Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
627a27bd
Unverified
Commit
627a27bd
authored
Dec 17, 2024
by
jakpiase
Committed by
GitHub
Dec 17, 2024
Browse files
Added unit tests for CK Tile compute bound gemm pipeline (#1728)
parent
d46196f2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
90 additions
and
26 deletions
+90
-26
test/ck_tile/gemm/CMakeLists.txt
test/ck_tile/gemm/CMakeLists.txt
+1
-1
test/ck_tile/gemm/test_gemm_pipeline.cpp
test/ck_tile/gemm/test_gemm_pipeline.cpp
+42
-0
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
+5
-5
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
+42
-20
No files found.
test/ck_tile/gemm/CMakeLists.txt
View file @
627a27bd
# Currently ck_tile is only built on gfx9
if
(
GPU_TARGETS MATCHES
"gfx9"
)
add_gtest_executable
(
test_ck_tile_gemm_
mem_
pipeline test_gemm_
mem_
pipeline.cpp
)
add_gtest_executable
(
test_ck_tile_gemm_pipeline test_gemm_pipeline.cpp
)
endif
()
test/ck_tile/gemm/test_gemm_
mem_
pipeline.cpp
→
test/ck_tile/gemm/test_gemm_pipeline.cpp
View file @
627a27bd
...
...
@@ -6,7 +6,7 @@
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_gemm_
mem_
pipeline_util.hpp"
#include "test_gemm_pipeline_util.hpp"
using
F16
=
ck_tile
::
half_t
;
using
F32
=
float
;
...
...
@@ -16,21 +16,27 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
>
;
using
Interwave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
ck_tile
::
GemmPipelineScheduler
::
Interwave
>
;
using
Mem
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Mem
>
;
using
Comp
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Comp
>
;
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
>
;
// clang-format on
TYPED_TEST_SUITE
(
TestCkTileGemm
Mem
Pipeline
,
KernelTypes
);
TYPED_TEST_SUITE
(
TestCkTileGemmPipeline
,
KernelTypes
);
#include "test_gemm_
mem_
pipeline_ut_cases.inc"
#include "test_gemm_pipeline_ut_cases.inc"
test/ck_tile/gemm/test_gemm_
mem_
pipeline_ut_cases.inc
→
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
View file @
627a27bd
...
...
@@ -3,7 +3,7 @@
#pragma once
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
SmallM
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
SmallM
)
{
std
::
vector
<
int
>
Ms
{
1
,
2
,
3
,
4
,
5
,
6
};
constexpr
int
N
=
1024
;
...
...
@@ -13,7 +13,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, SmallM)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
MidLargeM
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
MidLargeM
)
{
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
constexpr
int
N
=
1024
;
...
...
@@ -23,7 +23,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
PaddK
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
PaddK
)
{
std
::
vector
<
int
>
Ms
{
127
};
constexpr
int
N
=
1024
;
...
...
@@ -33,7 +33,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, PaddK)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
Regular
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
Regular
)
{
std
::
vector
<
int
>
Ms
{
512
};
constexpr
int
N
=
1024
;
...
...
@@ -43,7 +43,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, Regular)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
NotSupportedArgument
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
NotSupportedArgument
)
{
constexpr
int
M
=
512
;
constexpr
int
N
=
1025
;
...
...
test/ck_tile/gemm/test_gemm_
mem_
pipeline_util.hpp
→
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
View file @
627a27bd
...
...
@@ -11,18 +11,24 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
enum
struct
GemmPipelineType
{
Mem
,
Comp
};
template
<
typename
Tuple
>
class
TestCkTileGemm
Mem
Pipeline
:
public
::
testing
::
Test
class
TestCkTileGemmPipeline
:
public
::
testing
::
Test
{
protected:
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
static
constexpr
auto
Scheduler
=
std
::
tuple_element_t
<
7
,
Tuple
>::
value
;
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
static
constexpr
auto
Scheduler
=
std
::
tuple_element_t
<
7
,
Tuple
>::
value
;
static
constexpr
auto
PipelineType
=
std
::
tuple_element_t
<
8
,
Tuple
>::
value
;
// TODO: expose tile size through test t-param ?
struct
gemm_args
...
...
@@ -74,8 +80,13 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
;
using
BaseGemmPipeline
=
std
::
conditional_t
<
PipelineType
==
GemmPipelineType
::
Mem
,
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
,
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>>
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
args
.
K
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
...
...
@@ -85,15 +96,26 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
using
GemmPipeline
=
ck_tile
::
GemmPipelineAgBgCrMem
<
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
,
Scheduler
,
has_hot_loop_v
,
tail_number_v
>>
;
using
GemmPipeline
=
std
::
conditional_t
<
PipelineType
==
GemmPipelineType
::
Mem
,
ck_tile
::
GemmPipelineAgBgCrMem
<
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
,
Scheduler
,
has_hot_loop_v
,
tail_number_v
>>
,
ck_tile
::
GemmPipelineAgBgCrCompV3
<
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
,
Scheduler
,
has_hot_loop_v
,
tail_number_v
>>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment