Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
8385597f
Commit
8385597f
authored
Dec 13, 2024
by
Aleksander Dudek
Browse files
[CK_TILE] Refactor GemmKernel - update tests
parent
75535dd8
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
77 additions
and
67 deletions
+77
-67
example/ck_tile/16_batched_gemm/batched_gemm.hpp
example/ck_tile/16_batched_gemm/batched_gemm.hpp
+0
-32
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+2
-3
include/ck_tile/ops/gemm/problem/gemm_problem.hpp
include/ck_tile/ops/gemm/problem/gemm_problem.hpp
+32
-0
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
+34
-23
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
+9
-9
No files found.
example/ck_tile/16_batched_gemm/batched_gemm.hpp
View file @
8385597f
...
@@ -30,38 +30,6 @@ using BDataType = Types::BDataType;
...
@@ -30,38 +30,6 @@ using BDataType = Types::BDataType;
using
AccDataType
=
Types
::
AccDataType
;
using
AccDataType
=
Types
::
AccDataType
;
using
CDataType
=
Types
::
CDataType
;
using
CDataType
=
Types
::
CDataType
;
struct
BatchedGemmHostArgs
:
public
ck_tile
::
GemmHostArgs
{
CK_TILE_HOST
BatchedGemmHostArgs
()
=
default
;
CK_TILE_HOST
BatchedGemmHostArgs
(
const
void
*
a_ptr_
,
const
void
*
b_ptr_
,
void
*
c_ptr_
,
ck_tile
::
index_t
k_batch_
,
ck_tile
::
index_t
M_
,
ck_tile
::
index_t
N_
,
ck_tile
::
index_t
K_
,
ck_tile
::
index_t
stride_A_
,
ck_tile
::
index_t
stride_B_
,
ck_tile
::
index_t
stride_C_
,
ck_tile
::
index_t
batch_stride_A_
,
ck_tile
::
index_t
batch_stride_B_
,
ck_tile
::
index_t
batch_stride_C_
,
ck_tile
::
index_t
batch_count_
)
:
GemmHostArgs
(
a_ptr_
,
b_ptr_
,
c_ptr_
,
k_batch_
,
M_
,
N_
,
K_
,
stride_A_
,
stride_B_
,
stride_C_
),
batch_stride_A
(
batch_stride_A_
),
batch_stride_B
(
batch_stride_B_
),
batch_stride_C
(
batch_stride_C_
),
batch_count
(
batch_count_
)
{
}
ck_tile
::
index_t
batch_stride_A
;
ck_tile
::
index_t
batch_stride_B
;
ck_tile
::
index_t
batch_stride_C
;
ck_tile
::
index_t
batch_count
;
};
auto
create_args
(
int
argc
,
char
*
argv
[])
auto
create_args
(
int
argc
,
char
*
argv
[])
{
{
ck_tile
::
ArgParser
arg_parser
;
ck_tile
::
ArgParser
arg_parser
;
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
8385597f
...
@@ -299,7 +299,8 @@ struct GemmKernel
...
@@ -299,7 +299,8 @@ struct GemmKernel
}
}
/**
/**
* Create tensor views, pad views, tile windows, run gemm and epilogue pipeline
* Create tensor views, pad views, tile windows.
* Runs GEMM cooperatively by whole workgroup with CShuffle or Default 2D Epilogue
*
*
* @param a_ptr input A pointer
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param b_ptr input B pointer
...
@@ -307,8 +308,6 @@ struct GemmKernel
...
@@ -307,8 +308,6 @@ struct GemmKernel
* @param kargs GEMM kernel arguments
* @param kargs GEMM kernel arguments
* @param block_idx_m M block index
* @param block_idx_m M block index
* @param block_idx_n N block index
* @param block_idx_n N block index
*
* @return Runs GEMM cooperatively by whole workgroup with CShuffle or Default 2D Epilogue
*/
*/
CK_TILE_DEVICE
void
RunGemm
(
const
ADataType
*
a_ptr
,
CK_TILE_DEVICE
void
RunGemm
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
const
BDataType
*
b_ptr
,
...
...
include/ck_tile/ops/gemm/problem/gemm_problem.hpp
View file @
8385597f
...
@@ -53,4 +53,36 @@ struct GemmHostArgs : public Problem
...
@@ -53,4 +53,36 @@ struct GemmHostArgs : public Problem
index_t
k_batch
;
index_t
k_batch
;
};
};
struct
BatchedGemmHostArgs
:
public
ck_tile
::
GemmHostArgs
{
CK_TILE_HOST
BatchedGemmHostArgs
()
=
default
;
CK_TILE_HOST
BatchedGemmHostArgs
(
const
void
*
a_ptr_
,
const
void
*
b_ptr_
,
void
*
c_ptr_
,
ck_tile
::
index_t
k_batch_
,
ck_tile
::
index_t
M_
,
ck_tile
::
index_t
N_
,
ck_tile
::
index_t
K_
,
ck_tile
::
index_t
stride_A_
,
ck_tile
::
index_t
stride_B_
,
ck_tile
::
index_t
stride_C_
,
ck_tile
::
index_t
batch_stride_A_
,
ck_tile
::
index_t
batch_stride_B_
,
ck_tile
::
index_t
batch_stride_C_
,
ck_tile
::
index_t
batch_count_
)
:
GemmHostArgs
(
a_ptr_
,
b_ptr_
,
c_ptr_
,
k_batch_
,
M_
,
N_
,
K_
,
stride_A_
,
stride_B_
,
stride_C_
),
batch_stride_A
(
batch_stride_A_
),
batch_stride_B
(
batch_stride_B_
),
batch_stride_C
(
batch_stride_C_
),
batch_count
(
batch_count_
)
{
}
ck_tile
::
index_t
batch_stride_A
;
ck_tile
::
index_t
batch_stride_B
;
ck_tile
::
index_t
batch_stride_C
;
ck_tile
::
index_t
batch_count
;
};
}
// namespace ck_tile
}
// namespace ck_tile
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
View file @
8385597f
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/problem/gemm_problem.hpp"
template
<
typename
Tuple
>
template
<
typename
Tuple
>
class
TestCkTileBatchedGemm
:
public
::
testing
::
Test
class
TestCkTileBatchedGemm
:
public
::
testing
::
Test
...
@@ -24,12 +25,9 @@ class TestCkTileBatchedGemm : public ::testing::Test
...
@@ -24,12 +25,9 @@ class TestCkTileBatchedGemm : public ::testing::Test
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
struct
batched_gemm_kargs
:
public
ck_tile
::
BatchedGemmHostArgs
{
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
void
invoke_batched_gemm
(
const
batched_gemm_kargs
&
args
,
const
ck_tile
::
stream_config
&
s
)
void
invoke_batched_gemm
(
const
ck_tile
::
BatchedGemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadM
=
false
;
...
@@ -94,9 +92,21 @@ class TestCkTileBatchedGemm : public ::testing::Test
...
@@ -94,9 +92,21 @@ class TestCkTileBatchedGemm : public ::testing::Test
using
Kernel
=
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
);
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
.
a_ptr
,
args
.
b_ptr
,
const
dim3
grids
=
Kernel
::
GridSize
(
args
);
args
.
c_ptr
,
args
.
M
,
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
,
args
.
batch_stride_A
,
args
.
batch_stride_B
,
args
.
batch_stride_C
,
args
.
batch_count
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
batch_count
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
s
.
log_level_
>
0
)
if
(
s
.
log_level_
>
0
)
...
@@ -185,21 +195,22 @@ class TestCkTileBatchedGemm : public ::testing::Test
...
@@ -185,21 +195,22 @@ class TestCkTileBatchedGemm : public ::testing::Test
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
batched_gemm_kargs
kargs
{
a_m_k_dev_buf
.
GetDeviceBuffer
(),
ck_tile
::
BatchedGemmHostArgs
args
;
b_k_n_dev_buf
.
GetDeviceBuffer
(),
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
c_m_n_dev_buf
.
GetDeviceBuffer
(),
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
M
,
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
N
,
args
.
M
=
M
;
K
,
args
.
N
=
N
;
StrideA
,
args
.
K
=
K
;
StrideB
,
args
.
stride_A
=
StrideA
;
StrideC
,
args
.
stride_B
=
StrideB
;
BatchStrideA
,
args
.
stride_C
=
StrideC
;
BatchStrideB
,
args
.
batch_stride_A
=
BatchStrideA
;
BatchStrideC
,
args
.
batch_stride_B
=
BatchStrideB
;
BatchCount
};
args
.
batch_stride_C
=
BatchStrideC
;
args
.
batch_count
=
BatchCount
;
invoke_batched_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
kargs
,
invoke_batched_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
false
});
ck_tile
::
stream_config
{
nullptr
,
false
});
std
::
cout
<<
"Run kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
std
::
cout
<<
"Run kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
...
...
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
View file @
8385597f
...
@@ -95,7 +95,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
...
@@ -95,7 +95,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
has_hot_loop_v
,
has_hot_loop_v
,
tail_number_v
>>
;
tail_number_v
>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeK
a
rgs
(
args
.
p_a
,
auto
kargs
=
Kernel
::
MakeK
ernelA
rgs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_b
,
args
.
p_c
,
args
.
p_c
,
args
.
M
,
args
.
M
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment