Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
be38f68d
Commit
be38f68d
authored
Jul 10, 2023
by
ltqin
Browse files
add padding code for M
parent
a188073b
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
174 additions
and
56 deletions
+174
-56
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2_phased.cpp
...x_gemm/batched_multihead_attention_backward_v2_phased.cpp
+4
-4
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3.cpp
..._softmax_gemm/batched_multihead_attention_backward_v3.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v3.cpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+42
-12
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+42
-12
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+37
-8
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+36
-8
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_ydotygrad.hpp
...gridwise_batched_mha_bwd_xdl_cshuffle_qloop_ydotygrad.hpp
+11
-10
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2_phased.cpp
View file @
be38f68d
...
...
@@ -102,8 +102,8 @@ static constexpr bool Deterministic = false;
// If 32 < DIM <= 64 , ues prototype1 2nd template.
// If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1
<
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
...
...
@@ -172,8 +172,8 @@ using DeviceGemmInstance =
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 64)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1
<
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3.cpp
View file @
be38f68d
...
...
@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define DIM
32
// DIM should be a multiple of 8.
#define DIM
128
// DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
View file @
be38f68d
...
...
@@ -24,7 +24,7 @@ Kernel outputs:
*/
#define USING_MASK 0
#define DIM
128
// DIM should be a multiple of 8.
#define DIM
32
// DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
be38f68d
...
...
@@ -337,6 +337,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
();
static
constexpr
index_t
DMPerBlock
=
BlockSize
;
// TODO: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
...
...
@@ -371,6 +372,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
B1Spec
,
CSpec
>
;
using
DTransform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
Sequence
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
>
,
Sequence
<
DMPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
>
,
GemmSpecialization
::
MNKOPadding
,
ASpec
,
BSpec
,
B1Spec
,
CSpec
>
;
/*
Descriptors for inputs:
...
...
@@ -596,6 +606,19 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
}
}
static
auto
MakeDGridDescriptor_M
(
index_t
MRaw
)
{
const
auto
d_grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
DMPerBlock
)
*
DMPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
return
transform_tensor_descriptor
(
d_grid_desc_mraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
...
...
@@ -606,7 +629,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
DGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
DYGridDesc_M_O
=
decltype
(
DTransform
::
MakeCGridDescriptor_M_N
({},
{}));
using
DGridDesc_M
=
decltype
(
MakeDGridDescriptor_M
(
1
));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
YGradGridDesc_O0_M_O1
=
decltype
(
MakeYGradGridDescriptor_O0_M_O1
({},
{}));
...
...
@@ -705,7 +729,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
B1GridDesc_BK0_N_BK1
,
YGridDesc_M_O
,
LSEGridDesc_M
,
D
GridDesc_M
,
LSE
GridDesc_M
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
...
...
@@ -754,7 +778,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
YGridDesc_M_O
,
DGridDesc_M
,
BlockSize
,
Block
Size
,
DMPer
Block
,
DKPerBlock
>
;
// Argument
struct
Argument
:
public
BaseArgument
...
...
@@ -818,8 +842,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
y_grid_desc_m_o_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
d_y_grid_desc_m_o_
{
DTransform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
lse_grid_desc_m_
{
DeviceOp
::
MakeLSEGridDescriptor_M
(
lse_gs_ms_lengths
[
NumDimG
])},
d_grid_desc_m_
{
DeviceOp
::
Make
LSE
GridDescriptor_M
(
d_gs_ms_lengths
[
NumDimG
])},
d_grid_desc_m_
{
DeviceOp
::
Make
D
GridDescriptor_M
(
d_gs_ms_lengths
[
NumDimG
])},
k_grid_desc_n_k_
{
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
ygrad_grid_desc_o0_m_o1_
{
DeviceOp
::
MakeYGradGridDescriptor_O0_M_O1
(
...
...
@@ -836,7 +862,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
z_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
d_block_2_ctile_map_
{
GridwiseYDotYGrad
::
MakeDefaultBlock2CTileMap
(
y_grid_desc_m_o_
)},
d_block_2_ctile_map_
{
GridwiseYDotYGrad
::
MakeDefaultBlock2CTileMap
(
d_y_grid_desc_m_o_
)},
d_y_grid_desc_mblock_mperblock_oblock_operblock_
{},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
...
...
@@ -881,7 +908,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
(
z_grid_desc_m_n_
);
d_y_grid_desc_mblock_mperblock_oblock_operblock_
=
GridwiseYDotYGrad
::
MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
y_grid_desc_m_o_
);
d_
y_grid_desc_m_o_
);
// Print();
m_raw_padded_
=
GridwiseGemm
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
0
]);
...
...
@@ -932,6 +959,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
ZGridDesc_M_N
z_grid_desc_m_n_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
DYGridDesc_M_O
d_y_grid_desc_m_o_
;
LSEGridDesc_M
lse_grid_desc_m_
;
DGridDesc_M
d_grid_desc_m_
;
KGridDesc_N_K
k_grid_desc_n_k_
;
...
...
@@ -998,15 +1026,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
const
index_t
grid_size
=
(
Deterministic
?
1
:
arg
.
d_block_2_ctile_map_
.
CalculateGridSize
(
arg
.
y_grid_desc_m_o_
))
*
:
arg
.
d_block_2_ctile_map_
.
CalculateGridSize
(
arg
.
d_
y_grid_desc_m_o_
))
*
arg
.
batch_count_
;
std
::
cout
<<
"grid_size: "
<<
grid_size
<<
"grid_size / arg.batch_count_: "
<<
grid_size
/
arg
.
batch_count_
<<
" arg.batch_count_: "
<<
arg
.
batch_count_
<<
std
::
endl
;
std
::
cout
<<
"MPerBlock: "
<<
MPerBlock
<<
" Gemm1NPerBlock: "
<<
Gemm1NPerBlock
<<
std
::
endl
;
std
::
cout
<<
"arg.y_grid_desc_m_o_: {"
<<
arg
.
y_grid_desc_m_o_
.
GetLength
(
I0
)
<<
","
<<
arg
.
y_grid_desc_m_o_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.d_y_grid_desc_m_o_: {"
<<
arg
.
d_y_grid_desc_m_o_
.
GetLength
(
I0
)
<<
","
<<
arg
.
d_y_grid_desc_m_o_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.d_grid_desc_m_: {"
<<
arg
.
d_grid_desc_m_
.
GetLength
(
I0
)
<<
"}"
<<
std
::
endl
;
auto
launch_kernel
=
[
&
]()
{
const
auto
kernel
=
kernel_batched_multihead_attention_backward_ydotygrad_v1
<
...
...
@@ -1062,7 +1092,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
DeviceOp
::
LSEGridDesc_M
,
DeviceOp
::
D
GridDesc_M
,
DeviceOp
::
LSE
GridDesc_M
,
DeviceOp
::
YGradGridDesc_O0_M_O1
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
...
...
@@ -1096,7 +1126,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
lse_grid_desc_m_
,
arg
.
d
_grid_desc_m_
,
arg
.
lse
_grid_desc_m_
,
arg
.
ygrad_grid_desc_o0_m_o1_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
...
...
@@ -1138,7 +1168,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
}
// TODO: Check if tensor specialization & strides mismatch
if
(
!
GridwiseYDotYGrad
::
CheckValidity
(
arg
.
y_grid_desc_m_o_
,
arg
.
d_block_2_ctile_map_
))
if
(
!
GridwiseYDotYGrad
::
CheckValidity
(
arg
.
d_
y_grid_desc_m_o_
,
arg
.
d_block_2_ctile_map_
))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
be38f68d
...
...
@@ -342,6 +342,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
();
static
constexpr
index_t
DMPerBlock
=
BlockSize
;
// TODO: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
...
...
@@ -377,6 +378,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
B1Spec
,
CSpec
>
;
using
DTransform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
Sequence
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
>
,
Sequence
<
DMPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
>
,
GemmSpecialization
::
MNKOPadding
,
ASpec
,
BSpec
,
B1Spec
,
CSpec
>
;
/*
Descriptors for inputs:
...
...
@@ -602,6 +612,19 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
}
}
static
auto
MakeDGridDescriptor_M
(
index_t
MRaw
)
{
const
auto
d_grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
BlockSize
)
*
BlockSize
;
const
auto
MPad
=
M
-
MRaw
;
return
transform_tensor_descriptor
(
d_grid_desc_mraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
({},
{}));
...
...
@@ -612,7 +635,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
DGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
DYGridDesc_M_O
=
decltype
(
DTransform
::
MakeCGridDescriptor_M_N
({},
{}));
using
DGridDesc_M
=
decltype
(
MakeDGridDescriptor_M
(
1
));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
YGradGridDesc_M0_O_M1
=
decltype
(
MakeYGradGridDescriptor_M0_O_M1
(
YGridDesc_M_O
{}));
...
...
@@ -711,7 +735,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
B1GridDesc_BK0_N_BK1
,
YGridDesc_M_O
,
LSEGridDesc_M
,
D
GridDesc_M
,
LSE
GridDesc_M
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
...
...
@@ -768,7 +792,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
YGridDesc_M_O
,
DGridDesc_M
,
BlockSize
,
Block
Size
,
DMPer
Block
,
DKPerBlock
>
;
// Argument
struct
Argument
:
public
BaseArgument
...
...
@@ -832,8 +856,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
y_grid_desc_m_o_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
d_y_grid_desc_m_o_
{
DTransform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
lse_grid_desc_m_
{
DeviceOp
::
MakeLSEGridDescriptor_M
(
lse_gs_ms_lengths
[
NumDimG
])},
d_grid_desc_m_
{
DeviceOp
::
Make
LSE
GridDescriptor_M
(
d_gs_ms_lengths
[
NumDimG
])},
d_grid_desc_m_
{
DeviceOp
::
Make
D
GridDescriptor_M
(
d_gs_ms_lengths
[
NumDimG
])},
k_grid_desc_n_k_
{
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
ygrad_grid_desc_m0_o_m1_
{
DeviceOp
::
MakeYGradGridDescriptor_M0_O_M1
(
y_grid_desc_m_o_
)},
...
...
@@ -849,7 +875,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
z_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
d_block_2_ctile_map_
{
GridwiseYDotYGrad
::
MakeDefaultBlock2CTileMap
(
y_grid_desc_m_o_
)},
d_block_2_ctile_map_
{
GridwiseYDotYGrad
::
MakeDefaultBlock2CTileMap
(
d_y_grid_desc_m_o_
)},
d_y_grid_desc_mblock_mperblock_oblock_operblock_
{},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
...
...
@@ -894,7 +921,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
(
z_grid_desc_m_n_
);
d_y_grid_desc_mblock_mperblock_oblock_operblock_
=
GridwiseYDotYGrad
::
MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
y_grid_desc_m_o_
);
d_
y_grid_desc_m_o_
);
// Print();
m_raw_padded_
=
GridwiseGemm
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
0
]);
...
...
@@ -945,6 +972,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
ZGridDesc_M_N
z_grid_desc_m_n_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
DYGridDesc_M_O
d_y_grid_desc_m_o_
;
LSEGridDesc_M
lse_grid_desc_m_
;
DGridDesc_M
d_grid_desc_m_
;
KGridDesc_N_K
k_grid_desc_n_k_
;
...
...
@@ -1011,15 +1039,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
const
index_t
grid_size
=
(
Deterministic
?
1
:
arg
.
d_block_2_ctile_map_
.
CalculateGridSize
(
arg
.
y_grid_desc_m_o_
))
*
:
arg
.
d_block_2_ctile_map_
.
CalculateGridSize
(
arg
.
d_
y_grid_desc_m_o_
))
*
arg
.
batch_count_
;
std
::
cout
<<
"grid_size: "
<<
grid_size
<<
"grid_size / arg.batch_count_: "
<<
grid_size
/
arg
.
batch_count_
<<
" arg.batch_count_: "
<<
arg
.
batch_count_
<<
std
::
endl
;
std
::
cout
<<
"MPerBlock: "
<<
MPerBlock
<<
" Gemm1NPerBlock: "
<<
Gemm1NPerBlock
<<
std
::
endl
;
std
::
cout
<<
"arg.y_grid_desc_m_o_: {"
<<
arg
.
y_grid_desc_m_o_
.
GetLength
(
I0
)
<<
","
<<
arg
.
y_grid_desc_m_o_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.d_y_grid_desc_m_o_: {"
<<
arg
.
d_y_grid_desc_m_o_
.
GetLength
(
I0
)
<<
","
<<
arg
.
d_y_grid_desc_m_o_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.d_grid_desc_m_: {"
<<
arg
.
d_grid_desc_m_
.
GetLength
(
I0
)
<<
"}"
<<
std
::
endl
;
auto
launch_kernel
=
[
&
]()
{
const
auto
kernel
=
kernel_batched_multihead_attention_backward_ydotygrad_v2
<
...
...
@@ -1079,7 +1109,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
DeviceOp
::
LSEGridDesc_M
,
DeviceOp
::
D
GridDesc_M
,
DeviceOp
::
LSE
GridDesc_M
,
DeviceOp
::
YGradGridDesc_M0_O_M1
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
...
...
@@ -1113,7 +1143,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
lse_grid_desc_m_
,
arg
.
d
_grid_desc_m_
,
arg
.
lse
_grid_desc_m_
,
arg
.
ygrad_grid_desc_m0_o_m1_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
...
...
@@ -1165,7 +1195,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
}
// TODO: Check if tensor specialization & strides mismatch
if
(
!
GridwiseYDotYGrad
::
CheckValidity
(
arg
.
y_grid_desc_m_o_
,
arg
.
d_block_2_ctile_map_
))
if
(
!
GridwiseYDotYGrad
::
CheckValidity
(
arg
.
d_
y_grid_desc_m_o_
,
arg
.
d_block_2_ctile_map_
))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
be38f68d
...
...
@@ -315,6 +315,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
();
static
constexpr
index_t
DMPerBlock
=
BlockSize
;
// TODO: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
...
...
@@ -378,6 +379,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
B1Spec
,
CSpec
>
;
using
DTransform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
Sequence
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
>
,
Sequence
<
DMPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
>
,
GemmSpecialization
::
MNKOPadding
,
ASpec
,
BSpec
,
B1Spec
,
CSpec
>
;
/*
Descriptors for inputs:
...
...
@@ -547,6 +557,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
}
}
static
auto
MakeDGridDescriptor_M
(
index_t
MRaw
)
{
const
auto
d_grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
DMPerBlock
)
*
DMPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
return
transform_tensor_descriptor
(
d_grid_desc_mraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
...
...
@@ -562,7 +585,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
using
YGradGridDesc_O0_M_O1
=
decltype
(
MakeYGradGridDescriptor_O0_M_O1
({},
{}));
using
ZGridDesc_M_N
=
decltype
(
MakeZGridDescriptor_M_N
({},
{}));
using
DGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
DYGridDesc_M_O
=
decltype
(
DTransform
::
MakeCGridDescriptor_M_N
({},
{}));
using
DGridDesc_M
=
decltype
(
MakeDGridDescriptor_M
(
1
));
constexpr
static
auto
make_MaskOutPredicate
()
{
...
...
@@ -656,7 +680,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
B1GridDesc_BK0_N_BK1
,
YGridDesc_M_O
,
LSEGridDesc_M
,
D
GridDesc_M
,
LSE
GridDesc_M
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
...
...
@@ -707,7 +731,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
YGridDesc_M_O
,
DGridDesc_M
,
BlockSize
,
Block
Size
,
DMPer
Block
,
DKPerBlock
>
;
using
DBlock2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseYDotYGrad
::
DefaultBlock2CTileMap
>
;
...
...
@@ -752,6 +776,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// D parameter
DDataType
*
p_d_grid_
;
DYGridDesc_M_O
d_y_grid_desc_m_o_
;
DGridDesc_M
d_grid_desc_m_
;
DBlock2CTileMap
d_block_2_ctile_map_
;
typename
GridwiseYDotYGrad
::
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -931,15 +956,18 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// D parameters
const
auto
p_d_grid
=
static_cast
<
DDataType
*>
(
p_Ds
[
i
]);
const
auto
d_grid_desc_m
=
DeviceOp
::
MakeLSEGridDescriptor_M
(
problem_desc
.
d_gs_ms_lengths
[
NumDimG
]);
DeviceOp
::
MakeDGridDescriptor_M
(
problem_desc
.
d_gs_ms_lengths
[
NumDimG
]);
const
auto
d_y_grid_desc_m_o
=
DTransform
::
MakeCGridDescriptor_M_N
(
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
index_t
d_block_start
=
d_grid_size_
;
const
auto
d_block_2_ctile_map
=
DBlock2CTileMap
(
y_grid_desc_m_o
,
d_block_start
);
const
auto
d_block_2_ctile_map
=
DBlock2CTileMap
(
d_
y_grid_desc_m_o
,
d_block_start
);
const
auto
d_y_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseYDotYGrad
::
MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
y_grid_desc_m_o
);
d_
y_grid_desc_m_o
);
index_t
d_num_blocks_per_batch
=
d_block_2_ctile_map
.
CalculateGridSize
(
y_grid_desc_m_o
);
d_block_2_ctile_map
.
CalculateGridSize
(
d_
y_grid_desc_m_o
);
index_t
d_block_end
=
d_block_start
+
d_num_blocks_per_batch
*
batch_count
;
d_grid_size_
=
d_block_end
;
...
...
@@ -973,6 +1001,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
raw_m_padded
,
raw_n_padded
,
p_d_grid
,
d_y_grid_desc_m_o
,
d_grid_desc_m
,
d_block_2_ctile_map
,
d_y_grid_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -1151,7 +1180,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// TODO: Check if tensor specialization & strides mismatch
const
auto
&
kernel_arg
=
arg
.
group_kernel_args_
[
i
];
const
auto
&
device_arg
=
arg
.
group_device_args_
[
i
];
if
(
!
GridwiseYDotYGrad
::
CheckValidity
(
kernel_arg
.
y_grid_desc_m_o_
,
if
(
!
GridwiseYDotYGrad
::
CheckValidity
(
kernel_arg
.
d_
y_grid_desc_m_o_
,
kernel_arg
.
d_block_2_ctile_map_
))
{
return
false
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
be38f68d
...
...
@@ -322,6 +322,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
();
static
constexpr
index_t
DMPerBlock
=
BlockSize
;
// TODO: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
...
...
@@ -385,6 +386,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
B1Spec
,
CSpec
>
;
using
DTransform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
Sequence
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
>
,
Sequence
<
DMPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
>
,
GemmSpecialization
::
MNKOPadding
,
ASpec
,
BSpec
,
B1Spec
,
CSpec
>
;
/*
Descriptors for inputs:
...
...
@@ -547,6 +557,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
}
}
static
auto
MakeDGridDescriptor_M
(
index_t
MRaw
)
{
const
auto
d_grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
DMPerBlock
)
*
DMPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
return
transform_tensor_descriptor
(
d_grid_desc_mraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
({},
{}));
...
...
@@ -562,7 +585,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
using
YGradGridDesc_M0_O_M1
=
decltype
(
MakeYGradGridDescriptor_M0_O_M1
(
YGridDesc_M_O
{}));
using
ZGridDesc_M_N
=
decltype
(
MakeZGridDescriptor_M_N
({},
{}));
using
DGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
DYGridDesc_M_O
=
decltype
(
DTransform
::
MakeCGridDescriptor_M_N
({},
{}));
using
DGridDesc_M
=
decltype
(
MakeDGridDescriptor_M
(
1
));
constexpr
static
auto
make_MaskOutPredicate
()
{
...
...
@@ -656,7 +680,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
B1GridDesc_BK0_N_BK1
,
YGridDesc_M_O
,
LSEGridDesc_M
,
D
GridDesc_M
,
LSE
GridDesc_M
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
...
...
@@ -715,7 +739,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
YGridDesc_M_O
,
DGridDesc_M
,
BlockSize
,
Block
Size
,
DMPer
Block
,
DKPerBlock
>
;
using
DBlock2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseYDotYGrad
::
DefaultBlock2CTileMap
>
;
...
...
@@ -760,6 +784,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// D parameter
DDataType
*
p_d_grid_
;
DYGridDesc_M_O
d_y_grid_desc_m_o_
;
DGridDesc_M
d_grid_desc_m_
;
DBlock2CTileMap
d_block_2_ctile_map_
;
typename
GridwiseYDotYGrad
::
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -934,16 +959,18 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// D parameters
const
auto
p_d_grid
=
static_cast
<
DDataType
*>
(
p_Ds
[
i
]);
const
auto
d_grid_desc_m
=
DeviceOp
::
Make
LSE
GridDescriptor_M
(
problem_desc
.
d_gs_ms_lengths
[
NumDimG
]);
DeviceOp
::
Make
D
GridDescriptor_M
(
problem_desc
.
d_gs_ms_lengths
[
NumDimG
]);
const
auto
d_y_grid_desc_m_o
=
DTransform
::
MakeCGridDescriptor_M_N
(
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
index_t
d_block_start
=
d_grid_size_
;
const
auto
d_block_2_ctile_map
=
DBlock2CTileMap
(
y_grid_desc_m_o
,
d_block_start
);
const
auto
d_block_2_ctile_map
=
DBlock2CTileMap
(
d_
y_grid_desc_m_o
,
d_block_start
);
const
auto
d_y_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseYDotYGrad
::
MakeYGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
y_grid_desc_m_o
);
d_
y_grid_desc_m_o
);
index_t
d_num_blocks_per_batch
=
d_block_2_ctile_map
.
CalculateGridSize
(
y_grid_desc_m_o
);
d_block_2_ctile_map
.
CalculateGridSize
(
d_
y_grid_desc_m_o
);
index_t
d_block_end
=
d_block_start
+
d_num_blocks_per_batch
*
batch_count
;
d_grid_size_
=
d_block_end
;
...
...
@@ -977,6 +1004,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
raw_m_padded
,
raw_n_padded
,
p_d_grid
,
d_y_grid_desc_m_o
,
d_grid_desc_m
,
d_block_2_ctile_map
,
d_y_grid_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -1153,7 +1181,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// TODO: Check if tensor specialization & strides mismatch
const
auto
&
kernel_arg
=
arg
.
group_kernel_args_
[
i
];
const
auto
&
device_arg
=
arg
.
group_device_args_
[
i
];
if
(
!
GridwiseYDotYGrad
::
CheckValidity
(
kernel_arg
.
y_grid_desc_m_o_
,
if
(
!
GridwiseYDotYGrad
::
CheckValidity
(
kernel_arg
.
d_
y_grid_desc_m_o_
,
kernel_arg
.
d_block_2_ctile_map_
))
{
return
false
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_ydotygrad.hpp
View file @
be38f68d
...
...
@@ -45,21 +45,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
{
return
false
;
}
//
const auto M = y_grid_desc_m_n.GetLength(I0);
const
auto
M
=
y_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
y_grid_desc_m_n
.
GetLength
(
I1
);
if
(
N
<
NPerBlock
)
{
return
false
;
}
// std::cout << "m: " << M <<" n: " << N << std::endl;
//
if(M < MPerBlock)
//
{
//
return false;
//
}
//
if(M % MPerBlock != 0)
//
{
//
return false;
//
}
if
(
M
<
MPerBlock
)
{
return
false
;
}
if
(
M
%
MPerBlock
!=
0
)
{
return
false
;
}
return
true
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment