Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
627a27bd
"...composable_kernel.git" did not exist on "cb6475c77d74f9d9f0a5fb2c0b80d5008fe420da"
Unverified
Commit
627a27bd
authored
Dec 17, 2024
by
jakpiase
Committed by
GitHub
Dec 17, 2024
Browse files
Added unit tests for CK Tile compute bound gemm pipeline (#1728)
parent
d46196f2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
90 additions
and
26 deletions
+90
-26
test/ck_tile/gemm/CMakeLists.txt
test/ck_tile/gemm/CMakeLists.txt
+1
-1
test/ck_tile/gemm/test_gemm_pipeline.cpp
test/ck_tile/gemm/test_gemm_pipeline.cpp
+42
-0
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
+5
-5
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
+42
-20
No files found.
test/ck_tile/gemm/CMakeLists.txt
View file @
627a27bd
# Currently ck_tile is only built on gfx9
# Currently ck_tile is only built on gfx9
if
(
GPU_TARGETS MATCHES
"gfx9"
)
if
(
GPU_TARGETS MATCHES
"gfx9"
)
add_gtest_executable
(
test_ck_tile_gemm_
mem_
pipeline test_gemm_
mem_
pipeline.cpp
)
add_gtest_executable
(
test_ck_tile_gemm_pipeline test_gemm_pipeline.cpp
)
endif
()
endif
()
test/ck_tile/gemm/test_gemm_
mem_
pipeline.cpp
→
test/ck_tile/gemm/test_gemm_pipeline.cpp
View file @
627a27bd
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "ck_tile/host.hpp"
#include "test_gemm_
mem_
pipeline_util.hpp"
#include "test_gemm_pipeline_util.hpp"
using
F16
=
ck_tile
::
half_t
;
using
F16
=
ck_tile
::
half_t
;
using
F32
=
float
;
using
F32
=
float
;
...
@@ -16,21 +16,27 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
...
@@ -16,21 +16,27 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
>
;
ck_tile
::
GemmPipelineScheduler
::
Intrawave
>
;
using
Interwave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
using
Interwave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
ck_tile
::
GemmPipelineScheduler
::
Interwave
>
;
ck_tile
::
GemmPipelineScheduler
::
Interwave
>
;
using
Mem
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Mem
>
;
using
Comp
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Comp
>
;
// clang-format off
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
>
;
>
;
// clang-format on
// clang-format on
TYPED_TEST_SUITE
(
TestCkTileGemm
Mem
Pipeline
,
KernelTypes
);
TYPED_TEST_SUITE
(
TestCkTileGemmPipeline
,
KernelTypes
);
#include "test_gemm_
mem_
pipeline_ut_cases.inc"
#include "test_gemm_pipeline_ut_cases.inc"
test/ck_tile/gemm/test_gemm_
mem_
pipeline_ut_cases.inc
→
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
View file @
627a27bd
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#pragma once
#pragma once
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
SmallM
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
SmallM
)
{
{
std
::
vector
<
int
>
Ms
{
1
,
2
,
3
,
4
,
5
,
6
};
std
::
vector
<
int
>
Ms
{
1
,
2
,
3
,
4
,
5
,
6
};
constexpr
int
N
=
1024
;
constexpr
int
N
=
1024
;
...
@@ -13,7 +13,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, SmallM)
...
@@ -13,7 +13,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, SmallM)
this
->
Run
(
M
,
N
,
K
);
this
->
Run
(
M
,
N
,
K
);
}
}
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
MidLargeM
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
MidLargeM
)
{
{
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
constexpr
int
N
=
1024
;
constexpr
int
N
=
1024
;
...
@@ -23,7 +23,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM)
...
@@ -23,7 +23,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM)
this
->
Run
(
M
,
N
,
K
);
this
->
Run
(
M
,
N
,
K
);
}
}
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
PaddK
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
PaddK
)
{
{
std
::
vector
<
int
>
Ms
{
127
};
std
::
vector
<
int
>
Ms
{
127
};
constexpr
int
N
=
1024
;
constexpr
int
N
=
1024
;
...
@@ -33,7 +33,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, PaddK)
...
@@ -33,7 +33,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, PaddK)
this
->
Run
(
M
,
N
,
K
);
this
->
Run
(
M
,
N
,
K
);
}
}
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
Regular
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
Regular
)
{
{
std
::
vector
<
int
>
Ms
{
512
};
std
::
vector
<
int
>
Ms
{
512
};
constexpr
int
N
=
1024
;
constexpr
int
N
=
1024
;
...
@@ -43,7 +43,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, Regular)
...
@@ -43,7 +43,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, Regular)
this
->
Run
(
M
,
N
,
K
);
this
->
Run
(
M
,
N
,
K
);
}
}
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
NotSupportedArgument
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
NotSupportedArgument
)
{
{
constexpr
int
M
=
512
;
constexpr
int
M
=
512
;
constexpr
int
N
=
1025
;
constexpr
int
N
=
1025
;
...
...
test/ck_tile/gemm/test_gemm_
mem_
pipeline_util.hpp
→
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
View file @
627a27bd
...
@@ -11,18 +11,24 @@
...
@@ -11,18 +11,24 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm.hpp"
enum
struct
GemmPipelineType
{
Mem
,
Comp
};
template
<
typename
Tuple
>
template
<
typename
Tuple
>
class
TestCkTileGemm
Mem
Pipeline
:
public
::
testing
::
Test
class
TestCkTileGemmPipeline
:
public
::
testing
::
Test
{
{
protected:
protected:
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
static
constexpr
auto
Scheduler
=
std
::
tuple_element_t
<
7
,
Tuple
>::
value
;
static
constexpr
auto
Scheduler
=
std
::
tuple_element_t
<
7
,
Tuple
>::
value
;
static
constexpr
auto
PipelineType
=
std
::
tuple_element_t
<
8
,
Tuple
>::
value
;
// TODO: expose tile size through test t-param ?
// TODO: expose tile size through test t-param ?
struct
gemm_args
struct
gemm_args
...
@@ -74,8 +80,13 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
...
@@ -74,8 +80,13 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
using
BaseGemmPipeline
=
std
::
conditional_t
<
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
;
PipelineType
==
GemmPipelineType
::
Mem
,
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
,
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>>
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
args
.
K
);
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
args
.
K
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
...
@@ -85,15 +96,26 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
...
@@ -85,15 +96,26 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
using
GemmPipeline
=
ck_tile
::
GemmPipelineAgBgCrMem
<
using
GemmPipeline
=
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
std
::
conditional_t
<
PipelineType
==
GemmPipelineType
::
Mem
,
BDataType
,
ck_tile
::
GemmPipelineAgBgCrMem
<
AccDataType
,
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
GemmShape
,
BDataType
,
Traits
,
AccDataType
,
Scheduler
,
GemmShape
,
has_hot_loop_v
,
Traits
,
tail_number_v
>>
;
Scheduler
,
has_hot_loop_v
,
tail_number_v
>>
,
ck_tile
::
GemmPipelineAgBgCrCompV3
<
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
,
Scheduler
,
has_hot_loop_v
,
tail_number_v
>>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_b
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment