Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
579f84c6
Commit
579f84c6
authored
Mar 06, 2023
by
aska-0096
Browse files
tempsave
parent
7e003d31
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
395 additions
and
119 deletions
+395
-119
example/01_gemm/gemm_wmma_fp16.cpp
example/01_gemm/gemm_wmma_fp16.cpp
+4
-4
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
...emm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
+1
-1
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+2
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+2
-3
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+3
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+374
-107
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+5
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+3
-1
No files found.
example/01_gemm/gemm_wmma_fp16.cpp
View file @
579f84c6
...
...
@@ -37,13 +37,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
GemmDefault
,
256
,
// BlockSize
128
,
// MPerBlock
1
6
,
// NPerBlock
1
28
,
// NPerBlock
32
,
// KPerBlock
8
,
// K1
16
,
// MPerWmma
16
,
// NPerWmma
1
,
// M Repeat
1
,
// N-Repeat
2
,
// M Repeat
4
,
// N-Repeat
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
true
,
1
,
// C shuffle (M Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
S
<
1
,
128
,
1
,
2
>
,
S
<
1
,
64
,
1
,
4
>
,
8
>
;
// clang-format on
...
...
example/01_gemm/run_gemm_example.inc
View file @
579f84c6
...
...
@@ -44,7 +44,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
break
;
case
4
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5
.
f
,
5
.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
1
.
f
,
1
.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
1.
f
,
1.
f
}(
b_k_n
);
break
;
default
:
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
View file @
579f84c6
...
...
@@ -129,7 +129,7 @@ using DeviceGemmInstance =
S
<
0
,
2
,
1
>
,
1
,
8
,
1
,
1
,
// be eight?
false
,
1
,
// CShuffleMWmmaPerWavePerShuffle
2
,
// CShuffleNWmmaPerWavePerShuffle
...
...
include/ck/host_utility/kernel_launch.hpp
View file @
579f84c6
...
...
@@ -33,9 +33,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
printf
(
"Warm up 1 time
\n
"
);
#endif
// warm up
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
//
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
const
int
nrepeat
=
1
00
;
const
int
nrepeat
=
1
;
#if DEBUG_LOG
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
#endif
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
579f84c6
...
...
@@ -27,6 +27,8 @@ template <index_t BlockSize,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
bool
AEnableLds
=
true
,
bool
BEnableLds
=
true
,
bool
TransposeC
=
false
>
/* Option: Read from LDS, big buffer hold all threads required data
* Source
...
...
@@ -83,9 +85,6 @@ struct BlockwiseGemmWMMA
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
static
constexpr
bool
AEnableLds
=
NWaves
==
1
?
false
:
true
;
static
constexpr
bool
BEnableLds
=
MWaves
==
1
?
false
:
true
;
// Read from Lds, duplicate Twice, Read from VGPR, no duplication.
static
constexpr
index_t
A_Data_Duplicated_Rate
=
AEnableLds
?
2
:
1
;
static
constexpr
index_t
B_Data_Duplicated_Rate
=
BEnableLds
?
2
:
1
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
579f84c6
...
...
@@ -89,6 +89,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static
constexpr
auto
AEnableLds
=
NWaves
==
1
?
false
:
true
;
static
constexpr
auto
BEnableLds
=
MWaves
==
1
?
false
:
true
;
// static constexpr auto AEnableLds = true;
// static constexpr auto BEnableLds = true;
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
// Describe how data read from Global memory
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
579f84c6
...
...
@@ -45,7 +45,7 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
index_t
batch_count
,
const
AGridDesc_AK0_M_AK1
a_grid_desc
_k0_m_k1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_k0_n_k1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -84,7 +84,7 @@ __global__ void
p_ds_grid_grp
,
p_e_grid
+
e_batch_offset
,
p_shared
,
a_grid_desc
_k0_m_k1
,
a_grid_desc
,
b_grid_desc_k0_n_k1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
...
...
@@ -98,7 +98,7 @@ __global__ void
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
batch_count
;
ignore
=
a_grid_desc
_k0_m_k1
;
ignore
=
a_grid_desc
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
...
...
@@ -115,7 +115,7 @@ template <typename GridwiseOp,
typename
BDataType
,
typename
DsPointer
,
typename
EDataType
,
typename
AGridDesc
_K0_M_K1
,
typename
AGridDesc
,
typename
BGridDesc_K0_N_K1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -135,7 +135,7 @@ __global__ void
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
index_t
batch_count
,
const
AGridDesc
_K0_M_K1
a_grid_desc
_k0_m_k1
,
const
AGridDesc
a_grid_desc
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -176,7 +176,7 @@ __global__ void
p_ds_grid_grp
,
p_e_grid
+
e_batch_offset
,
p_shared
,
a_grid_desc
_k0_m_k1
,
a_grid_desc
,
b_grid_desc_k0_n_k1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -193,7 +193,7 @@ __global__ void
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
a_grid_desc
_k0_m_k1
;
ignore
=
a_grid_desc
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
...
...
@@ -207,7 +207,7 @@ template <typename GridwiseOp,
typename
BDataType
,
typename
DsPointer
,
typename
EDataType
,
typename
AGridDesc
_K0_M_K1
,
typename
AGridDesc
,
typename
BGridDesc_K0_N_K1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -225,7 +225,7 @@ __global__ void
const
BDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
AGridDesc
_K0_M_K1
a_grid_desc
_k0_m_k1
,
const
AGridDesc
a_grid_desc
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -244,7 +244,7 @@ __global__ void
p_ds_grid
,
p_e_grid
,
p_shared
,
a_grid_desc
_k0_m_k1
,
a_grid_desc
,
b_grid_desc_k0_n_k1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -257,7 +257,7 @@ __global__ void
ignore
=
p_b_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
a_grid_desc
_k0_m_k1
;
ignore
=
a_grid_desc
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
...
...
@@ -276,7 +276,7 @@ template < // DataType Family
typename
DsDataType
,
typename
EDataType
,
// InMemory Data Descriptor
typename
AGridDesc
_K0_M_K1
,
typename
AGridDesc
,
typename
BGridDesc_K0_N_K1
,
typename
DsGridDesc_M_N
,
typename
EGridDesc_M_N
,
...
...
@@ -288,7 +288,7 @@ template < // DataType Family
// Tiling Family
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K
0
PerBlock
,
index_t
KPerBlock
,
index_t
MPerWmma
,
index_t
NPerWmma
,
index_t
K1Value
,
...
...
@@ -303,6 +303,7 @@ template < // DataType Family
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
AEnableLds
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
...
...
@@ -311,6 +312,7 @@ template < // DataType Family
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BEnableLds
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
...
...
@@ -335,36 +337,161 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
WmmaK
=
16
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
AEnableLds
,
BEnableLds
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
// Describe how data store to (LDS/VGPR) buffer from Global memory
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor
()
{
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
a_block_desc
=
[
&
]()
{
if
constexpr
(
AEnableLds
)
{
// K0->M->K1 Per Block
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}
else
{
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
// KWmma->MRepeat->MWave->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
make_tuple
(
Number
<
KWmmaPerblock
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
I1
,
K1
),
make_tuple
(
Number
<
MRepeat
>
{}
*
K1
,
K1
,
K1
,
K1
,
K1
,
I1
));
}
}();
return
a_block_desc
;
}
__host__
__device__
static
constexpr
auto
MakeABlockSliceCopyStep
()
{
constexpr
auto
a_block_copy_step
=
[
&
]()
{
if
constexpr
(
AEnableLds
)
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
return
make_multi_index
(
K0PerBlock
,
0
,
0
);
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
return
make_multi_index
(
KWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
);
}
}();
return
a_block_desc_k0perblock_mperblock_k1
;
return
a_block_copy_step
;
}
__host__
__device__
static
constexpr
auto
MakeBBlockSliceCopyStep
()
{
constexpr
auto
b_block_copy_step
=
[
&
]()
{
if
constexpr
(
BEnableLds
)
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
return
make_multi_index
(
K0PerBlock
,
0
,
0
);
}
else
{
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
return
make_multi_index
(
KWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
);
}
}();
return
b_block_copy_step
;
}
// Describe how data read from (LDS/VGPR) buffer
template
<
typename
ABlockDesc_
>
__host__
__device__
static
constexpr
auto
MakeAWaveDescriptor
(
const
ABlockDesc_
&
)
{
constexpr
auto
a_wave_desc
=
[
&
]()
{
if
constexpr
(
AEnableLds
)
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
return
transform_tensor_descriptor
(
ABlockDesc_
{},
make_tuple
(
make_pass_through_transform
(
Number
<
A_K0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
else
{
// KWmma_MRepeat_MWave_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr
auto
KWmma
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I5
);
return
transform_tensor_descriptor
(
ABlockDesc_
{},
make_tuple
(
make_merge_transform
(
make_tuple
(
Number
<
KWmma
>
{},
I1
)),
make_pass_through_transform
(
Number
<
MRepeat
>
{}),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
}
}();
return
a_wave_desc
;
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_K0_N0_N1_N2_K1
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
auto
B_K0
=
BBlockDesc_BK0_N_BK1
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_BK0_N_BK1
{}.
GetLength
(
I2
);
return
transform_tensor_descriptor
(
BBlockDesc_BK0_N_BK1
{},
make_tuple
(
make_pass_through_transform
(
Number
<
B_K0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
[
&
]()
{
...
...
@@ -416,28 +543,20 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_k0perblock_nperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
const
index_t
gemm_bytes_end
=
SharedMemTrait
::
a_block_space_size_aligned
*
sizeof
(
ADataType
)
+
SharedMemTrait
::
b_block_space_size_aligned
*
sizeof
(
BDataType
);
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_shuffle_block_space_size
*
sizeof
(
CShuffleDataType
);
return
(
a_block_space_size_aligned
*
sizeof
(
ADataType
)
+
b_block_space_size_aligned
*
sizeof
(
BDataType
));
return
math
::
max
(
gemm_bytes_end
,
c_block_bytes_end
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc
_K0_M_K1
&
a_grid_desc
_k0_m_k1
,
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
...
...
@@ -450,9 +569,41 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
(
NPerBlock
%
(
NRepeat
*
NPerWmma
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
GetAProblemsizeMK
=
[
&
]()
{
if
constexpr
(
AEnableLds
)
{
return
make_tuple
(
a_grid_desc
.
GetLength
(
I1
),
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I2
));
}
else
{
return
make_tuple
(
a_grid_desc
.
GetLength
(
I1
)
*
a_grid_desc
.
GetLength
(
I2
)
*
a_grid_desc
.
GetLength
(
I4
),
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
*
a_grid_desc
.
GetLength
(
I5
));
}
};
const
auto
GetBProblemsizeNK
=
[
&
]()
{
if
constexpr
(
BEnableLds
)
{
return
make_tuple
(
b_grid_desc_k0_n_k1
.
GetLength
(
I1
),
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
*
b_grid_desc_k0_n_k1
.
GetLength
(
I2
));
}
else
{
return
make_tuple
(
b_grid_desc_k0_n_k1
.
GetLength
(
I1
)
*
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)
*
b_grid_desc_k0_n_k1
.
GetLength
(
I4
),
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
*
b_grid_desc_k0_n_k1
.
GetLength
(
I3
)
*
b_grid_desc_k0_n_k1
.
GetLength
(
I5
));
}
};
const
auto
M
=
GetAProblemsizeMK
()[
I0
];
const
auto
N
=
GetBProblemsizeNK
()[
I0
];
const
auto
K
=
GetAProblemsizeMK
()[
I1
];
bool
valid
=
true
;
...
...
@@ -468,21 +619,20 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
}
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)))
K
==
GetBProblemsizeNK
()[
I1
]))
{
printf
(
"GridwiseOp: ABE descriptor dimension cross check failure
\n
"
);
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
0
%
K
0
PerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
{
printf
(
"GridwiseOp: Problemsize descriptor dimension check failure
\n
"
);
return
false
;
}
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K
0
/
K
0
PerBlock
;
const
auto
num_k_loop
=
K
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
...
...
@@ -546,6 +696,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
e_grid_desc_m_n
);
}
struct
SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static
constexpr
auto
max_lds_align
=
K1
;
static
constexpr
auto
a_block_space_size_aligned
=
AEnableLds
?
math
::
integer_least_multiple
(
MakeABlockDescriptor
().
GetElementSpaceSize
(),
max_lds_align
)
:
0
;
static
constexpr
auto
b_block_space_size_aligned
=
BEnableLds
?
math
::
integer_least_multiple
(
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
().
GetElementSpaceSize
(),
max_lds_align
)
:
0
;
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_space_size
=
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
()
.
GetElementSpaceSize
();
static
constexpr
auto
c_shuffle_block_space_offset
=
0
;
};
using
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
...
...
@@ -560,7 +735,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
DsGridPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
void
*
__restrict__
p_shared
,
const
AGridDesc
_K0_M_K1
&
a_grid_desc
_k0_m_k1
,
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -575,7 +750,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
/*******************************************************************************/
// Memory buffer zone.
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc
_k0_m_k1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
const
auto
ds_grid_buf
=
generate_tuple
(
...
...
@@ -603,23 +778,39 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
/*******************************************************************************/
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
const
auto
K
=
[
&
](){
if
constexpr
(
AEnableLds
){
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I2
);
}
else
{
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
*
a_grid_desc
.
GetLength
(
I5
);
}
}();
constexpr
auto
a_block_desc
=
MakeABlockDescriptor
();
constexpr
auto
b_block_desc
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
auto
a_block_trait
=
[
&
](){
// A matrix blockwise copy
if
constexpr
(
AEnableLds
)
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatA
*>
(
p_shared
),
a_block_desc
.
GetElementSpaceSize
());
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
/* typename SrcElementwiseOperation, */
AElementwiseOperation
,
/* typename DstElementwiseOperation, */
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
/* InMemoryDataOperationEnum DstInMemOp, */
InMemoryDataOperationEnum
::
Set
,
/* typename BlockSliceLengths, */
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
/* typename ThreadClusterLengths, */
ABlockTransferThreadClusterLengths_K0_M_K1
,
/* typename ThreadClusterArrangeOrder, */
ABlockTransferThreadClusterArrangeOrder
,
/* typename SrcData, */
ADataType
,
/* typename DstData, */
ADataType
,
/* typename SrcDesc, */
decltype
(
a_grid_desc
_k0_m_k1
),
/* typename DstDesc, */
decltype
(
a_block_desc
_k0perblock_mperblock_k1
),
/* typename SrcData, */
FloatA
,
/* typename DstData, */
FloatA
,
/* typename SrcDesc, */
decltype
(
a_grid_desc
),
/* typename DstDesc, */
decltype
(
a_block_desc
),
/* typename SrcDimAccessOrder, */
ABlockTransferSrcAccessOrder
,
/* typename DstDimAccessOrder, */
Sequence
<
0
,
1
,
2
>
,
/* index_t SrcVectorDim, */
ABlockTransferSrcVectorDim
,
...
...
@@ -630,62 +821,138 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
/* index_t DstScalarStrideInVector, */
1
,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */
AThreadTransferSrcResetCoordinateAfterRun
,
/* bool ThreadTransferDstResetCoordinateAfterRun, */
true
>
(
a_grid_desc
_k0_m_k1
,
a_grid_desc
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc
_k0perblock_mperblock_k1
,
a_block_desc
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BDataType
,
BDataType
,
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_block_desc_k0perblock_nperblock_k1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_k0perblock_nperblock_k1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
return
make_tuple
(
a_block_buf
,
a_blockwise_copy
);
}
else
{
// Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
auto
a_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
a_block_desc
.
GetElementSpaceSize
());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto
a_blockwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
FloatA
,
FloatA
,
decltype
(
a_grid_desc
),
decltype
(
a_block_desc
),
Sequence
<
Number
<
KWmmaPerBlock
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
K1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
ABlockTransferSrcScalarPerVector
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_grid_desc
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
/
(
MWaves
*
MPerWmma
),
get_thread_local_1d_id
()
/
32
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
get_thread_local_1d_id
()
%
16
,
0
));
return
make_tuple
(
a_block_buf
,
a_blockwise_copy
);
}
};
auto
b_block_trait
=
[
&
](){
if
constexpr
(
BEnableLds
)
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatB
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_size_aligned
,
b_block_desc
.
GetElementSpaceSize
());
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatB
,
FloatB
,
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_block_desc
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
return
make_tuple
(
b_block_buf
,
b_blockwise_copy
);
}
else
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
auto
b_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatB
>
(
b_block_desc
.
GetElementSpaceSize
());
auto
b_blockwise_copy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_block_desc
),
Sequence
<
Number
<
K0PerBlock
>
{},
Number
<
NRepeat
>
{},
Number
<
K1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
BBlockTransferSrcScalarPerVector
,
1
>
(
make_multi_index
(
0
,
get_thread_local_1d_id
()
/
32
*
16
+
get_thread_local_1d_id
()
%
16
,
0
));
return
make_tuple
(
b_block_buf
,
b_blockwise_copy
);
}
};
auto
a_block_buf
=
a_block_trait
()[
I0
];
auto
a_blockwise_copy
=
a_block_trait
()[
I1
];
auto
b_block_buf
=
b_block_trait
()[
I0
];
auto
b_blockwise_copy
=
b_block_trait
()[
I1
];
/*******************************************************************************/
// GEMM
constexpr
auto
WmmaK
=
16
;
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
auto
blockwise_gemm
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
<
BlockSize
,
ADataType
,
BDataType
,
AccDataType
,
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
decltype
(
b_block_desc_k0perblock_nperblock_k1
),
MPerWmma
,
NPerWmma
,
MRepeat
,
NRepeat
,
KPack
,
false
,
true
>
{};
BlockwiseGemmWMMA
<
BlockSize
,
ADataType
,
BDataType
,
AccDataType
,
decltype
(
MakeAWaveDescriptor
(
a_block_desc
)),
decltype
(
MakeBBlockDescriptor_K0_N0_N1_N2_K1
(
b_block_desc
)),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerWmma
,
NPerWmma
,
MRepeat
,
NRepeat
,
KPack
>
{};
// Prepare Register for C matrix
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
...
...
@@ -702,7 +969,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
// gridwise GEMM pipeline
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc
_k0_m_k1
,
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc
,
a_block_desc_k0perblock_mperblock_k1
,
a_blockwise_copy
,
a_grid_buf
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
579f84c6
...
...
@@ -56,6 +56,8 @@ struct GridwiseGemmPipeline_v1<1, true, true>
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
if
(
get_thread_local_1d_id
()
<
32
);
printf
(
"Mat-A Lds Enabled, Mat-B Lds Enabled
\n
"
);
// preload data into LDS
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
...
...
@@ -304,6 +306,9 @@ struct GridwiseGemmPipeline_v1<1, false, true>
},
Number<a_block_desc.GetLengths().GetSize()>{});
#endif
if
(
get_thread_local_1d_id
()
<
32
);
printf
(
"Mat-A Lds Disabled, Mat-B Lds Enabled
\n
"
);
constexpr
auto
a_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
auto
a_block_buf_switch
=
a_block_buf
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
579f84c6
...
...
@@ -694,7 +694,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
NPerWmma
,
MRepeat
,
NRepeat
,
KPack
>
{};
KPack
,
AEnableLds
,
BEnableLds
>
{};
// Prepare Register for C matrix
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment