Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6fe9e964
Commit
6fe9e964
authored
Dec 16, 2024
by
Adam Osewski
Browse files
Adapt example to refactor changes.
parent
241baec1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
15 deletions
+36
-15
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+8
-10
example/ck_tile/03_gemm/universal_gemm.cpp
example/ck_tile/03_gemm/universal_gemm.cpp
+28
-5
No files found.
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
6fe9e964
...
...
@@ -199,16 +199,14 @@ int run_gemm_example(int argc, char* argv[])
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
// work.
// else if(a_layout == "C" && b_layout == "C")
// {
// return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
// }
// else if(a_layout == "C" && b_layout == "R")
// {
// return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
// }
else
if
(
a_layout
==
"C"
&&
b_layout
==
"C"
)
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
}
else
if
(
a_layout
==
"C"
&&
b_layout
==
"R"
)
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Col
{},
Row
{},
Row
{});
}
else
{
throw
std
::
runtime_error
(
"Unsupported data layout configuration for A,B and C tensors!"
);
...
...
example/ck_tile/03_gemm/universal_gemm.cpp
View file @
6fe9e964
...
...
@@ -37,8 +37,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
#
el
if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
// Compute friendly for Intrawave scheduler
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
...
...
@@ -57,6 +57,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
TransposeC
=
false
;
constexpr
int
kBlockPerCu
=
1
;
// ===============================================
...
...
@@ -71,10 +73,13 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
GemmUniversalTraits
=
ck_tile
::
TileGemmUniversalTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
,
TransposeC
>
;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
#endif
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
;
...
...
@@ -97,14 +102,16 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
BDataType
,
AccDataType
,
GemmShape
,
Traits
,
GemmUniversal
Traits
,
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
ck_tile
::
GemmPipelineScheduler
::
Interwave
,
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
ck_tile
::
GemmPipelineScheduler
::
Intrawave
,
#endif
has_hot_loop_v
,
tail_number_v
>>
;
tail_number_v
>
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
...
...
@@ -139,6 +146,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
if
(
has_hot_loop
)
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Full
>
{});
}
else
{
std
::
ostringstream
err
;
err
<<
"For compute pipeline tail number should always be Full, but have
\"
"
<<
tail_num
<<
"
\"
which is not supported! PrefetchStages: "
<<
BaseGemmPipeline
::
PrefetchStages
<<
"
\n
File: "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Tail pipeline One to Seven
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
{
...
...
@@ -199,6 +221,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Seven
>
{});
}
}
#endif
}
else
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment