Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
627a27bd
"examples/vscode:/vscode.git/clone" did not exist on "f07b697b8dd1fee39c7ca7d0d95d40910c1e724d"
Unverified
Commit
627a27bd
authored
Dec 17, 2024
by
jakpiase
Committed by
GitHub
Dec 17, 2024
Browse files
Added unit tests for CK Tile compute bound gemm pipeline (#1728)
parent
d46196f2
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
90 additions
and
26 deletions
+90
-26
test/ck_tile/gemm/CMakeLists.txt
test/ck_tile/gemm/CMakeLists.txt
+1
-1
test/ck_tile/gemm/test_gemm_pipeline.cpp
test/ck_tile/gemm/test_gemm_pipeline.cpp
+42
-0
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
+5
-5
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
+42
-20
No files found.
test/ck_tile/gemm/CMakeLists.txt
View file @
627a27bd
# Currently ck_tile is only built on gfx9
if
(
GPU_TARGETS MATCHES
"gfx9"
)
add_gtest_executable
(
test_ck_tile_gemm_
mem_
pipeline test_gemm_
mem_
pipeline.cpp
)
add_gtest_executable
(
test_ck_tile_gemm_pipeline test_gemm_pipeline.cpp
)
endif
()
test/ck_tile/gemm/test_gemm_
mem_
pipeline.cpp
→
test/ck_tile/gemm/test_gemm_pipeline.cpp
View file @
627a27bd
...
...
@@ -6,7 +6,7 @@
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_gemm_
mem_
pipeline_util.hpp"
#include "test_gemm_pipeline_util.hpp"
using
F16
=
ck_tile
::
half_t
;
using
F32
=
float
;
...
...
@@ -16,21 +16,27 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
>
;
using
Interwave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
ck_tile
::
GemmPipelineScheduler
::
Interwave
>
;
using
Mem
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Mem
>
;
using
Comp
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Comp
>
;
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
>
;
// clang-format on
TYPED_TEST_SUITE
(
TestCkTileGemm
Mem
Pipeline
,
KernelTypes
);
TYPED_TEST_SUITE
(
TestCkTileGemmPipeline
,
KernelTypes
);
#include "test_gemm_
mem_
pipeline_ut_cases.inc"
#include "test_gemm_pipeline_ut_cases.inc"
test/ck_tile/gemm/test_gemm_
mem_
pipeline_ut_cases.inc
→
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
View file @
627a27bd
...
...
@@ -3,7 +3,7 @@
#pragma once
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
SmallM
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
SmallM
)
{
std
::
vector
<
int
>
Ms
{
1
,
2
,
3
,
4
,
5
,
6
};
constexpr
int
N
=
1024
;
...
...
@@ -13,7 +13,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, SmallM)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
MidLargeM
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
MidLargeM
)
{
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
constexpr
int
N
=
1024
;
...
...
@@ -23,7 +23,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
PaddK
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
PaddK
)
{
std
::
vector
<
int
>
Ms
{
127
};
constexpr
int
N
=
1024
;
...
...
@@ -33,7 +33,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, PaddK)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
Regular
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
Regular
)
{
std
::
vector
<
int
>
Ms
{
512
};
constexpr
int
N
=
1024
;
...
...
@@ -43,7 +43,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, Regular)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
NotSupportedArgument
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
NotSupportedArgument
)
{
constexpr
int
M
=
512
;
constexpr
int
N
=
1025
;
...
...
test/ck_tile/gemm/test_gemm_
mem_
pipeline_util.hpp
→
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
View file @
627a27bd
...
...
@@ -11,8 +11,13 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
enum
struct
GemmPipelineType
{
Mem
,
Comp
};
template
<
typename
Tuple
>
class
TestCkTileGemm
Mem
Pipeline
:
public
::
testing
::
Test
class
TestCkTileGemmPipeline
:
public
::
testing
::
Test
{
protected:
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
...
...
@@ -23,6 +28,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
static
constexpr
auto
Scheduler
=
std
::
tuple_element_t
<
7
,
Tuple
>::
value
;
static
constexpr
auto
PipelineType
=
std
::
tuple_element_t
<
8
,
Tuple
>::
value
;
// TODO: expose tile size through test t-param ?
struct
gemm_args
...
...
@@ -74,8 +80,13 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
;
using
BaseGemmPipeline
=
std
::
conditional_t
<
PipelineType
==
GemmPipelineType
::
Mem
,
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
,
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>>
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
args
.
K
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
...
...
@@ -85,7 +96,18 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
using
GemmPipeline
=
ck_tile
::
GemmPipelineAgBgCrMem
<
using
GemmPipeline
=
std
::
conditional_t
<
PipelineType
==
GemmPipelineType
::
Mem
,
ck_tile
::
GemmPipelineAgBgCrMem
<
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
,
Scheduler
,
has_hot_loop_v
,
tail_number_v
>>
,
ck_tile
::
GemmPipelineAgBgCrCompV3
<
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
...
...
@@ -93,7 +115,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
Traits
,
Scheduler
,
has_hot_loop_v
,
tail_number_v
>>
;
tail_number_v
>>
>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment