Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f3e61c0a
Commit
f3e61c0a
authored
Apr 13, 2023
by
danyao12
Browse files
datatype of bwd output can be selected
parent
f7e05f9e
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
441 additions
and
387 deletions
+441
-387
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward.cpp
...ale_softmax_gemm/batched_multihead_attention_backward.cpp
+134
-114
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward.cpp
...ale_softmax_gemm/grouped_multihead_attention_backward.cpp
+135
-115
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
+48
-44
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
+48
-44
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp
+20
-18
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
+20
-18
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
+18
-17
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt2.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt2.hpp
+18
-17
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward.cpp
View file @
f3e61c0a
This diff is collapsed.
Click to expand it.
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward.cpp
View file @
f3e61c0a
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
View file @
f3e61c0a
...
@@ -28,7 +28,8 @@ namespace tensor_operation {
...
@@ -28,7 +28,8 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
DataType
,
typename
InputDataType
,
typename
OutputDataType
,
typename
ZDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
...
@@ -53,16 +54,16 @@ __global__ void
...
@@ -53,16 +54,16 @@ __global__ void
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
#endif
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1
(
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1
(
const
DataType
*
__restrict__
p_a_grid
,
const
Input
DataType
*
__restrict__
p_a_grid
,
const
DataType
*
__restrict__
p_b_grid
,
const
Input
DataType
*
__restrict__
p_b_grid
,
ZDataType
*
__restrict__
p_z_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
DataType
*
__restrict__
p_b1_grid
,
const
Input
DataType
*
__restrict__
p_b1_grid
,
const
DataType
*
__restrict__
p_c_grid
,
const
Input
DataType
*
__restrict__
p_c_grid
,
const
LSEDataType
*
__restrict__
p_lse_grid
,
const
LSEDataType
*
__restrict__
p_lse_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
const
Input
DataType
*
__restrict__
p_ygrad_grid
,
DataType
*
__restrict__
p_qgrad_grid
,
Output
DataType
*
__restrict__
p_qgrad_grid
,
DataType
*
__restrict__
p_kgrad_grid
,
Output
DataType
*
__restrict__
p_kgrad_grid
,
DataType
*
__restrict__
p_vgrad_grid
,
Output
DataType
*
__restrict__
p_vgrad_grid
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
AccElementwiseOperation
acc_element_op
,
const
AccElementwiseOperation
acc_element_op
,
...
@@ -171,7 +172,8 @@ template <index_t NumDimG,
...
@@ -171,7 +172,8 @@ template <index_t NumDimG,
index_t
NumDimN
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimK
,
index_t
NumDimO
,
// NumDimGemm1N
index_t
NumDimO
,
// NumDimGemm1N
typename
DataType
,
typename
InputDataType
,
typename
OutputDataType
,
typename
GemmDataType
,
typename
GemmDataType
,
typename
ZDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
...
@@ -597,7 +599,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -597,7 +599,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
DataType
,
// TODO: distinguish A/B datatype
InputDataType
,
// TODO: distinguish A/B datatype
OutputDataType
,
GemmDataType
,
GemmDataType
,
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
...
@@ -666,16 +669,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -666,16 +669,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
Argument
(
const
DataType
*
p_a_grid
,
const
Input
DataType
*
p_a_grid
,
const
DataType
*
p_b_grid
,
const
Input
DataType
*
p_b_grid
,
ZDataType
*
p_z_grid
,
ZDataType
*
p_z_grid
,
const
DataType
*
p_b1_grid
,
const
Input
DataType
*
p_b1_grid
,
const
DataType
*
p_c_grid
,
// for dS
const
Input
DataType
*
p_c_grid
,
// for dS
const
LSEDataType
*
p_lse_grid
,
const
LSEDataType
*
p_lse_grid
,
const
DataType
*
p_ygrad_grid
,
const
Input
DataType
*
p_ygrad_grid
,
DataType
*
p_qgrad_grid
,
Output
DataType
*
p_qgrad_grid
,
DataType
*
p_kgrad_grid
,
Output
DataType
*
p_kgrad_grid
,
DataType
*
p_vgrad_grid
,
Output
DataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
...
@@ -820,16 +823,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -820,16 +823,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
}
// pointers
// pointers
const
DataType
*
p_a_grid_
;
const
Input
DataType
*
p_a_grid_
;
const
DataType
*
p_b_grid_
;
const
Input
DataType
*
p_b_grid_
;
ZDataType
*
p_z_grid_
;
ZDataType
*
p_z_grid_
;
const
DataType
*
p_b1_grid_
;
const
Input
DataType
*
p_b1_grid_
;
const
DataType
*
p_c_grid_
;
const
Input
DataType
*
p_c_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
DataType
*
p_ygrad_grid_
;
const
Input
DataType
*
p_ygrad_grid_
;
DataType
*
p_qgrad_grid_
;
Output
DataType
*
p_qgrad_grid_
;
DataType
*
p_kgrad_grid_
;
Output
DataType
*
p_kgrad_grid_
;
DataType
*
p_vgrad_grid_
;
Output
DataType
*
p_vgrad_grid_
;
// tensor descriptor
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
...
@@ -901,7 +904,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -901,7 +904,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1
<
const
auto
kernel
=
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1
<
GridwiseGemm
,
GridwiseGemm
,
DataType
,
InputDataType
,
OutputDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
...
@@ -1067,16 +1071,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1067,16 +1071,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
}
static
auto
MakeArgument
(
static
auto
MakeArgument
(
const
DataType
*
p_a
,
const
Input
DataType
*
p_a
,
const
DataType
*
p_b
,
const
Input
DataType
*
p_b
,
ZDataType
*
p_z
,
ZDataType
*
p_z
,
const
DataType
*
p_b1
,
const
Input
DataType
*
p_b1
,
const
DataType
*
p_c
,
const
Input
DataType
*
p_c
,
const
LSEDataType
*
p_lse
,
const
LSEDataType
*
p_lse
,
const
DataType
*
p_ygrad_grid
,
const
Input
DataType
*
p_ygrad_grid
,
DataType
*
p_qgrad_grid
,
Output
DataType
*
p_qgrad_grid
,
DataType
*
p_kgrad_grid
,
Output
DataType
*
p_kgrad_grid
,
DataType
*
p_vgrad_grid
,
Output
DataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
...
@@ -1182,16 +1186,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1182,16 +1186,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
float
p_drop
,
float
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
DataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
Input
DataType
*>
(
p_a
),
static_cast
<
const
DataType
*>
(
p_b
),
static_cast
<
const
Input
DataType
*>
(
p_b
),
static_cast
<
ZDataType
*>
(
p_z
),
static_cast
<
ZDataType
*>
(
p_z
),
static_cast
<
const
DataType
*>
(
p_b1
),
static_cast
<
const
Input
DataType
*>
(
p_b1
),
static_cast
<
const
DataType
*>
(
p_c
),
static_cast
<
const
Input
DataType
*>
(
p_c
),
static_cast
<
const
LSEDataType
*>
(
p_lse
),
static_cast
<
const
LSEDataType
*>
(
p_lse
),
static_cast
<
const
DataType
*>
(
p_ygrad_grid
),
static_cast
<
const
Input
DataType
*>
(
p_ygrad_grid
),
static_cast
<
DataType
*>
(
p_qgrad_grid
),
static_cast
<
Output
DataType
*>
(
p_qgrad_grid
),
static_cast
<
DataType
*>
(
p_kgrad_grid
),
static_cast
<
Output
DataType
*>
(
p_kgrad_grid
),
static_cast
<
DataType
*>
(
p_vgrad_grid
),
static_cast
<
Output
DataType
*>
(
p_vgrad_grid
),
p_acc0_biases
,
// cast in struct Argument
p_acc0_biases
,
// cast in struct Argument
p_acc1_biases
,
// cast in struct Argument
p_acc1_biases
,
// cast in struct Argument
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
f3e61c0a
...
@@ -27,7 +27,8 @@ namespace tensor_operation {
...
@@ -27,7 +27,8 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
DataType
,
typename
InputDataType
,
typename
OutputDataType
,
typename
ZDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
...
@@ -52,16 +53,16 @@ __global__ void
...
@@ -52,16 +53,16 @@ __global__ void
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
#endif
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2
(
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2
(
const
DataType
*
__restrict__
p_a_grid
,
const
Input
DataType
*
__restrict__
p_a_grid
,
const
DataType
*
__restrict__
p_b_grid
,
const
Input
DataType
*
__restrict__
p_b_grid
,
ZDataType
*
__restrict__
p_z_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
DataType
*
__restrict__
p_b1_grid
,
const
Input
DataType
*
__restrict__
p_b1_grid
,
const
DataType
*
__restrict__
p_c_grid
,
const
Input
DataType
*
__restrict__
p_c_grid
,
const
LSEDataType
*
__restrict__
p_lse_grid
,
const
LSEDataType
*
__restrict__
p_lse_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
const
Input
DataType
*
__restrict__
p_ygrad_grid
,
DataType
*
__restrict__
p_qgrad_grid
,
Output
DataType
*
__restrict__
p_qgrad_grid
,
DataType
*
__restrict__
p_kgrad_grid
,
Output
DataType
*
__restrict__
p_kgrad_grid
,
DataType
*
__restrict__
p_vgrad_grid
,
Output
DataType
*
__restrict__
p_vgrad_grid
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
AccElementwiseOperation
acc_element_op
,
const
AccElementwiseOperation
acc_element_op
,
...
@@ -170,7 +171,8 @@ template <index_t NumDimG,
...
@@ -170,7 +171,8 @@ template <index_t NumDimG,
index_t
NumDimN
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimK
,
index_t
NumDimO
,
// NumDimGemm1N
index_t
NumDimO
,
// NumDimGemm1N
typename
DataType
,
typename
InputDataType
,
typename
OutputDataType
,
typename
GemmDataType
,
typename
GemmDataType
,
typename
ZDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
...
@@ -596,7 +598,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -596,7 +598,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
DataType
,
// TODO: distinguish A/B datatype
InputDataType
,
// TODO: distinguish A/B datatype
OutputDataType
,
GemmDataType
,
GemmDataType
,
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
...
@@ -665,16 +668,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -665,16 +668,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
Argument
(
const
DataType
*
p_a_grid
,
const
Input
DataType
*
p_a_grid
,
const
DataType
*
p_b_grid
,
const
Input
DataType
*
p_b_grid
,
ZDataType
*
p_z_grid
,
ZDataType
*
p_z_grid
,
const
DataType
*
p_b1_grid
,
const
Input
DataType
*
p_b1_grid
,
const
DataType
*
p_c_grid
,
// for dS
const
Input
DataType
*
p_c_grid
,
// for dS
const
LSEDataType
*
p_lse_grid
,
const
LSEDataType
*
p_lse_grid
,
const
DataType
*
p_ygrad_grid
,
const
Input
DataType
*
p_ygrad_grid
,
DataType
*
p_qgrad_grid
,
Output
DataType
*
p_qgrad_grid
,
DataType
*
p_kgrad_grid
,
Output
DataType
*
p_kgrad_grid
,
DataType
*
p_vgrad_grid
,
Output
DataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
...
@@ -818,16 +821,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -818,16 +821,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
}
// pointers
// pointers
const
DataType
*
p_a_grid_
;
const
Input
DataType
*
p_a_grid_
;
const
DataType
*
p_b_grid_
;
const
Input
DataType
*
p_b_grid_
;
ZDataType
*
p_z_grid_
;
ZDataType
*
p_z_grid_
;
const
DataType
*
p_b1_grid_
;
const
Input
DataType
*
p_b1_grid_
;
const
DataType
*
p_c_grid_
;
const
Input
DataType
*
p_c_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
DataType
*
p_ygrad_grid_
;
const
Input
DataType
*
p_ygrad_grid_
;
DataType
*
p_qgrad_grid_
;
Output
DataType
*
p_qgrad_grid_
;
DataType
*
p_kgrad_grid_
;
Output
DataType
*
p_kgrad_grid_
;
DataType
*
p_vgrad_grid_
;
Output
DataType
*
p_vgrad_grid_
;
// tensor descriptor
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
...
@@ -903,7 +906,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -903,7 +906,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2
<
const
auto
kernel
=
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2
<
GridwiseGemm
,
GridwiseGemm
,
DataType
,
InputDataType
,
OutputDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
...
@@ -1067,16 +1071,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1067,16 +1071,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
}
static
auto
MakeArgument
(
static
auto
MakeArgument
(
const
DataType
*
p_a
,
const
Input
DataType
*
p_a
,
const
DataType
*
p_b
,
const
Input
DataType
*
p_b
,
ZDataType
*
p_z
,
ZDataType
*
p_z
,
const
DataType
*
p_b1
,
const
Input
DataType
*
p_b1
,
const
DataType
*
p_c
,
const
Input
DataType
*
p_c
,
const
LSEDataType
*
p_lse
,
const
LSEDataType
*
p_lse
,
const
DataType
*
p_ygrad_grid
,
const
Input
DataType
*
p_ygrad_grid
,
DataType
*
p_qgrad_grid
,
Output
DataType
*
p_qgrad_grid
,
DataType
*
p_kgrad_grid
,
Output
DataType
*
p_kgrad_grid
,
DataType
*
p_vgrad_grid
,
Output
DataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
...
@@ -1182,16 +1186,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1182,16 +1186,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
float
p_drop
,
float
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
DataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
Input
DataType
*>
(
p_a
),
static_cast
<
const
DataType
*>
(
p_b
),
static_cast
<
const
Input
DataType
*>
(
p_b
),
static_cast
<
ZDataType
*>
(
p_z
),
static_cast
<
ZDataType
*>
(
p_z
),
static_cast
<
const
DataType
*>
(
p_b1
),
static_cast
<
const
Input
DataType
*>
(
p_b1
),
static_cast
<
const
DataType
*>
(
p_c
),
static_cast
<
const
Input
DataType
*>
(
p_c
),
static_cast
<
const
LSEDataType
*>
(
p_lse
),
static_cast
<
const
LSEDataType
*>
(
p_lse
),
static_cast
<
const
DataType
*>
(
p_ygrad_grid
),
static_cast
<
const
Input
DataType
*>
(
p_ygrad_grid
),
static_cast
<
DataType
*>
(
p_qgrad_grid
),
static_cast
<
Output
DataType
*>
(
p_qgrad_grid
),
static_cast
<
DataType
*>
(
p_kgrad_grid
),
static_cast
<
Output
DataType
*>
(
p_kgrad_grid
),
static_cast
<
DataType
*>
(
p_vgrad_grid
),
static_cast
<
Output
DataType
*>
(
p_vgrad_grid
),
p_acc0_biases
,
// cast in struct Argument
p_acc0_biases
,
// cast in struct Argument
p_acc1_biases
,
// cast in struct Argument
p_acc1_biases
,
// cast in struct Argument
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp
View file @
f3e61c0a
...
@@ -150,7 +150,8 @@ template <index_t NumDimG,
...
@@ -150,7 +150,8 @@ template <index_t NumDimG,
index_t
NumDimN
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimK
,
index_t
NumDimO
,
// NumDimGemm1N
index_t
NumDimO
,
// NumDimGemm1N
typename
DataType
,
typename
InputDataType
,
typename
OutputDataType
,
typename
GemmDataType
,
typename
GemmDataType
,
typename
ZDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
...
@@ -534,7 +535,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -534,7 +535,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
DataType
,
// TODO: distinguish A/B datatype
InputDataType
,
// TODO: distinguish A/B datatype
OutputDataType
,
GemmDataType
,
GemmDataType
,
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
...
@@ -604,16 +606,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -604,16 +606,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct
GroupKernelArg
struct
GroupKernelArg
{
{
// pointers
// pointers
const
DataType
*
p_a_grid_
;
const
Input
DataType
*
p_a_grid_
;
const
DataType
*
p_b_grid_
;
const
Input
DataType
*
p_b_grid_
;
ZDataType
*
p_z_grid_
;
ZDataType
*
p_z_grid_
;
const
DataType
*
p_b1_grid_
;
const
Input
DataType
*
p_b1_grid_
;
const
DataType
*
p_c_grid_
;
const
Input
DataType
*
p_c_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
DataType
*
p_ygrad_grid_
;
const
Input
DataType
*
p_ygrad_grid_
;
DataType
*
p_qgrad_grid_
;
Output
DataType
*
p_qgrad_grid_
;
DataType
*
p_kgrad_grid_
;
Output
DataType
*
p_kgrad_grid_
;
DataType
*
p_vgrad_grid_
;
Output
DataType
*
p_vgrad_grid_
;
// tensor descriptors for block/thread-wise copy
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
...
@@ -712,16 +714,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -712,16 +714,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
grid_size_
=
0
;
grid_size_
=
0
;
for
(
index_t
i
=
0
;
i
<
group_count_
;
i
++
)
for
(
index_t
i
=
0
;
i
<
group_count_
;
i
++
)
{
{
const
auto
p_a_grid
=
static_cast
<
const
DataType
*>
(
p_As
[
i
]);
const
auto
p_a_grid
=
static_cast
<
const
Input
DataType
*>
(
p_As
[
i
]);
const
auto
p_b_grid
=
static_cast
<
const
DataType
*>
(
p_Bs
[
i
]);
const
auto
p_b_grid
=
static_cast
<
const
Input
DataType
*>
(
p_Bs
[
i
]);
auto
p_z_grid
=
static_cast
<
ZDataType
*>
(
p_Zs
[
i
]);
auto
p_z_grid
=
static_cast
<
ZDataType
*>
(
p_Zs
[
i
]);
const
auto
p_b1_grid
=
static_cast
<
const
DataType
*>
(
p_B1s
[
i
]);
const
auto
p_b1_grid
=
static_cast
<
const
Input
DataType
*>
(
p_B1s
[
i
]);
const
auto
p_c_grid
=
static_cast
<
const
DataType
*>
(
p_Cs
[
i
]);
const
auto
p_c_grid
=
static_cast
<
const
Input
DataType
*>
(
p_Cs
[
i
]);
const
auto
p_lse_grid
=
static_cast
<
const
LSEDataType
*>
(
p_LSEs
[
i
]);
const
auto
p_lse_grid
=
static_cast
<
const
LSEDataType
*>
(
p_LSEs
[
i
]);
const
auto
p_ygrad_grid
=
static_cast
<
const
DataType
*>
(
p_Ygrads
[
i
]);
const
auto
p_ygrad_grid
=
static_cast
<
const
Input
DataType
*>
(
p_Ygrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
DataType
*>
(
p_Qgrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
Output
DataType
*>
(
p_Qgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
DataType
*>
(
p_Kgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
Output
DataType
*>
(
p_Kgrads
[
i
]);
auto
p_vgrad_grid
=
static_cast
<
DataType
*>
(
p_Vgrads
[
i
]);
auto
p_vgrad_grid
=
static_cast
<
Output
DataType
*>
(
p_Vgrads
[
i
]);
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
f3e61c0a
...
@@ -150,7 +150,8 @@ template <index_t NumDimG,
...
@@ -150,7 +150,8 @@ template <index_t NumDimG,
index_t
NumDimN
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimK
,
index_t
NumDimO
,
// NumDimGemm1N
index_t
NumDimO
,
// NumDimGemm1N
typename
DataType
,
typename
InputDataType
,
typename
OutputDataType
,
typename
GemmDataType
,
typename
GemmDataType
,
typename
ZDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
...
@@ -527,7 +528,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -527,7 +528,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
DataType
,
// TODO: distinguish A/B datatype
InputDataType
,
// TODO: distinguish A/B datatype
OutputDataType
,
GemmDataType
,
GemmDataType
,
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
...
@@ -597,16 +599,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -597,16 +599,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct
GroupKernelArg
struct
GroupKernelArg
{
{
// pointers
// pointers
const
DataType
*
p_a_grid_
;
const
Input
DataType
*
p_a_grid_
;
const
DataType
*
p_b_grid_
;
const
Input
DataType
*
p_b_grid_
;
ZDataType
*
p_z_grid_
;
ZDataType
*
p_z_grid_
;
const
DataType
*
p_b1_grid_
;
const
Input
DataType
*
p_b1_grid_
;
const
DataType
*
p_c_grid_
;
const
Input
DataType
*
p_c_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
DataType
*
p_ygrad_grid_
;
const
Input
DataType
*
p_ygrad_grid_
;
DataType
*
p_qgrad_grid_
;
Output
DataType
*
p_qgrad_grid_
;
DataType
*
p_kgrad_grid_
;
Output
DataType
*
p_kgrad_grid_
;
DataType
*
p_vgrad_grid_
;
Output
DataType
*
p_vgrad_grid_
;
// tensor descriptors for block/thread-wise copy
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
...
@@ -705,16 +707,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -705,16 +707,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
grid_size_
=
0
;
grid_size_
=
0
;
for
(
index_t
i
=
0
;
i
<
group_count_
;
i
++
)
for
(
index_t
i
=
0
;
i
<
group_count_
;
i
++
)
{
{
const
auto
p_a_grid
=
static_cast
<
const
DataType
*>
(
p_As
[
i
]);
const
auto
p_a_grid
=
static_cast
<
const
Input
DataType
*>
(
p_As
[
i
]);
const
auto
p_b_grid
=
static_cast
<
const
DataType
*>
(
p_Bs
[
i
]);
const
auto
p_b_grid
=
static_cast
<
const
Input
DataType
*>
(
p_Bs
[
i
]);
auto
p_z_grid
=
static_cast
<
ZDataType
*>
(
p_Zs
[
i
]);
auto
p_z_grid
=
static_cast
<
ZDataType
*>
(
p_Zs
[
i
]);
const
auto
p_b1_grid
=
static_cast
<
const
DataType
*>
(
p_B1s
[
i
]);
const
auto
p_b1_grid
=
static_cast
<
const
Input
DataType
*>
(
p_B1s
[
i
]);
const
auto
p_c_grid
=
static_cast
<
const
DataType
*>
(
p_Cs
[
i
]);
const
auto
p_c_grid
=
static_cast
<
const
Input
DataType
*>
(
p_Cs
[
i
]);
const
auto
p_lse_grid
=
static_cast
<
const
LSEDataType
*>
(
p_LSEs
[
i
]);
const
auto
p_lse_grid
=
static_cast
<
const
LSEDataType
*>
(
p_LSEs
[
i
]);
const
auto
p_ygrad_grid
=
static_cast
<
const
DataType
*>
(
p_Ygrads
[
i
]);
const
auto
p_ygrad_grid
=
static_cast
<
const
Input
DataType
*>
(
p_Ygrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
DataType
*>
(
p_Qgrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
Output
DataType
*>
(
p_Qgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
DataType
*>
(
p_Kgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
Output
DataType
*>
(
p_Kgrads
[
i
]);
auto
p_vgrad_grid
=
static_cast
<
DataType
*>
(
p_Vgrads
[
i
]);
auto
p_vgrad_grid
=
static_cast
<
Output
DataType
*>
(
p_Vgrads
[
i
]);
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
View file @
f3e61c0a
...
@@ -20,7 +20,8 @@
...
@@ -20,7 +20,8 @@
namespace
ck
{
namespace
ck
{
template
<
typename
DataType
,
template
<
typename
InputDataType
,
typename
OutputDataType
,
typename
GemmDataType
,
typename
GemmDataType
,
typename
FloatGemmAcc
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatCShuffle
,
...
@@ -381,7 +382,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -381,7 +382,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GemmDataType
,
GridDesc_K0_M_K1
,
GridDesc_K0_M_K1
,
decltype
(
q_block_desc_k0_m_k1
),
decltype
(
q_block_desc_k0_m_k1
),
...
@@ -406,7 +407,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -406,7 +407,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GemmDataType
,
GridDesc_K0_N_K1
,
GridDesc_K0_N_K1
,
decltype
(
k_block_desc_k0_n_k1
),
decltype
(
k_block_desc_k0_n_k1
),
...
@@ -431,7 +432,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -431,7 +432,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GemmDataType
,
GridDesc_K0_N_K1
,
GridDesc_K0_N_K1
,
decltype
(
v_block_desc_k0_n_k1
),
decltype
(
v_block_desc_k0_n_k1
),
...
@@ -456,7 +457,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -456,7 +457,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GemmDataType
,
GridDesc_K0_M_K1
,
GridDesc_K0_M_K1
,
decltype
(
ygrad_block_desc_k0_m_k1
),
decltype
(
ygrad_block_desc_k0_m_k1
),
...
@@ -1043,7 +1044,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1043,7 +1044,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
using
CBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
using
CBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatGemmAcc
,
DataType
,
Output
DataType
,
decltype
(
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
),
decltype
(
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
),
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
,
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
,
ElementwiseOp
,
// CElementwiseOperation
ElementwiseOp
,
// CElementwiseOperation
...
@@ -1059,7 +1060,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1059,7 +1060,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
struct
YDotYGrad_M_O_
{
{
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
DataType
);
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
Input
DataType
);
static
constexpr
auto
ThreadClusterLength_O
=
static
constexpr
auto
ThreadClusterLength_O
=
Number
<
BlockSliceLength_O_
/
SrcScalarPerVector
>
{};
Number
<
BlockSliceLength_O_
/
SrcScalarPerVector
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
...
@@ -1234,16 +1235,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1234,16 +1235,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename
C0MatrixMask
,
typename
C0MatrixMask
,
typename
VGradGridDescriptor_N_O
,
typename
VGradGridDescriptor_N_O
,
typename
YGradGridDesc_O0_M_O1
>
typename
YGradGridDesc_O0_M_O1
>
__device__
static
void
Run
(
const
DataType
*
__restrict__
p_q_grid
,
__device__
static
void
Run
(
const
Input
DataType
*
__restrict__
p_q_grid
,
const
DataType
*
__restrict__
p_k_grid
,
const
Input
DataType
*
__restrict__
p_k_grid
,
unsigned
short
*
__restrict__
p_z_grid
,
unsigned
short
*
__restrict__
p_z_grid
,
const
DataType
*
__restrict__
p_v_grid
,
const
Input
DataType
*
__restrict__
p_v_grid
,
const
DataType
*
__restrict__
p_y_grid
,
const
Input
DataType
*
__restrict__
p_y_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
const
Input
DataType
*
__restrict__
p_ygrad_grid
,
DataType
*
__restrict__
p_qgrad_grid
,
Output
DataType
*
__restrict__
p_qgrad_grid
,
DataType
*
__restrict__
p_kgrad_grid
,
Output
DataType
*
__restrict__
p_kgrad_grid
,
DataType
*
__restrict__
p_vgrad_grid
,
Output
DataType
*
__restrict__
p_vgrad_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
...
@@ -1723,7 +1724,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1723,7 +1724,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// performs for y
// performs for y
auto
y_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
auto
y_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
Input
DataType
,
FloatGemmAcc
,
FloatGemmAcc
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
decltype
(
y_thread_desc_m0_m1_o0_o1
),
decltype
(
y_thread_desc_m0_m1_o0_o1
),
...
@@ -2307,7 +2308,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -2307,7 +2308,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
FloatCShuffle
,
// typename SrcData,
DataType
,
// typename DstData,
Output
DataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
),
decltype
(
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt2.hpp
View file @
f3e61c0a
...
@@ -20,7 +20,8 @@
...
@@ -20,7 +20,8 @@
namespace
ck
{
namespace
ck
{
template
<
typename
DataType
,
template
<
typename
InputDataType
,
typename
OutputDataType
,
typename
GemmDataType
,
typename
GemmDataType
,
typename
FloatGemmAcc
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatCShuffle
,
...
@@ -457,7 +458,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -457,7 +458,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GemmDataType
,
GridDesc_K0_M_K1
,
GridDesc_K0_M_K1
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
...
@@ -482,7 +483,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -482,7 +483,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GemmDataType
,
GridDesc_K0_N_K1
,
GridDesc_K0_N_K1
,
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
@@ -585,7 +586,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -585,7 +586,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Sequence
<
B1K0
,
Gemm1NPerBlock
,
B1K1
>
,
Sequence
<
B1K0
,
Gemm1NPerBlock
,
B1K1
>
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GemmDataType
,
GridDesc_K0_N_K1
,
GridDesc_K0_N_K1
,
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
@@ -823,7 +824,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -823,7 +824,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename
Gemm2Params_N_O_M
::
BBlockSliceLengths
,
typename
Gemm2Params_N_O_M
::
BBlockSliceLengths
,
typename
Gemm2Params_N_O_M
::
BThreadClusterLengths
,
typename
Gemm2Params_N_O_M
::
BThreadClusterLengths
,
typename
Gemm2Params_N_O_M
::
BThreadClusterArrangeOrder
,
typename
Gemm2Params_N_O_M
::
BThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GemmDataType
,
GridDesc_M0_O_M1
,
GridDesc_M0_O_M1
,
decltype
(
b_block_desc_m0_o_m1
),
decltype
(
b_block_desc_m0_o_m1
),
...
@@ -892,7 +893,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -892,7 +893,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
using
CBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
using
CBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatGemmAcc
,
DataType
,
Output
DataType
,
decltype
(
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
),
decltype
(
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
),
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
,
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
,
ElementwiseOp
,
// CElementwiseOperation
ElementwiseOp
,
// CElementwiseOperation
...
@@ -908,7 +909,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -908,7 +909,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
struct
YDotYGrad_M_O_
{
{
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
DataType
);
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
Input
DataType
);
static
constexpr
auto
ThreadClusterLength_O
=
static
constexpr
auto
ThreadClusterLength_O
=
Number
<
BlockSliceLength_O_
/
SrcScalarPerVector
>
{};
Number
<
BlockSliceLength_O_
/
SrcScalarPerVector
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
...
@@ -1144,16 +1145,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1144,16 +1145,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename
C0MatrixMask
,
typename
C0MatrixMask
,
typename
VGradGridDescriptor_N_O
,
typename
VGradGridDescriptor_N_O
,
typename
YGradGridDesc_M0_O_M1
>
typename
YGradGridDesc_M0_O_M1
>
__device__
static
void
Run
(
const
DataType
*
__restrict__
p_q_grid
,
__device__
static
void
Run
(
const
Input
DataType
*
__restrict__
p_q_grid
,
const
DataType
*
__restrict__
p_k_grid
,
const
Input
DataType
*
__restrict__
p_k_grid
,
unsigned
short
*
__restrict__
p_z_grid
,
unsigned
short
*
__restrict__
p_z_grid
,
const
DataType
*
__restrict__
p_v_grid
,
const
Input
DataType
*
__restrict__
p_v_grid
,
const
DataType
*
__restrict__
p_y_grid
,
const
Input
DataType
*
__restrict__
p_y_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
const
Input
DataType
*
__restrict__
p_ygrad_grid
,
DataType
*
__restrict__
p_qgrad_grid
,
Output
DataType
*
__restrict__
p_qgrad_grid
,
DataType
*
__restrict__
p_kgrad_grid
,
Output
DataType
*
__restrict__
p_kgrad_grid
,
DataType
*
__restrict__
p_vgrad_grid
,
Output
DataType
*
__restrict__
p_vgrad_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
...
@@ -1646,7 +1647,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1646,7 +1647,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// performs double duty for both y and ygrad
// performs double duty for both y and ygrad
auto
yygrad_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
auto
yygrad_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
Input
DataType
,
FloatGemmAcc
,
FloatGemmAcc
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
decltype
(
y_thread_desc_m0_m1_o0_o1
),
decltype
(
y_thread_desc_m0_m1_o0_o1
),
...
@@ -2257,7 +2258,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -2257,7 +2258,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
FloatCShuffle
,
// typename SrcData,
DataType
,
// typename DstData,
Output
DataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
),
decltype
(
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment