Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
16467e0e
Commit
16467e0e
authored
Sep 02, 2022
by
wangshaojie6
Browse files
run gemm instance without code modification
parent
16f47b25
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
993 additions
and
67 deletions
+993
-67
example/01_gemm/gemm_xdl_fp16_splitk.cpp
example/01_gemm/gemm_xdl_fp16_splitk.cpp
+225
-41
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle_small_gemm.hpp
...pu/device/device_gemm_xdl_splitk_c_shuffle_small_gemm.hpp
+2
-2
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle_static.hpp
...on/gpu/device/device_gemm_xdl_splitk_c_shuffle_static.hpp
+43
-9
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+6
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2_static.hpp
...operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2_static.hpp
+704
-0
include/ck/utility/common_header.hpp
include/ck/utility/common_header.hpp
+12
-12
No files found.
example/01_gemm/gemm_xdl_fp16_splitk.cpp
View file @
16467e0e
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle_small_gemm.hpp
View file @
16467e0e
...
@@ -224,7 +224,7 @@ struct DeviceGemmXdlSplitKCShuffleSmallGemm
...
@@ -224,7 +224,7 @@ struct DeviceGemmXdlSplitKCShuffleSmallGemm
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
1
,
1
,
1
,
1
,
BlockSize
,
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
GemmAccDataType
,
...
@@ -268,7 +268,7 @@ struct DeviceGemmXdlSplitKCShuffleSmallGemm
...
@@ -268,7 +268,7 @@ struct DeviceGemmXdlSplitKCShuffleSmallGemm
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemmAtomicAdd
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
using
GridwiseGemmAtomicAdd
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
1
,
1
,
1
,
1
,
BlockSize
,
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
GemmAccDataType
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle_static.hpp
View file @
16467e0e
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
#include "tensor_layout.hpp"
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4r2.hpp"
#include "gridwise_gemm_xdlops_v2r4r2
_static
.hpp"
#include "gemm_specialization.hpp"
#include "gemm_specialization.hpp"
#ifndef CK_RUN_KERNEL_AND_TIME
#ifndef CK_RUN_KERNEL_AND_TIME
...
@@ -20,7 +20,11 @@ namespace ck {
...
@@ -20,7 +20,11 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
typename
ALayout
,
template
<
index_t
M_matrix
,
index_t
N_matrix
,
index_t
K_matrix
,
index_t
K_batch
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
CLayout
,
typename
CLayout
,
typename
ADataType
,
typename
ADataType
,
...
@@ -246,7 +250,11 @@ struct DeviceGemmXdlSplitKCShuffleStatic
...
@@ -246,7 +250,11 @@ struct DeviceGemmXdlSplitKCShuffleStatic
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
());
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
());
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2_static
<
M_matrix
,
N_matrix
,
K_matrix
,
K_batch
,
BlockSize
,
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
GemmAccDataType
,
...
@@ -290,7 +298,11 @@ struct DeviceGemmXdlSplitKCShuffleStatic
...
@@ -290,7 +298,11 @@ struct DeviceGemmXdlSplitKCShuffleStatic
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemmAtomicAdd
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
using
GridwiseGemmAtomicAdd
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2_static
<
M_matrix
,
N_matrix
,
K_matrix
,
K_batch
,
BlockSize
,
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
GemmAccDataType
,
...
@@ -444,7 +456,7 @@ struct DeviceGemmXdlSplitKCShuffleStatic
...
@@ -444,7 +456,7 @@ struct DeviceGemmXdlSplitKCShuffleStatic
arg
.
block_2_ctile_map_
))
arg
.
block_2_ctile_map_
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid setting"
);
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
_static
has invalid setting"
);
}
}
const
index_t
grid_size
=
const
index_t
grid_size
=
...
@@ -490,13 +502,35 @@ struct DeviceGemmXdlSplitKCShuffleStatic
...
@@ -490,13 +502,35 @@ struct DeviceGemmXdlSplitKCShuffleStatic
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
// check validaty when using splitk
hipGetErrorString
(
hipMemset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetElementSpaceSize
()
*
sizeof
(
CDataType
)));
launch_and_time_kernel
({
stream_config
.
stream_id_
,
false
},
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
};
};
if
(
has_main_k0_block_loop
)
if
(
has_main_k0_block_loop
)
{
{
if
(
kbatch
==
1
)
if
(
kbatch
==
1
)
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
_static
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
...
@@ -514,7 +548,7 @@ struct DeviceGemmXdlSplitKCShuffleStatic
...
@@ -514,7 +548,7 @@ struct DeviceGemmXdlSplitKCShuffleStatic
}
}
else
else
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
_static
<
GridwiseGemmAtomicAdd
,
GridwiseGemmAtomicAdd
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
...
@@ -535,7 +569,7 @@ struct DeviceGemmXdlSplitKCShuffleStatic
...
@@ -535,7 +569,7 @@ struct DeviceGemmXdlSplitKCShuffleStatic
{
{
if
(
kbatch
==
1
)
if
(
kbatch
==
1
)
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
_static
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
...
@@ -553,7 +587,7 @@ struct DeviceGemmXdlSplitKCShuffleStatic
...
@@ -553,7 +587,7 @@ struct DeviceGemmXdlSplitKCShuffleStatic
}
}
else
else
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
_static
<
GridwiseGemmAtomicAdd
,
GridwiseGemmAtomicAdd
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
16467e0e
...
@@ -241,7 +241,7 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
...
@@ -241,7 +241,7 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
// 2D slices of column-vectors in 3D space
// 2D slices of column-vectors in 3D space
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
>
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
,
index_t
K_batch
>
struct
BlockToCTileMap_KSplit_M00_N0_M01Adapt_Static
struct
BlockToCTileMap_KSplit_M00_N0_M01Adapt_Static
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
16467e0e
...
@@ -77,7 +77,11 @@ __global__ void
...
@@ -77,7 +77,11 @@ __global__ void
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
template
<
index_t
BlockSize
,
template
<
index_t
M_matrix
,
index_t
N_matrix
,
index_t
K_matrix
,
index_t
K_batch
,
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatAcc
,
typename
FloatC
,
typename
FloatC
,
...
@@ -289,7 +293,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -289,7 +293,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__host__
__device__
static
constexpr
auto
MakeCBlockClusterAdaptorStatic
(
__host__
__device__
static
constexpr
auto
MakeCBlockClusterAdaptorStatic
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
const
CMNGridDesc
&
c_m_n_grid_desc
)
{
{
return
BlockToCTileMap_KSplit_M00_N0_M01Adapt_Static
<
MPerBlock
,
NPerBlock
,
CMNGridDesc
>
(
return
BlockToCTileMap_KSplit_M00_N0_M01Adapt_Static
<
MPerBlock
,
NPerBlock
,
CMNGridDesc
,
K_batch
>
(
c_m_n_grid_desc
);
c_m_n_grid_desc
);
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2_static.hpp
0 → 100644
View file @
16467e0e
This diff is collapsed.
Click to expand it.
include/ck/utility/common_header.hpp
View file @
16467e0e
...
@@ -50,7 +50,7 @@
...
@@ -50,7 +50,7 @@
#define USEING_STATIC_KERNEL 1
#define USEING_STATIC_KERNEL 1
#define MNKB_0_8
1
#define MNKB_0_8
0
#define MNKB_1_4 0
#define MNKB_1_4 0
#define MNKB_2_8 0
#define MNKB_2_8 0
#define MNKB_3_5 0
#define MNKB_3_5 0
...
@@ -60,23 +60,23 @@
...
@@ -60,23 +60,23 @@
#if MNKB_0_8
#if MNKB_0_8
#define M_matrix 16
#define M_matrix 16
#define N_matrix
4096
#define N_matrix
1152
#define K_matrix 12
80
0
#define K_matrix
5
120
#define K_batch
5
#define K_batch
8
#elif MNKB_1_4
#elif MNKB_1_4
#define M_matrix 16
#define M_matrix 16
#define N_matrix
4096
#define N_matrix
5120
#define K_matrix
12800
#define K_matrix
384
#define K_batch
5
#define K_batch
4
#elif MNKB_2_8
#elif MNKB_2_8
#define M_matrix 16
#define M_matrix 16
#define N_matrix
4096
#define N_matrix
1280
#define K_matrix 12
80
0
#define K_matrix
5
120
#define K_batch
5
#define K_batch
8
#elif MNKB_3_5
#elif MNKB_3_5
#define M_matrix 16
#define M_matrix 16
#define N_matrix
4096
#define N_matrix
5120
#define K_matrix 1280
0
#define K_matrix 1280
#define K_batch 5
#define K_batch 5
#elif MNKB_4_5
#elif MNKB_4_5
#define M_matrix 16
#define M_matrix 16
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment