Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5bedd21a
Commit
5bedd21a
authored
Dec 18, 2024
by
Aleksander Dudek
Browse files
[CK_TILE] Refactor GemmKernel - update tests
parent
13fe6e95
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
45 deletions
+11
-45
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+1
-1
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
+1
-13
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
+9
-31
No files found.
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
5bedd21a
...
@@ -95,7 +95,7 @@ struct GemmKernel
...
@@ -95,7 +95,7 @@ struct GemmKernel
index_t
stride_C
;
index_t
stride_C
;
};
};
CK_TILE_HOST
static
constexpr
GemmKernelArgs
MakeKernelArgs
(
GemmHostArgs
&
hostArgs
)
CK_TILE_HOST
static
constexpr
GemmKernelArgs
MakeKernelArgs
(
const
GemmHostArgs
&
hostArgs
)
{
{
return
GemmKernelArgs
{
hostArgs
.
a_ptr
,
return
GemmKernelArgs
{
hostArgs
.
a_ptr
,
hostArgs
.
b_ptr
,
hostArgs
.
b_ptr
,
...
...
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
View file @
5bedd21a
...
@@ -91,19 +91,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
...
@@ -91,19 +91,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
using
Kernel
=
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
.
a_ptr
,
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
args
.
b_ptr
,
args
.
c_ptr
,
args
.
M
,
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
,
args
.
batch_stride_A
,
args
.
batch_stride_B
,
args
.
batch_stride_C
,
args
.
batch_count
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
batch_count
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
batch_count
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
...
...
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
View file @
5bedd21a
...
@@ -31,22 +31,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -31,22 +31,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
static
constexpr
auto
PipelineType
=
std
::
tuple_element_t
<
8
,
Tuple
>::
value
;
static
constexpr
auto
PipelineType
=
std
::
tuple_element_t
<
8
,
Tuple
>::
value
;
// TODO: expose tile size through test t-param ?
// TODO: expose tile size through test t-param ?
struct
gemm_args
{
const
void
*
p_a
;
const
void
*
p_b
;
void
*
p_c
;
ck_tile
::
index_t
kbatch
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
N
;
ck_tile
::
index_t
K
;
ck_tile
::
index_t
stride_A
;
ck_tile
::
index_t
stride_B
;
ck_tile
::
index_t
stride_C
;
};
template
<
bool
PadM
,
bool
PadN
,
bool
PadK
>
template
<
bool
PadM
,
bool
PadN
,
bool
PadK
>
void
invoke_gemm
(
const
gemm_a
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
void
invoke_gemm
(
const
ck_tile
::
GemmHostA
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
// TODO: This should be parameterized in tests
// TODO: This should be parameterized in tests
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
...
@@ -117,17 +103,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -117,17 +103,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
has_hot_loop_v
,
has_hot_loop_v
,
tail_number_v
>>>
;
tail_number_v
>>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
.
p_a
,
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
args
.
p_b
,
args
.
p_c
,
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
);
args
.
M
,
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
!
Kernel
::
IsSupportedArgument
(
kargs
))
if
(
!
Kernel
::
IsSupportedArgument
(
kargs
))
...
@@ -319,11 +297,11 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -319,11 +297,11 @@ class TestCkTileGemmPipeline : public ::testing::Test
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
gemm_a
rgs
args
;
ck_tile
::
GemmHostA
rgs
args
;
args
.
p_a
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
p_b
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
p_c
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
kbatch
=
kbatch
;
args
.
k
_
batch
=
kbatch
;
args
.
M
=
M
;
args
.
M
=
M
;
args
.
N
=
N
;
args
.
N
=
N
;
args
.
K
=
K
;
args
.
K
=
K
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment