Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3b395570
Commit
3b395570
authored
Aug 28, 2023
by
danyao12
Browse files
add Gemm2KPerBlock template for split kernels
parent
228e9cd1
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
40 deletions
+41
-40
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3_protro.cpp
...x_gemm/batched_multihead_attention_backward_v3_protro.cpp
+19
-17
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2_protro.hpp
...ce_batched_mha_bwd_xdl_cshuffle_qloop_light_v2_protro.hpp
+3
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2_protro.hpp
...atched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2_protro.hpp
+19
-23
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3_protro.cpp
View file @
3b395570
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2_protro.hpp
View file @
3b395570
...
@@ -306,6 +306,7 @@ template <index_t NumDimG,
...
@@ -306,6 +306,7 @@ template <index_t NumDimG,
index_t
KPerBlock
,
// Gemm0KPerBlock
index_t
KPerBlock
,
// Gemm0KPerBlock
index_t
Gemm1NPerBlock
,
index_t
Gemm1NPerBlock
,
index_t
Gemm1KPerBlock
,
index_t
Gemm1KPerBlock
,
index_t
Gemm2KPerBlock
,
index_t
AK1
,
index_t
AK1
,
index_t
BK1
,
index_t
BK1
,
index_t
B1K1
,
index_t
B1K1
,
...
@@ -761,6 +762,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -761,6 +762,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
KPerBlock
,
KPerBlock
,
Gemm1NPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
Gemm1KPerBlock
,
Gemm2KPerBlock
,
AK1
,
AK1
,
BK1
,
BK1
,
B1K1
,
B1K1
,
...
@@ -1457,6 +1459,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1457,6 +1459,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
Gemm1NPerBlock
<<
", "
<<
Gemm1NPerBlock
<<
", "
<<
Gemm1KPerBlock
<<
", "
<<
Gemm1KPerBlock
<<
", "
<<
Gemm2KPerBlock
<<
", "
<<
B1K1
<<
", "
<<
B1K1
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
"ASpec"
<<
getTensorSpecializationString
(
ASpec
)
<<
", "
<<
"ASpec"
<<
getTensorSpecializationString
(
ASpec
)
<<
", "
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2_protro.hpp
View file @
3b395570
...
@@ -48,6 +48,7 @@ template <typename InputDataType,
...
@@ -48,6 +48,7 @@ template <typename InputDataType,
index_t
KPerBlock
,
index_t
KPerBlock
,
index_t
Gemm1NPerBlock
,
index_t
Gemm1NPerBlock
,
index_t
Gemm1KPerBlock
,
index_t
Gemm1KPerBlock
,
index_t
Gemm2KPerBlock
,
index_t
AK1Value
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
BK1Value
,
index_t
B1K1Value
,
index_t
B1K1Value
,
...
@@ -786,13 +787,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -786,13 +787,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// dQ Gemm (type 3 crr)
// dQ Gemm (type 3 crr)
// Describes tuning parameter for C2_m_n = A2_m_k * B2_k_n
// Describes tuning parameter for C2_m_n = A2_m_k * B2_k_n
template
<
index_t
Sum_K_
=
NPerXdl
*
2
>
struct
Gemm2Params
struct
Gemm2Params_
{
{
static
constexpr
index_t
Gemm2_M
=
MPerBlock
;
// 64
static
constexpr
index_t
Gemm2_M
=
MPerBlock
;
// 64
static
constexpr
index_t
Gemm2_K
=
NPerBlock
;
// 128
static
constexpr
index_t
Gemm2_K
=
NPerBlock
;
// 128
static
constexpr
index_t
Gemm2_N
=
Gemm1NPerBlock
;
// 128
static
constexpr
index_t
Gemm2_N
=
Gemm1NPerBlock
;
// 128
static
constexpr
index_t
Sum_K
=
Sum_K_
;
static
constexpr
index_t
Sum_K
=
Gemm2KPerBlock
;
static
constexpr
index_t
A_K1
=
8
;
// dS will be row-major
static
constexpr
index_t
A_K1
=
8
;
// dS will be row-major
static
constexpr
index_t
A_K0
=
Sum_K
/
A_K1
;
static
constexpr
index_t
A_K0
=
Sum_K
/
A_K1
;
...
@@ -815,13 +815,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -815,13 +815,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
__host__
__device__
static
constexpr
auto
GetABlockSliceLengths_M0_K0_M1_K1_M2_K2
()
__host__
__device__
static
constexpr
auto
GetABlockSliceLengths_M0_K0_M1_K1_M2_K2
()
{
{
// perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
// perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
constexpr
index_t
k
=
Gemm2Params
::
Sum_K
-
1
;
constexpr
index_t
k
=
Sum_K
-
1
;
constexpr
index_t
k2
=
k
%
NPerXdl
;
constexpr
index_t
k2
=
k
%
NPerXdl
;
constexpr
index_t
k1
=
k
/
NPerXdl
%
Gemm0NWaves
;
constexpr
index_t
k1
=
k
/
NPerXdl
%
Gemm0NWaves
;
constexpr
index_t
k0
=
k
/
NPerXdl
/
Gemm0NWaves
%
NXdlPerWave
;
constexpr
index_t
k0
=
k
/
NPerXdl
/
Gemm0NWaves
%
NXdlPerWave
;
// perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
// perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
constexpr
index_t
m
=
Gemm2Params
::
Gemm2_M
-
1
;
constexpr
index_t
m
=
Gemm2_M
-
1
;
constexpr
index_t
m2
=
m
%
MPerXdl
;
constexpr
index_t
m2
=
m
%
MPerXdl
;
constexpr
index_t
m1
=
m
/
MPerXdl
%
Gemm0MWaves
;
constexpr
index_t
m1
=
m
/
MPerXdl
%
Gemm0MWaves
;
constexpr
index_t
m0
=
m
/
MPerXdl
/
Gemm0MWaves
%
MXdlPerWave
;
constexpr
index_t
m0
=
m
/
MPerXdl
/
Gemm0MWaves
%
MXdlPerWave
;
...
@@ -842,7 +842,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -842,7 +842,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using
ABlockSliceLengths_M0_K0_M1_K1
=
using
ABlockSliceLengths_M0_K0_M1_K1
=
decltype
(
GetABlockSliceLengths_M0_K0_M1_K1
());
//(2, 1, 1, 2) //(4, 1, 1, 2)
decltype
(
GetABlockSliceLengths_M0_K0_M1_K1
());
//(2, 1, 1, 2) //(4, 1, 1, 2)
};
};
using
Gemm2Params
=
Gemm2Params_
<>
;
// tune later
// dQ Gemm (type 3 crr)
// dQ Gemm (type 3 crr)
template
<
typename
Gemm2Params
,
typename
ASrcBlockwiseGemm
>
template
<
typename
Gemm2Params
,
typename
ASrcBlockwiseGemm
>
...
@@ -1033,14 +1032,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1033,14 +1032,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using
BBlockwiseCopy
=
using
BBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v2
<
GemmDataType
,
ThreadwiseTensorSliceTransfer_v2
<
GemmDataType
,
GemmDataType
,
GemmDataType
,
decltype
(
b_block_desc_n0_n1_n2_k0_k1_k2_k3
),
decltype
(
b_block_desc_n0_n1_n2_k0_k1_k2_k3
),
decltype
(
b_thread_desc_n0_n1_n2_k0_k1_k2_k3
),
decltype
(
b_thread_desc_n0_n1_n2_k0_k1_k2_k3
),
BThreadSlice_N0_N1_N2_K0_K1_K2_K3
,
BThreadSlice_N0_N1_N2_K0_K1_K2_K3
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
6
,
6
,
1
,
1
,
1
,
1
,
true
>
;
true
>
;
static
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
0
,
0
,
1
,
0
,
0
,
0
);
static
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
0
,
0
,
1
,
0
,
0
,
0
);
...
@@ -1049,20 +1048,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1049,20 +1048,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
BlockSize
,
GemmDataType
,
GemmDataType
,
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_thread_desc_k0_n_k1
),
decltype
(
b_thread_desc_k0_n_k1
),
decltype
(
MakeGemm2AMmaTileDescriptor_M0_M1_M2_K
(
a_block_desc_k0_m_k1
)),
decltype
(
MakeGemm2AMmaTileDescriptor_M0_M1_M2_K
(
a_block_desc_k0_m_k1
)),
decltype
(
MakeGemm2BMmaTileDescriptor_N0_N1_N2_K
(
b_thread_desc_k0_n_k1
)),
decltype
(
MakeGemm2BMmaTileDescriptor_N0_N1_N2_K
(
b_thread_desc_k0_n_k1
)),
MPerBlock
,
MPerBlock
,
Gemm1NPerBlock
,
Gemm1NPerBlock
,
Gemm2Params
::
Sum_K
,
Gemm2Params
::
Sum_K
,
MPerXdl
,
MPerXdl
,
NPerXdl
,
NPerXdl
,
Gemm2Params
::
GemmMRepeat
,
Gemm2Params
::
GemmMRepeat
,
Gemm2Params
::
GemmNRepeat
,
Gemm2Params
::
GemmNRepeat
,
Gemm2Params
::
GemmKPack
,
Gemm2Params
::
GemmKPack
,
true
,
// TransposeC
true
,
// TransposeC
Gemm2Params
::
GemmKPack
*
Gemm2Params
::
GemmKPack
*
XdlopsGemm
<
GemmDataType
,
MPerXdl
,
NPerXdl
,
Gemm2Params
::
GemmKPack
,
false
>
{}
XdlopsGemm
<
GemmDataType
,
MPerXdl
,
NPerXdl
,
Gemm2Params
::
GemmKPack
,
false
>
{}
...
@@ -1343,7 +1342,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1343,7 +1342,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
sizeof
(
GemmDataType
);
sizeof
(
GemmDataType
);
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
sizeof
(
GemmDataType
);
const
index_t
gemm2_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
const
index_t
gemm2_bytes_end
=
(
SharedMemTrait
::
k_block_space_size_aligned
+
SharedMemTrait
::
a2_block_space_size_aligned
)
*
SharedMemTrait
::
a2_block_space_size_aligned
)
*
sizeof
(
GemmDataType
);
sizeof
(
GemmDataType
);
...
@@ -1353,11 +1352,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1353,11 +1352,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
index_t
c_block_bytes_end
=
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
gemm0_bytes_end
,
return
math
::
max
(
gemm1_bytes_end
,
gemm0_bytes_end
,
gemm1_bytes_end
,
gemm2_bytes_end
,
gemm3_bytes_end
,
c_block_bytes_end
);
gemm2_bytes_end
,
gemm3_bytes_end
,
c_block_bytes_end
);
}
}
template
<
bool
HasMainKBlockLoop
,
template
<
bool
HasMainKBlockLoop
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment