Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3e178e7c
Unverified
Commit
3e178e7c
authored
Feb 13, 2025
by
Muhammed Emin Ozturk
Committed by
GitHub
Feb 13, 2025
Browse files
Merge branch 'develop' into muozturk_bf16fp8_streamk
parents
27fb084f
0e5e29c4
Changes
122
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
81 additions
and
21 deletions
+81
-21
test/ck_tile/gemm/test_gemm_pipeline.cpp
test/ck_tile/gemm/test_gemm_pipeline.cpp
+10
-5
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
+71
-16
No files found.
test/ck_tile/gemm/test_gemm_pipeline.cpp
View file @
3e178e7c
...
...
@@ -17,22 +17,27 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
using
Interwave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
ck_tile
::
GemmPipelineScheduler
::
Interwave
>
;
using
Mem
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Mem
>
;
using
Comp
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Comp
>
;
using
CompV3
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
CompV3
>
;
using
CompV4
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
CompV4
>
;
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
CompV3
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
CompV4
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
CompV3
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
CompV4
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
CompV3
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
CompV4
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
CompV3
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
CompV4
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
>
;
// clang-format on
...
...
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
View file @
3e178e7c
...
...
@@ -14,7 +14,32 @@
enum
struct
GemmPipelineType
{
Mem
,
Comp
CompV3
,
CompV4
};
template
<
GemmPipelineType
PT
,
typename
Problem
>
struct
GemmPipelineTypeSelector
;
template
<
typename
Problem
>
struct
GemmPipelineTypeSelector
<
GemmPipelineType
::
Mem
,
Problem
>
{
using
base_pipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
Problem
>
;
using
pipeline
=
ck_tile
::
GemmPipelineAgBgCrMem
<
Problem
>
;
};
template
<
typename
Problem
>
struct
GemmPipelineTypeSelector
<
GemmPipelineType
::
CompV3
,
Problem
>
{
using
base_pipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
;
using
pipeline
=
ck_tile
::
GemmPipelineAgBgCrCompV3
<
Problem
>
;
};
template
<
typename
Problem
>
struct
GemmPipelineTypeSelector
<
GemmPipelineType
::
CompV4
,
Problem
>
{
using
base_pipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrCompV4
<
Problem
>
;
using
pipeline
=
ck_tile
::
GemmPipelineAgBgCrCompV4
<
Problem
>
;
};
template
<
typename
Tuple
>
...
...
@@ -36,8 +61,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
void
invoke_gemm
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// TODO: This should be parameterized in tests
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
...
...
@@ -52,6 +77,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr
bool
kPadN
=
PadN
;
constexpr
bool
kPadK
=
PadK
;
constexpr
bool
DoubleSmemBuffer
=
(
PipelineType
==
GemmPipelineType
::
CompV4
)
?
true
:
false
;
// TODO: For now - but this should also be a test parameter
constexpr
bool
TransposeC
=
false
;
...
...
@@ -69,16 +96,20 @@ class TestCkTileGemmPipeline : public ::testing::Test
GemmSpatiallyLocalTilePartitioner
<
GemmShape
,
TileParitionerGroupNum
,
TileParitionerM01
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
GemmUniversalTraits
=
ck_tile
::
TileGemmUniversalTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
,
TransposeC
>
;
using
GemmUniversalTraits
=
ck_tile
::
TileGemmUniversalTraits
<
kPadM
,
kPadN
,
kPadK
,
DoubleSmemBuffer
,
ALayout
,
BLayout
,
CLayout
,
TransposeC
>
;
using
GemmPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>
;
using
BaseGemmPipeline
=
std
::
conditional_t
<
PipelineType
==
GemmPipelineType
::
Mem
,
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
GemmPipelineProblem
>
,
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
GemmPipelineProblem
>>
;
typename
GemmPipelineTypeSelector
<
PipelineType
,
GemmPipelineProblem
>::
base_pipeline
;
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
K_Tile
;
...
...
@@ -99,12 +130,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
has_hot_loop_v
,
tail_number_v
>
;
using
GemmPipeline
=
std
::
conditional_t
<
PipelineType
==
GemmPipelineType
::
Mem
,
ck_tile
::
GemmPipelineAgBgCrMem
<
UniversalGemmProblem
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>
,
ck_tile
::
GemmPipelineAgBgCrCompV3
<
UniversalGemmProblem
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>>
;
using
GemmPipeline
=
typename
GemmPipelineTypeSelector
<
PipelineType
,
UniversalGemmProblem
>::
pipeline
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
...
...
@@ -145,7 +172,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
if
(
has_hot_loop
)
{
if
constexpr
(
PipelineType
==
GemmPipelineType
::
Comp
)
if
constexpr
(
PipelineType
==
GemmPipelineType
::
Comp
V3
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
...
...
@@ -235,6 +262,22 @@ class TestCkTileGemmPipeline : public ::testing::Test
}
}
}
if
constexpr
(
PipelineType
==
GemmPipelineType
::
CompV4
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Three
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Three
>
{});
}
else
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Two
>
{});
}
}
}
else
{
...
...
@@ -258,7 +301,19 @@ class TestCkTileGemmPipeline : public ::testing::Test
public:
std
::
vector
<
int
>
k_batches_
;
void
SetUp
()
override
{
k_batches_
=
{
1
,
2
};
}
void
SetUp
()
override
{
if
constexpr
(
PipelineType
==
GemmPipelineType
::
CompV4
)
{
// Only do k_batch = 1 when pipeline is CompV4
k_batches_
=
{
1
};
}
else
{
// Otherwise, use k_batch = 1 and 2
k_batches_
=
{
1
,
2
};
}
}
template
<
bool
PadM
=
true
,
bool
PadN
=
true
,
bool
PadK
=
true
>
void
Run
(
const
int
M
,
...
...
Prev
1
…
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment