Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1a306e0d
Commit
1a306e0d
authored
May 18, 2023
by
danyao12
Browse files
add OutputDataType&Deterministic for pt1q2
parent
35b59ac6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
367 additions
and
274 deletions
+367
-274
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v4.cpp
..._softmax_gemm/batched_multihead_attention_backward_v4.cpp
+160
-131
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
+163
-110
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt4.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt4.hpp
+44
-33
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v4.cpp
View file @
1a306e0d
...
@@ -51,10 +51,11 @@ Kernel outputs:
...
@@ -51,10 +51,11 @@ Kernel outputs:
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
U16
=
unsigned
short
;
using
INT32
=
int32_t
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
...
@@ -62,12 +63,13 @@ using Scale = ck::tensor_operation::element_wise::Scale;
...
@@ -62,12 +63,13 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using
QKVElementOp
=
PassThrough
;
using
QKVElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
DataType
=
F16
;
using
InputDataType
=
F16
;
using
OutputDataType
=
F16
;
using
GemmDataType
=
F16
;
using
GemmDataType
=
F16
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
using
ZDataType
=
U16
;
using
ZDataType
=
U16
;
// INT32
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
...
@@ -76,6 +78,9 @@ static constexpr ck::index_t NumDimM = 1;
...
@@ -76,6 +78,9 @@ static constexpr ck::index_t NumDimM = 1;
static
constexpr
ck
::
index_t
NumDimN
=
1
;
static
constexpr
ck
::
index_t
NumDimN
=
1
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
// When OutputDataType == F32, CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static
constexpr
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
=
8
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
#if USING_MASK
...
@@ -86,10 +91,11 @@ static constexpr auto MaskingSpec =
...
@@ -86,10 +91,11 @@ static constexpr auto MaskingSpec =
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
#endif
#endif
static
constexpr
auto
TensorSpecQ
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecQ
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
true
;
// DIM should be a multiple of 8.
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
// If DIM <= 32 , ues prototype1 1st template.
...
@@ -103,7 +109,8 @@ using DeviceGemmInstance =
...
@@ -103,7 +109,8 @@ using DeviceGemmInstance =
NumDimN
,
NumDimN
,
NumDimK
,
NumDimK
,
NumDimO
,
NumDimO
,
DataType
,
InputDataType
,
OutputDataType
,
GemmDataType
,
GemmDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
...
@@ -154,8 +161,9 @@ using DeviceGemmInstance =
...
@@ -154,8 +161,9 @@ using DeviceGemmInstance =
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
...
@@ -164,7 +172,8 @@ using DeviceGemmInstance =
...
@@ -164,7 +172,8 @@ using DeviceGemmInstance =
NumDimN
,
NumDimN
,
NumDimK
,
NumDimK
,
NumDimO
,
NumDimO
,
DataType
,
InputDataType
,
OutputDataType
,
GemmDataType
,
GemmDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
...
@@ -215,8 +224,9 @@ using DeviceGemmInstance =
...
@@ -215,8 +224,9 @@ using DeviceGemmInstance =
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
// using DeviceGemmInstance =
// using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...
@@ -225,7 +235,8 @@ using DeviceGemmInstance =
...
@@ -225,7 +235,8 @@ using DeviceGemmInstance =
// NumDimN,
// NumDimN,
// NumDimK,
// NumDimK,
// NumDimO,
// NumDimO,
// DataType,
// InputDataType,
// OutputDataType,
// GemmDataType,
// GemmDataType,
// ZDataType,
// ZDataType,
// LSEDataType,
// LSEDataType,
...
@@ -283,8 +294,9 @@ using DeviceGemmInstance =
...
@@ -283,8 +294,9 @@ using DeviceGemmInstance =
// 1, // CShuffleMXdlPerWavePerShuffle
// 1, // CShuffleMXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
// CShuffleBlockTransferScalarPerVector_NPerBlock,
// MaskingSpec>; // MaskingSpecialization
// MaskingSpec,
// Deterministic>;
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
...
@@ -293,7 +305,8 @@ using DeviceGemmInstance =
...
@@ -293,7 +305,8 @@ using DeviceGemmInstance =
NumDimN
,
NumDimN
,
NumDimK
,
NumDimK
,
NumDimO
,
NumDimO
,
DataType
,
InputDataType
,
OutputDataType
,
GemmDataType
,
GemmDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
...
@@ -351,14 +364,15 @@ using DeviceGemmInstance =
...
@@ -351,14 +364,15 @@ using DeviceGemmInstance =
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
4
,
// CShuffleNXdlPerWavePerShuffle
4
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
#endif
#endif
// Ref Gemm0: S = alpha * Q * K^T
// Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out
// fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
DataType
,
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
Input
DataType
,
DataType
,
Input
DataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
...
@@ -368,13 +382,13 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -368,13 +382,13 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
// Ref Softmax: P = Softmax(S)
// Ref Softmax: P = Softmax(S)
// fp32 in, fp16 out
// fp32 in, fp16 out
using
ReferenceSoftmaxInstance
=
using
ReferenceSoftmaxInstance
=
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
DataType
,
AccDataType
>
;
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
Input
DataType
,
AccDataType
>
;
// Ref Gemm1: Y = P * V
// Ref Gemm1: Y = P * V
// fp16 in, fp16 out
// fp16 in, fp16 out
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
DataType
,
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
Input
DataType
,
DataType
,
Input
DataType
,
DataType
,
Input
DataType
,
AccDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
...
@@ -382,16 +396,25 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -382,16 +396,25 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
// Ref Gemm for backward pass
// Ref Gemm for backward pass
// fp16 in, fp16 out
// fp16 in, fp16 out
using
ReferenceGemmGradInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
DataType
,
using
ReferenceGemm0GradInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
InputDataType
,
DataType
,
InputDataType
,
DataType
,
InputDataType
,
AccDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
>
;
Scale
>
;
using
ReferenceGemm1GradInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
InputDataType
,
InputDataType
,
OutputDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
Scale
>
;
// Ref dropout
// Ref dropout
using
ReferenceDropoutInstance
=
using
ReferenceDropoutInstance
=
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ushort
,
DataType
,
DataType
>
;
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ZDataType
,
Input
DataType
,
Input
DataType
>
;
template
<
typename
TensorQ
,
template
<
typename
TensorQ
,
typename
TensorK
,
typename
TensorK
,
...
@@ -411,7 +434,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -411,7 +434,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE
&
lse_g_m
,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
TensorZ
&
z_g_m_n
,
ushort
p_dropout_in_16bits
,
ZDataType
p_dropout_in_16bits
,
float
rp_dropout
)
float
rp_dropout
)
{
{
// S = alpha * Q * K^T
// S = alpha * Q * K^T
...
@@ -468,8 +491,8 @@ int run(int argc, char* argv[])
...
@@ -468,8 +491,8 @@ int run(int argc, char* argv[])
ck
::
index_t
N
=
512
;
ck
::
index_t
N
=
512
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
5
4
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G1
=
1
6
;
ck
::
index_t
G1
=
6
;
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
...
@@ -517,10 +540,10 @@ int run(int argc, char* argv[])
...
@@ -517,10 +540,10 @@ int run(int argc, char* argv[])
exit
(
0
);
exit
(
0
);
}
}
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
ZDataType
p_dropout_in_16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
...
@@ -578,12 +601,12 @@ int run(int argc, char* argv[])
...
@@ -578,12 +601,12 @@ int run(int argc, char* argv[])
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
Tensor
<
DataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
Input
DataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
Input
DataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
DataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
Input
DataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
DataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
Input
DataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
DataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
Input
DataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
...
@@ -593,45 +616,45 @@ int run(int argc, char* argv[])
...
@@ -593,45 +616,45 @@ int run(int argc, char* argv[])
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
0
});
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
0
});
switch
(
init_method
)
switch
(
init_method
)
{
{
case
0
:
break
;
case
0
:
break
;
case
1
:
case
1
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
2
,
2
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
2
,
2
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
2
,
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
2
,
2
});
break
;
break
;
case
2
:
case
2
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
0.0
,
1.0
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
Input
DataType
>
{
0.0
,
1.0
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
0.0
,
1.0
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
Input
DataType
>
{
0.0
,
1.0
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
-
0.5
,
0.5
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
Input
DataType
>
{
-
0.5
,
0.5
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
-
0.5
,
0.5
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_3
<
Input
DataType
>
{
-
0.5
,
0.5
});
break
;
break
;
case
3
:
case
3
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
5
,
5
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
5
,
5
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
break
;
break
;
case
4
:
case
4
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
break
;
break
;
case
5
:
case
5
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
// dO dot O = [0; 1; 2; ...]
// dO dot O = [0; 1; 2; ...]
break
;
break
;
case
6
:
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
...
@@ -642,10 +665,10 @@ int run(int argc, char* argv[])
...
@@ -642,10 +665,10 @@ int run(int argc, char* argv[])
//
//
break
;
break
;
default:
default:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
// dy[g0, g1, m, o]
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// O = P V = 0.0039 * ones
...
@@ -656,14 +679,14 @@ int run(int argc, char* argv[])
...
@@ -656,14 +679,14 @@ int run(int argc, char* argv[])
// = 0
// = 0
}
}
Tensor
<
DataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
Input
DataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
Input
DataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
Input
DataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
Input
DataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
q_gs_ms_ks
.
ForEach
(
q_gs_ms_ks
.
ForEach
(
...
@@ -674,16 +697,16 @@ int run(int argc, char* argv[])
...
@@ -674,16 +697,16 @@ int run(int argc, char* argv[])
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
// qkv gradients have the same descriptor as with qkv
// qkv gradients have the same descriptor as with qkv
DeviceMem
q_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
q_device_buf
(
sizeof
(
Input
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
k_device_buf
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
k_device_buf
(
sizeof
(
Input
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
z_device_buf
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
z_device_buf
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
v_device_buf
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
v_device_buf
(
sizeof
(
Input
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
Input
DataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
qgrad_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
qgrad_device_buf
(
sizeof
(
Output
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
Output
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
Output
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
Input
DataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
...
@@ -696,16 +719,16 @@ int run(int argc, char* argv[])
...
@@ -696,16 +719,16 @@ int run(int argc, char* argv[])
// get z matrix
// get z matrix
{
{
auto
argument
=
gemm
.
MakeArgument
(
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
z_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
z_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Output
DataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Output
DataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Output
DataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
{},
// std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths
,
q_gs_ms_ks_lengths
,
...
@@ -741,16 +764,16 @@ int run(int argc, char* argv[])
...
@@ -741,16 +764,16 @@ int run(int argc, char* argv[])
}
}
// not need output z matrix
// not need output z matrix
auto
argument
=
gemm
.
MakeArgument
(
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
nullptr
),
// set to nullptr
static_cast
<
ZDataType
*>
(
nullptr
),
// set to nullptr
static_cast
<
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Output
DataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Output
DataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Output
DataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
{},
// std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths
,
q_gs_ms_ks_lengths
,
...
@@ -785,10 +808,13 @@ int run(int argc, char* argv[])
...
@@ -785,10 +808,13 @@ int run(int argc, char* argv[])
// 3x MNK + 2x MNO
// 3x MNK + 2x MNO
std
::
size_t
flop
=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
std
::
size_t
flop
=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
// Q/K/V/Y, dQ/dK/dV/dY, LSE
std
::
size_t
num_btype
=
(
sizeof
(
DataType
)
*
M
*
K
+
sizeof
(
DataType
)
*
K
*
N
+
std
::
size_t
num_btype
=
sizeof
(
DataType
)
*
N
*
O
+
sizeof
(
DataType
)
*
M
*
O
)
*
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
size_t
(
2
)
*
BatchCount
+
sizeof
(
InputDataType
)
*
N
*
O
+
sizeof
(
InputDataType
)
*
M
*
O
*
size_t
(
2
)
+
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
sizeof
(
OutputDataType
)
*
M
*
K
+
sizeof
(
OutputDataType
)
*
K
*
N
+
sizeof
(
OutputDataType
)
*
N
*
O
)
*
BatchCount
+
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
@@ -831,14 +857,14 @@ int run(int argc, char* argv[])
...
@@ -831,14 +857,14 @@ int run(int argc, char* argv[])
qgrad_device_buf
.
SetZero
();
qgrad_device_buf
.
SetZero
();
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
Tensor
<
DataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
Output
DataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
Output
DataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
DataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
Output
DataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
DataType
>
sgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
sgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
pgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
pgrad_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
Input
DataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
DataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
Tensor
<
Input
DataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
...
@@ -854,13 +880,16 @@ int run(int argc, char* argv[])
...
@@ -854,13 +880,16 @@ int run(int argc, char* argv[])
#endif
#endif
// Gradients
// Gradients
auto
ref_gemm_grad
=
ReferenceGemmGradInstance
{};
auto
ref_gemm0_grad
=
ReferenceGemm0GradInstance
{};
auto
ref_gemm_grad_invoker
=
ref_gemm_grad
.
MakeInvoker
();
auto
ref_gemm0_grad_invoker
=
ref_gemm0_grad
.
MakeInvoker
();
using
RefGemmGradArg
=
ReferenceGemmGradInstance
::
Argument
;
using
RefGemm0GradArg
=
ReferenceGemm0GradInstance
::
Argument
;
auto
ref_gemm1_grad
=
ReferenceGemm1GradInstance
{};
auto
ref_gemm1_grad_invoker
=
ref_gemm1_grad
.
MakeInvoker
();
using
RefGemm1GradArg
=
ReferenceGemm1GradInstance
::
Argument
;
// dP_dropout = dY * V^T
// dP_dropout = dY * V^T
auto
v_g_o_n
=
v_g_n_o
.
Transpose
({
0
,
2
,
1
});
auto
v_g_o_n
=
v_g_n_o
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm
0
_grad_invoker
.
Run
(
RefGemm
0
GradArg
{
ygrad_g_m_o
,
v_g_o_n
,
pgrad_drop_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
ygrad_g_m_o
,
v_g_o_n
,
pgrad_drop_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
#if PRINT_HOST
#if PRINT_HOST
{
{
...
@@ -887,7 +916,7 @@ int run(int argc, char* argv[])
...
@@ -887,7 +916,7 @@ int run(int argc, char* argv[])
ygrad_dot_y
+=
ck
::
type_convert
<
AccDataType
>
(
ygrad_g_m_o
(
idx_gmo
))
*
ygrad_dot_y
+=
ck
::
type_convert
<
AccDataType
>
(
ygrad_g_m_o
(
idx_gmo
))
*
ck
::
type_convert
<
AccDataType
>
(
y_g_m_o
(
idx_gmo
));
ck
::
type_convert
<
AccDataType
>
(
y_g_m_o
(
idx_gmo
));
}
}
self
(
idx_gmn
)
=
ck
::
type_convert
<
DataType
>
(
self
(
idx_gmn
)
=
ck
::
type_convert
<
Input
DataType
>
(
ck
::
type_convert
<
AccDataType
>
(
p_g_m_n
(
idx_gmn
))
*
ck
::
type_convert
<
AccDataType
>
(
p_g_m_n
(
idx_gmn
))
*
(
ck
::
type_convert
<
AccDataType
>
(
pgrad_g_m_n
(
idx_gmn
))
-
ygrad_dot_y
));
(
ck
::
type_convert
<
AccDataType
>
(
pgrad_g_m_n
(
idx_gmn
))
-
ygrad_dot_y
));
});
});
...
@@ -903,7 +932,7 @@ int run(int argc, char* argv[])
...
@@ -903,7 +932,7 @@ int run(int argc, char* argv[])
#endif
#endif
// dV = P_drop^T * dY
// dV = P_drop^T * dY
auto
p_drop_g_n_m
=
p_drop_g_m_n
.
Transpose
({
0
,
2
,
1
});
auto
p_drop_g_n_m
=
p_drop_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm
1
_grad_invoker
.
Run
(
RefGemm
1
GradArg
{
p_drop_g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.0
f
}});
p_drop_g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.0
f
}});
#if PRINT_HOST
#if PRINT_HOST
{
{
...
@@ -915,7 +944,7 @@ int run(int argc, char* argv[])
...
@@ -915,7 +944,7 @@ int run(int argc, char* argv[])
#endif
#endif
// dQ = alpha * dS * K
// dQ = alpha * dS * K
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm
1
_grad_invoker
.
Run
(
RefGemm
1
GradArg
{
sgrad_g_m_n
,
k_g_n_k
,
qgrad_g_m_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
sgrad_g_m_n
,
k_g_n_k
,
qgrad_g_m_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
#if PRINT_HOST
#if PRINT_HOST
{
{
...
@@ -928,7 +957,7 @@ int run(int argc, char* argv[])
...
@@ -928,7 +957,7 @@ int run(int argc, char* argv[])
// dK = alpha * dS^T * Q
// dK = alpha * dS^T * Q
auto
sgrad_g_n_m
=
sgrad_g_m_n
.
Transpose
({
0
,
2
,
1
});
auto
sgrad_g_n_m
=
sgrad_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm
1
_grad_invoker
.
Run
(
RefGemm
1
GradArg
{
sgrad_g_n_m
,
q_g_m_k
,
kgrad_g_n_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
sgrad_g_n_m
,
q_g_m_k
,
kgrad_g_n_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
#if PRINT_HOST
#if PRINT_HOST
{
{
...
@@ -939,13 +968,13 @@ int run(int argc, char* argv[])
...
@@ -939,13 +968,13 @@ int run(int argc, char* argv[])
}
}
#endif
#endif
Tensor
<
DataType
>
qgrad_gs_ms_ks_host_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
Output
DataType
>
qgrad_gs_ms_ks_host_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
kgrad_gs_ns_ks_host_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
Output
DataType
>
kgrad_gs_ns_ks_host_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
DataType
>
vgrad_gs_os_ns_host_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
Output
DataType
>
vgrad_gs_os_ns_host_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
DataType
>
qgrad_gs_ms_ks_device_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
Output
DataType
>
qgrad_gs_ms_ks_device_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
kgrad_gs_ns_ks_device_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
Output
DataType
>
kgrad_gs_ns_ks_device_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
DataType
>
vgrad_gs_os_ns_device_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
Output
DataType
>
vgrad_gs_os_ns_device_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
qgrad_device_buf
.
FromDevice
(
qgrad_gs_ms_ks_device_result
.
mData
.
data
());
qgrad_device_buf
.
FromDevice
(
qgrad_gs_ms_ks_device_result
.
mData
.
data
());
kgrad_device_buf
.
FromDevice
(
kgrad_gs_ns_ks_device_result
.
mData
.
data
());
kgrad_device_buf
.
FromDevice
(
kgrad_gs_ns_ks_device_result
.
mData
.
data
());
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
View file @
1a306e0d
...
@@ -28,7 +28,8 @@ namespace tensor_operation {
...
@@ -28,7 +28,8 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
DataType
,
typename
InputDataType
,
typename
OutputDataType
,
typename
ZDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
...
@@ -46,22 +47,23 @@ template <typename GridwiseGemm,
...
@@ -46,22 +47,23 @@ template <typename GridwiseGemm,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
typename
C0MatrixMask
,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
,
bool
Deterministic
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
#endif
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1
(
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1
(
const
DataType
*
__restrict__
p_a_grid
,
const
Input
DataType
*
__restrict__
p_a_grid
,
const
DataType
*
__restrict__
p_b_grid
,
const
Input
DataType
*
__restrict__
p_b_grid
,
ZDataType
*
__restrict__
p_z_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
DataType
*
__restrict__
p_b1_grid
,
const
Input
DataType
*
__restrict__
p_b1_grid
,
const
DataType
*
__restrict__
p_c_grid
,
const
Input
DataType
*
__restrict__
p_c_grid
,
const
LSEDataType
*
__restrict__
p_lse_grid
,
const
LSEDataType
*
__restrict__
p_lse_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
const
Input
DataType
*
__restrict__
p_ygrad_grid
,
DataType
*
__restrict__
p_qgrad_grid
,
Output
DataType
*
__restrict__
p_qgrad_grid
,
DataType
*
__restrict__
p_kgrad_grid
,
Output
DataType
*
__restrict__
p_kgrad_grid
,
DataType
*
__restrict__
p_vgrad_grid
,
Output
DataType
*
__restrict__
p_vgrad_grid
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
AccElementwiseOperation
acc_element_op
,
const
AccElementwiseOperation
acc_element_op
,
...
@@ -78,6 +80,7 @@ __global__ void
...
@@ -78,6 +80,7 @@ __global__ void
const
YGradGridDesc_O0_M_O1
ygrad_grid_desc_o0_m_o1
,
const
YGradGridDesc_O0_M_O1
ygrad_grid_desc_o0_m_o1
,
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
index_t
nblock
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
const
C0MatrixMask
c0_matrix_mask
,
const
float
p_drop
,
const
float
p_drop
,
...
@@ -109,33 +112,72 @@ __global__ void
...
@@ -109,33 +112,72 @@ __global__ void
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
ZDataType
*
z_matrix_ptr
=
(
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
);
ZDataType
*
z_matrix_ptr
=
(
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
if
constexpr
(
Deterministic
)
p_b_grid
+
b_batch_offset
,
{
z_matrix_ptr
,
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
p_b1_grid
+
b1_batch_offset
,
{
p_c_grid
+
c_batch_offset
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_lse_grid
+
lse_batch_offset
,
p_a_grid
+
a_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
z_matrix_ptr
,
p_kgrad_grid
+
b_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_shared
,
p_lse_grid
+
lse_batch_offset
,
a_element_op
,
p_ygrad_grid
+
c_batch_offset
,
b_element_op
,
p_qgrad_grid
+
a_batch_offset
,
acc_element_op
,
p_kgrad_grid
+
b_batch_offset
,
b1_element_op
,
p_vgrad_grid
+
b1_batch_offset
,
c_element_op
,
p_shared
,
a_grid_desc_ak0_m_ak1
,
a_element_op
,
b_grid_desc_bk0_n_bk1
,
b_element_op
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
acc_element_op
,
b1_grid_desc_bk0_n_bk1
,
b1_element_op
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_element_op
,
lse_grid_desc_m
,
a_grid_desc_ak0_m_ak1
,
ygrad_grid_desc_o0_m_o1
,
b_grid_desc_bk0_n_bk1
,
block_2_ctile_map
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
c0_matrix_mask
,
b1_grid_desc_bk0_n_bk1
,
p_drop
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
ph
);
lse_grid_desc_m
,
ygrad_grid_desc_o0_m_o1
,
block_2_ctile_map
,
c0_matrix_mask
,
p_drop
,
ph
,
i
);
}
}
else
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
ygrad_grid_desc_o0_m_o1
,
block_2_ctile_map
,
c0_matrix_mask
,
p_drop
,
ph
,
0
);
}
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
...
@@ -168,7 +210,8 @@ template <index_t NumDimG,
...
@@ -168,7 +210,8 @@ template <index_t NumDimG,
index_t
NumDimN
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimK
,
index_t
NumDimO
,
// NumDimGemm1N
index_t
NumDimO
,
// NumDimGemm1N
typename
DataType
,
typename
InputDataType
,
typename
OutputDataType
,
typename
GemmDataType
,
typename
GemmDataType
,
typename
ZDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
...
@@ -221,6 +264,7 @@ template <index_t NumDimG,
...
@@ -221,6 +264,7 @@ template <index_t NumDimG,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
...
@@ -587,7 +631,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -587,7 +631,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
DataType
,
// TODO: distinguish A/B datatype
InputDataType
,
// TODO: distinguish A/B datatype
OutputDataType
,
ZDataType
,
GemmDataType
,
GemmDataType
,
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
...
@@ -643,22 +689,23 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -643,22 +689,23 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
CShuffleBlockTransferScalarPerVector_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
,
Deterministic
>
;
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
Argument
(
const
DataType
*
p_a_grid
,
const
Input
DataType
*
p_a_grid
,
const
DataType
*
p_b_grid
,
const
Input
DataType
*
p_b_grid
,
ZDataType
*
p_z_grid
,
ZDataType
*
p_z_grid
,
const
DataType
*
p_b1_grid
,
const
Input
DataType
*
p_b1_grid
,
const
DataType
*
p_c_grid
,
// for dS
const
Input
DataType
*
p_c_grid
,
// for dS
const
LSEDataType
*
p_lse_grid
,
const
LSEDataType
*
p_lse_grid
,
const
DataType
*
p_ygrad_grid
,
const
Input
DataType
*
p_ygrad_grid
,
DataType
*
p_qgrad_grid
,
Output
DataType
*
p_qgrad_grid
,
DataType
*
p_kgrad_grid
,
Output
DataType
*
p_kgrad_grid
,
DataType
*
p_vgrad_grid
,
Output
DataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
...
@@ -802,16 +849,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -802,16 +849,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
}
// pointers
// pointers
const
DataType
*
p_a_grid_
;
const
Input
DataType
*
p_a_grid_
;
const
DataType
*
p_b_grid_
;
const
Input
DataType
*
p_b_grid_
;
ZDataType
*
p_z_grid_
;
ZDataType
*
p_z_grid_
;
const
DataType
*
p_b1_grid_
;
const
Input
DataType
*
p_b1_grid_
;
const
DataType
*
p_c_grid_
;
const
Input
DataType
*
p_c_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
DataType
*
p_ygrad_grid_
;
const
Input
DataType
*
p_ygrad_grid_
;
DataType
*
p_qgrad_grid_
;
Output
DataType
*
p_qgrad_grid_
;
DataType
*
p_kgrad_grid_
;
Output
DataType
*
p_kgrad_grid_
;
DataType
*
p_vgrad_grid_
;
Output
DataType
*
p_vgrad_grid_
;
// tensor descriptor
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
...
@@ -876,14 +923,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -876,14 +923,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
}
const
index_t
grid_size
=
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
k_grid_desc_n_k_
)
*
arg
.
batch_count_
;
(
Deterministic
?
1
:
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
k_grid_desc_n_k_
))
*
arg
.
batch_count_
;
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1
<
const
auto
kernel
=
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1
<
GridwiseGemm
,
GridwiseGemm
,
DataType
,
InputDataType
,
OutputDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
...
@@ -901,42 +951,45 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -901,42 +951,45 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
C0MatrixMask
,
has_main_k_block_loop_
>
;
has_main_k_block_loop_
,
Deterministic
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
return
launch_and_time_kernel
(
dim3
(
grid_size
),
stream_config
,
dim3
(
BlockSize
),
kernel
,
0
,
dim3
(
grid_size
),
arg
.
p_a_grid_
,
dim3
(
BlockSize
),
arg
.
p_b_grid_
,
0
,
arg
.
p_z_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_z_grid_
,
arg
.
p_lse_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_ygrad_grid_
,
arg
.
p_c_grid_
,
arg
.
p_qgrad_grid_
,
arg
.
p_lse_grid_
,
arg
.
p_kgrad_grid_
,
arg
.
p_ygrad_grid_
,
arg
.
p_vgrad_grid_
,
arg
.
p_qgrad_grid_
,
arg
.
a_element_op_
,
arg
.
p_kgrad_grid_
,
arg
.
b_element_op_
,
arg
.
p_vgrad_grid_
,
arg
.
acc_element_op_
,
arg
.
a_element_op_
,
arg
.
b1_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
acc_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b1_element_op_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_element_op_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
lse_grid_desc_m_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
ygrad_grid_desc_o0_m_o1_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg
.
block_2_ctile_map_
,
arg
.
lse_grid_desc_m_
,
arg
.
batch_count_
,
arg
.
ygrad_grid_desc_o0_m_o1_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
block_2_ctile_map_
,
arg
.
c0_matrix_mask_
,
arg
.
batch_count_
,
arg
.
p_drop_
,
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
k_grid_desc_n_k_
),
arg
.
seed_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
offset_
);
arg
.
c0_matrix_mask_
,
arg
.
p_drop_
,
arg
.
seed_
,
arg
.
offset_
);
};
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
...
@@ -1041,16 +1094,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1041,16 +1094,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
}
static
auto
MakeArgument
(
static
auto
MakeArgument
(
const
DataType
*
p_a
,
const
Input
DataType
*
p_a
,
const
DataType
*
p_b
,
const
Input
DataType
*
p_b
,
ZDataType
*
p_z
,
ZDataType
*
p_z
,
const
DataType
*
p_b1
,
const
Input
DataType
*
p_b1
,
const
DataType
*
p_c
,
const
Input
DataType
*
p_c
,
const
LSEDataType
*
p_lse
,
const
LSEDataType
*
p_lse
,
const
DataType
*
p_ygrad_grid
,
const
Input
DataType
*
p_ygrad_grid
,
DataType
*
p_qgrad_grid
,
Output
DataType
*
p_qgrad_grid
,
DataType
*
p_kgrad_grid
,
Output
DataType
*
p_kgrad_grid
,
DataType
*
p_vgrad_grid
,
Output
DataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
...
@@ -1156,16 +1209,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1156,16 +1209,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
float
p_drop
,
float
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
DataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
Input
DataType
*>
(
p_a
),
static_cast
<
const
DataType
*>
(
p_b
),
static_cast
<
const
Input
DataType
*>
(
p_b
),
static_cast
<
ZDataType
*>
(
p_z
),
static_cast
<
ZDataType
*>
(
p_z
),
static_cast
<
const
DataType
*>
(
p_b1
),
static_cast
<
const
Input
DataType
*>
(
p_b1
),
static_cast
<
const
DataType
*>
(
p_c
),
static_cast
<
const
Input
DataType
*>
(
p_c
),
static_cast
<
const
LSEDataType
*>
(
p_lse
),
static_cast
<
const
LSEDataType
*>
(
p_lse
),
static_cast
<
const
DataType
*>
(
p_ygrad_grid
),
static_cast
<
const
Input
DataType
*>
(
p_ygrad_grid
),
static_cast
<
DataType
*>
(
p_qgrad_grid
),
static_cast
<
Output
DataType
*>
(
p_qgrad_grid
),
static_cast
<
DataType
*>
(
p_kgrad_grid
),
static_cast
<
Output
DataType
*>
(
p_kgrad_grid
),
static_cast
<
DataType
*>
(
p_vgrad_grid
),
static_cast
<
Output
DataType
*>
(
p_vgrad_grid
),
p_acc0_biases
,
// cast in struct Argument
p_acc0_biases
,
// cast in struct Argument
p_acc1_biases
,
// cast in struct Argument
p_acc1_biases
,
// cast in struct Argument
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt4.hpp
View file @
1a306e0d
...
@@ -20,7 +20,9 @@
...
@@ -20,7 +20,9 @@
namespace
ck
{
namespace
ck
{
template
<
typename
DataType
,
template
<
typename
InputDataType
,
typename
OutputDataType
,
typename
ZDataType
,
typename
GemmDataType
,
typename
GemmDataType
,
typename
FloatGemmAcc
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatCShuffle
,
...
@@ -77,6 +79,7 @@ template <typename DataType,
...
@@ -77,6 +79,7 @@ template <typename DataType,
LoopScheduler
LoopSched
,
LoopScheduler
LoopSched
,
bool
PadN
,
bool
PadN
,
bool
MaskOutUpperTriangle
,
bool
MaskOutUpperTriangle
,
bool
Deterministic
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{
{
...
@@ -424,7 +427,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -424,7 +427,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GemmDataType
,
GridDesc_K0_M_K1
,
GridDesc_K0_M_K1
,
decltype
(
q_block_desc_k0_m_k1
),
decltype
(
q_block_desc_k0_m_k1
),
...
@@ -449,7 +452,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -449,7 +452,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GemmDataType
,
GridDesc_K0_N_K1
,
GridDesc_K0_N_K1
,
decltype
(
k_block_desc_k0_n_k1
),
decltype
(
k_block_desc_k0_n_k1
),
...
@@ -474,7 +477,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -474,7 +477,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GemmDataType
,
GridDesc_K0_N_K1
,
GridDesc_K0_N_K1
,
decltype
(
v_block_desc_k0_n_k1
),
decltype
(
v_block_desc_k0_n_k1
),
...
@@ -499,7 +502,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -499,7 +502,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GemmDataType
,
GridDesc_K0_M_K1
,
GridDesc_K0_M_K1
,
decltype
(
ygrad_block_desc_k0_m_k1
),
decltype
(
ygrad_block_desc_k0_m_k1
),
...
@@ -1080,7 +1083,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1080,7 +1083,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
using
CBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
using
CBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatGemmAcc
,
DataType
,
Output
DataType
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
CGridDesc_M0_N0_M1_N1_M2_N2_N3_N4
,
CGridDesc_M0_N0_M1_N1_M2_N2_N3_N4
,
ElementwiseOp
,
// CElementwiseOperation
ElementwiseOp
,
// CElementwiseOperation
...
@@ -1096,7 +1099,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1096,7 +1099,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
struct
YDotYGrad_M_O_
{
{
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
DataType
);
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
Input
DataType
);
static
constexpr
auto
ThreadClusterLength_O
=
static
constexpr
auto
ThreadClusterLength_O
=
Number
<
BlockSliceLength_O_
/
SrcScalarPerVector
>
{};
Number
<
BlockSliceLength_O_
/
SrcScalarPerVector
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
...
@@ -1213,16 +1216,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1213,16 +1216,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename
Block2CTileMap
,
typename
Block2CTileMap
,
typename
C0MatrixMask
,
typename
C0MatrixMask
,
typename
YGradGridDesc_O0_M_O1
>
typename
YGradGridDesc_O0_M_O1
>
__device__
static
void
Run
(
const
DataType
*
__restrict__
p_q_grid
,
__device__
static
void
Run
(
const
Input
DataType
*
__restrict__
p_q_grid
,
const
DataType
*
__restrict__
p_k_grid
,
const
Input
DataType
*
__restrict__
p_k_grid
,
unsigned
short
*
__restrict__
p_z_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
DataType
*
__restrict__
p_v_grid
,
const
Input
DataType
*
__restrict__
p_v_grid
,
const
DataType
*
__restrict__
p_y_grid
,
const
Input
DataType
*
__restrict__
p_y_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
const
Input
DataType
*
__restrict__
p_ygrad_grid
,
DataType
*
__restrict__
p_qgrad_grid
,
Output
DataType
*
__restrict__
p_qgrad_grid
,
DataType
*
__restrict__
p_kgrad_grid
,
Output
DataType
*
__restrict__
p_kgrad_grid
,
DataType
*
__restrict__
p_vgrad_grid
,
Output
DataType
*
__restrict__
p_vgrad_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
...
@@ -1241,7 +1244,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1241,7 +1244,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const
Block2CTileMap
&
block_2_ctile_map
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
float
p_drop
,
const
float
p_drop
,
ck
::
philox
&
ph
)
ck
::
philox
&
ph
,
const
index_t
block_idx_n
)
{
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
...
@@ -1273,9 +1277,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1273,9 +1277,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const
auto
block_work_idx
=
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
const
index_t
block_work_idx_n
=
Deterministic
?
block_idx_n
:
block_work_idx
[
I0
];
// HACK: this force n_block_data_idx_on_grid into SGPR
// HACK: this force n_block_data_idx_on_grid into SGPR
const
index_t
n_block_data_idx_on_grid
=
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
NPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
_n
*
NPerBlock
);
// 6 GEMM operations are categorized into 3 buckets. SizeK == SizeO == head_dim
// 6 GEMM operations are categorized into 3 buckets. SizeK == SizeO == head_dim
// S_MNK / dP_MNO Gemm (Gemm0 rcr)
// S_MNK / dP_MNO Gemm (Gemm0 rcr)
...
@@ -1605,7 +1611,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1605,7 +1611,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ushort
,
ushort
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -1625,15 +1631,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1625,15 +1631,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
block_work_idx
[
I0
]
,
// MBlockId
make_multi_index
(
block_work_idx
_n
,
// MBlockId
0
,
// NBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// mrepeat
0
,
// nrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// group
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
wave_m_n_id
[
I0
],
// NInputIndex
0
),
0
),
tensor_operation
::
element_wise
::
PassThrough
{}};
tensor_operation
::
element_wise
::
PassThrough
{}};
...
@@ -1678,7 +1684,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1678,7 +1684,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// performs for y
// performs for y
auto
y_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
auto
y_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
Input
DataType
,
FloatGemmAcc
,
FloatGemmAcc
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
decltype
(
y_thread_desc_m0_m1_o0_o1
),
decltype
(
y_thread_desc_m0_m1_o0_o1
),
...
@@ -1743,6 +1749,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1743,6 +1749,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const
index_t
num_gemm0_m_block_outer_loop
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
)
/
MPerBlock
;
const
index_t
num_gemm0_m_block_outer_loop
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
)
/
MPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
MPerBlock
/
Gemm1KPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
MPerBlock
/
Gemm1KPerBlock
;
if
constexpr
(
Deterministic
)
{
block_sync_lds
();
}
// Initialize dK&dV
// Initialize dK&dV
kgrad_thread_buf
.
Clear
();
kgrad_thread_buf
.
Clear
();
vgrad_thread_buf
.
Clear
();
vgrad_thread_buf
.
Clear
();
...
@@ -2286,7 +2297,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -2286,7 +2297,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
FloatCShuffle
,
// typename SrcData,
DataType
,
// typename DstData,
Output
DataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
vgrad_grid_desc_nblock_nperblock_oblock_operblock
),
decltype
(
vgrad_grid_desc_nblock_nperblock_oblock_operblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
...
@@ -2297,7 +2308,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -2297,7 +2308,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
),
vgrad_grid_desc_nblock_nperblock_oblock_operblock
,
vgrad_grid_desc_nblock_nperblock_oblock_operblock
,
make_multi_index
(
block_work_idx
[
I0
]
,
0
,
block_work_idx
[
I1
],
0
),
make_multi_index
(
block_work_idx
_n
,
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
c_element_op
};
// shuffle: threadwise copy C from VGPR to LDS
// shuffle: threadwise copy C from VGPR to LDS
...
@@ -2344,7 +2355,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -2344,7 +2355,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
FloatCShuffle
,
// typename SrcData,
DataType
,
// typename DstData,
Output
DataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
kgrad_grid_desc_nblock_nperblock_oblock_operblock
),
decltype
(
kgrad_grid_desc_nblock_nperblock_oblock_operblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
...
@@ -2355,7 +2366,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -2355,7 +2366,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
),
kgrad_grid_desc_nblock_nperblock_oblock_operblock
,
kgrad_grid_desc_nblock_nperblock_oblock_operblock
,
make_multi_index
(
block_work_idx
[
I0
]
,
0
,
block_work_idx
[
I1
],
0
),
make_multi_index
(
block_work_idx
_n
,
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
c_element_op
};
// space filling curve for threadwise C in VGPR
// space filling curve for threadwise C in VGPR
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment