Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f3e61c0a
"...git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "90e753c97ab4fb8920fe6ddc4a3dc0572a097b47"
Commit
f3e61c0a
authored
Apr 13, 2023
by
danyao12
Browse files
datatype of bwd output can be selected
parent
f7e05f9e
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
441 additions
and
387 deletions
+441
-387
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward.cpp
...ale_softmax_gemm/batched_multihead_attention_backward.cpp
+134
-114
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward.cpp
...ale_softmax_gemm/grouped_multihead_attention_backward.cpp
+135
-115
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
+48
-44
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
+48
-44
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp
+20
-18
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
+20
-18
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
+18
-17
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt2.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt2.hpp
+18
-17
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward.cpp
View file @
f3e61c0a
...
@@ -62,8 +62,9 @@ using Scale = ck::tensor_operation::element_wise::Scale;
...
@@ -62,8 +62,9 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using
QKVElementOp
=
PassThrough
;
using
QKVElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
DataType
=
F16
;
using
InputDataType
=
BF16
;
using
GemmDataType
=
F16
;
using
OutputDataType
=
F32
;
using
GemmDataType
=
BF16
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
...
@@ -103,7 +104,8 @@ using DeviceGemmInstance =
...
@@ -103,7 +104,8 @@ using DeviceGemmInstance =
NumDimN
,
NumDimN
,
NumDimK
,
NumDimK
,
NumDimO
,
NumDimO
,
DataType
,
InputDataType
,
OutputDataType
,
GemmDataType
,
GemmDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
...
@@ -161,7 +163,7 @@ using DeviceGemmInstance =
...
@@ -161,7 +163,7 @@ using DeviceGemmInstance =
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
...
@@ -171,7 +173,8 @@ using DeviceGemmInstance =
...
@@ -171,7 +173,8 @@ using DeviceGemmInstance =
NumDimN
,
NumDimN
,
NumDimK
,
NumDimK
,
NumDimO
,
NumDimO
,
DataType
,
InputDataType
,
OutputDataType
,
GemmDataType
,
GemmDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
...
@@ -229,7 +232,7 @@ using DeviceGemmInstance =
...
@@ -229,7 +232,7 @@ using DeviceGemmInstance =
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
// using DeviceGemmInstance =
// using DeviceGemmInstance =
...
@@ -239,7 +242,8 @@ using DeviceGemmInstance =
...
@@ -239,7 +242,8 @@ using DeviceGemmInstance =
// NumDimN,
// NumDimN,
// NumDimK,
// NumDimK,
// NumDimO,
// NumDimO,
// DataType,
// InputDataType,
// OutputDataType,
// GemmDataType,
// GemmDataType,
// ZDataType,
// ZDataType,
// LSEDataType,
// LSEDataType,
...
@@ -297,7 +301,7 @@ using DeviceGemmInstance =
...
@@ -297,7 +301,7 @@ using DeviceGemmInstance =
// 1, // CShuffleMXdlPerWavePerShuffle
// 1, // CShuffleMXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
//
8
, // CShuffleBlockTransferScalarPerVector_NPerBlock
//
4
, // CShuffleBlockTransferScalarPerVector_NPerBlock
// MaskingSpec>; // MaskingSpecialization
// MaskingSpec>; // MaskingSpecialization
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
...
@@ -307,7 +311,8 @@ using DeviceGemmInstance =
...
@@ -307,7 +311,8 @@ using DeviceGemmInstance =
NumDimN
,
NumDimN
,
NumDimK
,
NumDimK
,
NumDimO
,
NumDimO
,
DataType
,
InputDataType
,
OutputDataType
,
GemmDataType
,
GemmDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
...
@@ -365,14 +370,14 @@ using DeviceGemmInstance =
...
@@ -365,14 +370,14 @@ using DeviceGemmInstance =
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
4
,
// CShuffleNXdlPerWavePerShuffle
4
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
#endif
#endif
// Ref Gemm0: S = alpha * Q * K^T
// Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out
// fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
DataType
,
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
Input
DataType
,
DataType
,
Input
DataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
...
@@ -382,13 +387,13 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -382,13 +387,13 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
// Ref Softmax: P = Softmax(S)
// Ref Softmax: P = Softmax(S)
// fp32 in, fp16 out
// fp32 in, fp16 out
using
ReferenceSoftmaxInstance
=
using
ReferenceSoftmaxInstance
=
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
DataType
,
AccDataType
>
;
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
Input
DataType
,
AccDataType
>
;
// Ref Gemm1: Y = P * V
// Ref Gemm1: Y = P * V
// fp16 in, fp16 out
// fp16 in, fp16 out
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
DataType
,
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
Input
DataType
,
DataType
,
Input
DataType
,
DataType
,
Input
DataType
,
AccDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
...
@@ -396,16 +401,25 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -396,16 +401,25 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
// Ref Gemm for backward pass
// Ref Gemm for backward pass
// fp16 in, fp16 out
// fp16 in, fp16 out
using
ReferenceGemmGradInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
DataType
,
using
ReferenceGemm0GradInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
InputDataType
,
DataType
,
InputDataType
,
DataType
,
InputDataType
,
AccDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
>
;
Scale
>
;
using
ReferenceGemm1GradInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
InputDataType
,
InputDataType
,
OutputDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
Scale
>
;
// Ref dropout
// Ref dropout
using
ReferenceDropoutInstance
=
using
ReferenceDropoutInstance
=
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ushort
,
DataType
,
DataType
>
;
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ushort
,
Input
DataType
,
Input
DataType
>
;
template
<
typename
TensorQ
,
template
<
typename
TensorQ
,
typename
TensorK
,
typename
TensorK
,
...
@@ -482,8 +496,8 @@ int run(int argc, char* argv[])
...
@@ -482,8 +496,8 @@ int run(int argc, char* argv[])
ck
::
index_t
N
=
512
;
ck
::
index_t
N
=
512
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
5
4
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G1
=
1
6
;
ck
::
index_t
G1
=
6
;
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
...
@@ -592,12 +606,12 @@ int run(int argc, char* argv[])
...
@@ -592,12 +606,12 @@ int run(int argc, char* argv[])
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
Tensor
<
DataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
Input
DataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
Input
DataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
DataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
Input
DataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
DataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
Input
DataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
DataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
Input
DataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
...
@@ -607,45 +621,45 @@ int run(int argc, char* argv[])
...
@@ -607,45 +621,45 @@ int run(int argc, char* argv[])
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
0
});
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
0
});
switch
(
init_method
)
switch
(
init_method
)
{
{
case
0
:
break
;
case
0
:
break
;
case
1
:
case
1
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
2
,
2
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
2
,
2
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
2
,
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
2
,
2
});
break
;
break
;
case
2
:
case
2
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
0.0
,
1.0
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
Input
DataType
>
{
0.0
,
1.0
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
0.0
,
1.0
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
Input
DataType
>
{
0.0
,
1.0
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
-
0.5
,
0.5
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
Input
DataType
>
{
-
0.5
,
0.5
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
-
0.5
,
0.5
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_3
<
Input
DataType
>
{
-
0.5
,
0.5
});
break
;
break
;
case
3
:
case
3
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
5
,
5
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
5
,
5
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
break
;
break
;
case
4
:
case
4
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
break
;
break
;
case
5
:
case
5
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
// dO dot O = [0; 1; 2; ...]
// dO dot O = [0; 1; 2; ...]
break
;
break
;
case
6
:
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
...
@@ -656,10 +670,10 @@ int run(int argc, char* argv[])
...
@@ -656,10 +670,10 @@ int run(int argc, char* argv[])
//
//
break
;
break
;
default:
default:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
// dy[g0, g1, m, o]
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// O = P V = 0.0039 * ones
...
@@ -670,14 +684,14 @@ int run(int argc, char* argv[])
...
@@ -670,14 +684,14 @@ int run(int argc, char* argv[])
// = 0
// = 0
}
}
Tensor
<
DataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
Input
DataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
Input
DataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
Input
DataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
Input
DataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
q_gs_ms_ks
.
ForEach
(
q_gs_ms_ks
.
ForEach
(
...
@@ -688,16 +702,16 @@ int run(int argc, char* argv[])
...
@@ -688,16 +702,16 @@ int run(int argc, char* argv[])
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
// qkv gradients have the same descriptor as with qkv
// qkv gradients have the same descriptor as with qkv
DeviceMem
q_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
q_device_buf
(
sizeof
(
Input
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
k_device_buf
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
k_device_buf
(
sizeof
(
Input
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
z_device_buf
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
z_device_buf
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
v_device_buf
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
v_device_buf
(
sizeof
(
Input
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
Input
DataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
qgrad_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
qgrad_device_buf
(
sizeof
(
Output
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
Output
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
Output
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
Input
DataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
...
@@ -710,16 +724,16 @@ int run(int argc, char* argv[])
...
@@ -710,16 +724,16 @@ int run(int argc, char* argv[])
// get z matrix
// get z matrix
{
{
auto
argument
=
gemm
.
MakeArgument
(
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
z_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
z_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Output
DataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Output
DataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Output
DataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
{},
// std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths
,
q_gs_ms_ks_lengths
,
...
@@ -755,16 +769,16 @@ int run(int argc, char* argv[])
...
@@ -755,16 +769,16 @@ int run(int argc, char* argv[])
}
}
// not need output z matrix
// not need output z matrix
auto
argument
=
gemm
.
MakeArgument
(
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
nullptr
),
// set to nullptr
static_cast
<
ZDataType
*>
(
nullptr
),
// set to nullptr
static_cast
<
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Output
DataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Output
DataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Output
DataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
{},
// std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths
,
q_gs_ms_ks_lengths
,
...
@@ -800,10 +814,13 @@ int run(int argc, char* argv[])
...
@@ -800,10 +814,13 @@ int run(int argc, char* argv[])
// 3x MNK + 2x MNO
// 3x MNK + 2x MNO
std
::
size_t
flop
=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
std
::
size_t
flop
=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
// Q/K/V/Y, dQ/dK/dV/dY, LSE
std
::
size_t
num_btype
=
(
sizeof
(
DataType
)
*
M
*
K
+
sizeof
(
DataType
)
*
K
*
N
+
std
::
size_t
num_btype
=
sizeof
(
DataType
)
*
N
*
O
+
sizeof
(
DataType
)
*
M
*
O
)
*
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
size_t
(
2
)
*
BatchCount
+
sizeof
(
InputDataType
)
*
N
*
O
+
sizeof
(
InputDataType
)
*
M
*
O
*
size_t
(
2
)
+
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
sizeof
(
OutputDataType
)
*
M
*
K
+
sizeof
(
OutputDataType
)
*
K
*
N
+
sizeof
(
OutputDataType
)
*
N
*
O
)
*
BatchCount
+
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
@@ -847,14 +864,14 @@ int run(int argc, char* argv[])
...
@@ -847,14 +864,14 @@ int run(int argc, char* argv[])
vgrad_device_buf
.
SetZero
();
vgrad_device_buf
.
SetZero
();
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
Tensor
<
DataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
Output
DataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
Output
DataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
DataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
Output
DataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
DataType
>
sgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
sgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
pgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
pgrad_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
Input
DataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
DataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
Tensor
<
Input
DataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
...
@@ -870,13 +887,16 @@ int run(int argc, char* argv[])
...
@@ -870,13 +887,16 @@ int run(int argc, char* argv[])
#endif
#endif
// Gradients
// Gradients
auto
ref_gemm_grad
=
ReferenceGemmGradInstance
{};
auto
ref_gemm0_grad
=
ReferenceGemm0GradInstance
{};
auto
ref_gemm_grad_invoker
=
ref_gemm_grad
.
MakeInvoker
();
auto
ref_gemm0_grad_invoker
=
ref_gemm0_grad
.
MakeInvoker
();
using
RefGemmGradArg
=
ReferenceGemmGradInstance
::
Argument
;
using
RefGemm0GradArg
=
ReferenceGemm0GradInstance
::
Argument
;
auto
ref_gemm1_grad
=
ReferenceGemm1GradInstance
{};
auto
ref_gemm1_grad_invoker
=
ref_gemm1_grad
.
MakeInvoker
();
using
RefGemm1GradArg
=
ReferenceGemm1GradInstance
::
Argument
;
// dP_dropout = dY * V^T
// dP_dropout = dY * V^T
auto
v_g_o_n
=
v_g_n_o
.
Transpose
({
0
,
2
,
1
});
auto
v_g_o_n
=
v_g_n_o
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm
0
_grad_invoker
.
Run
(
RefGemm
0
GradArg
{
ygrad_g_m_o
,
v_g_o_n
,
pgrad_drop_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
ygrad_g_m_o
,
v_g_o_n
,
pgrad_drop_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
#if PRINT_HOST
#if PRINT_HOST
{
{
...
@@ -903,7 +923,7 @@ int run(int argc, char* argv[])
...
@@ -903,7 +923,7 @@ int run(int argc, char* argv[])
ygrad_dot_y
+=
ck
::
type_convert
<
AccDataType
>
(
ygrad_g_m_o
(
idx_gmo
))
*
ygrad_dot_y
+=
ck
::
type_convert
<
AccDataType
>
(
ygrad_g_m_o
(
idx_gmo
))
*
ck
::
type_convert
<
AccDataType
>
(
y_g_m_o
(
idx_gmo
));
ck
::
type_convert
<
AccDataType
>
(
y_g_m_o
(
idx_gmo
));
}
}
self
(
idx_gmn
)
=
ck
::
type_convert
<
DataType
>
(
self
(
idx_gmn
)
=
ck
::
type_convert
<
Input
DataType
>
(
ck
::
type_convert
<
AccDataType
>
(
p_g_m_n
(
idx_gmn
))
*
ck
::
type_convert
<
AccDataType
>
(
p_g_m_n
(
idx_gmn
))
*
(
ck
::
type_convert
<
AccDataType
>
(
pgrad_g_m_n
(
idx_gmn
))
-
ygrad_dot_y
));
(
ck
::
type_convert
<
AccDataType
>
(
pgrad_g_m_n
(
idx_gmn
))
-
ygrad_dot_y
));
});
});
...
@@ -919,7 +939,7 @@ int run(int argc, char* argv[])
...
@@ -919,7 +939,7 @@ int run(int argc, char* argv[])
#endif
#endif
// dV = P_drop^T * dY
// dV = P_drop^T * dY
auto
p_drop_g_n_m
=
p_drop_g_m_n
.
Transpose
({
0
,
2
,
1
});
auto
p_drop_g_n_m
=
p_drop_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm
1
_grad_invoker
.
Run
(
RefGemm
1
GradArg
{
p_drop_g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.0
f
}});
p_drop_g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.0
f
}});
#if PRINT_HOST
#if PRINT_HOST
{
{
...
@@ -931,7 +951,7 @@ int run(int argc, char* argv[])
...
@@ -931,7 +951,7 @@ int run(int argc, char* argv[])
#endif
#endif
// dQ = alpha * dS * K
// dQ = alpha * dS * K
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm
1
_grad_invoker
.
Run
(
RefGemm
1
GradArg
{
sgrad_g_m_n
,
k_g_n_k
,
qgrad_g_m_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
sgrad_g_m_n
,
k_g_n_k
,
qgrad_g_m_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
#if PRINT_HOST
#if PRINT_HOST
{
{
...
@@ -944,7 +964,7 @@ int run(int argc, char* argv[])
...
@@ -944,7 +964,7 @@ int run(int argc, char* argv[])
// dK = alpha * dS^T * Q
// dK = alpha * dS^T * Q
auto
sgrad_g_n_m
=
sgrad_g_m_n
.
Transpose
({
0
,
2
,
1
});
auto
sgrad_g_n_m
=
sgrad_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm
1
_grad_invoker
.
Run
(
RefGemm
1
GradArg
{
sgrad_g_n_m
,
q_g_m_k
,
kgrad_g_n_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
sgrad_g_n_m
,
q_g_m_k
,
kgrad_g_n_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
#if PRINT_HOST
#if PRINT_HOST
{
{
...
@@ -955,13 +975,13 @@ int run(int argc, char* argv[])
...
@@ -955,13 +975,13 @@ int run(int argc, char* argv[])
}
}
#endif
#endif
Tensor
<
DataType
>
qgrad_gs_ms_ks_host_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
Output
DataType
>
qgrad_gs_ms_ks_host_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
kgrad_gs_ns_ks_host_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
Output
DataType
>
kgrad_gs_ns_ks_host_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
DataType
>
vgrad_gs_os_ns_host_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
Output
DataType
>
vgrad_gs_os_ns_host_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
DataType
>
qgrad_gs_ms_ks_device_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
Output
DataType
>
qgrad_gs_ms_ks_device_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
kgrad_gs_ns_ks_device_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
Output
DataType
>
kgrad_gs_ns_ks_device_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
DataType
>
vgrad_gs_os_ns_device_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
Output
DataType
>
vgrad_gs_os_ns_device_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
qgrad_device_buf
.
FromDevice
(
qgrad_gs_ms_ks_device_result
.
mData
.
data
());
qgrad_device_buf
.
FromDevice
(
qgrad_gs_ms_ks_device_result
.
mData
.
data
());
kgrad_device_buf
.
FromDevice
(
kgrad_gs_ns_ks_device_result
.
mData
.
data
());
kgrad_device_buf
.
FromDevice
(
kgrad_gs_ns_ks_device_result
.
mData
.
data
());
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward.cpp
View file @
f3e61c0a
...
@@ -61,8 +61,9 @@ using Scale = ck::tensor_operation::element_wise::Scale;
...
@@ -61,8 +61,9 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using
QKVElementOp
=
PassThrough
;
using
QKVElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
DataType
=
F16
;
using
InputDataType
=
BF16
;
using
GemmDataType
=
F16
;
using
OutputDataType
=
F32
;
using
GemmDataType
=
BF16
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
...
@@ -102,7 +103,8 @@ using DeviceGemmInstance =
...
@@ -102,7 +103,8 @@ using DeviceGemmInstance =
NumDimN
,
NumDimN
,
NumDimK
,
NumDimK
,
NumDimO
,
NumDimO
,
DataType
,
InputDataType
,
OutputDataType
,
GemmDataType
,
GemmDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
...
@@ -160,7 +162,7 @@ using DeviceGemmInstance =
...
@@ -160,7 +162,7 @@ using DeviceGemmInstance =
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
...
@@ -170,7 +172,8 @@ using DeviceGemmInstance =
...
@@ -170,7 +172,8 @@ using DeviceGemmInstance =
NumDimN
,
NumDimN
,
NumDimK
,
NumDimK
,
NumDimO
,
NumDimO
,
DataType
,
InputDataType
,
OutputDataType
,
GemmDataType
,
GemmDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
...
@@ -228,7 +231,7 @@ using DeviceGemmInstance =
...
@@ -228,7 +231,7 @@ using DeviceGemmInstance =
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
// using DeviceGemmInstance =
// using DeviceGemmInstance =
...
@@ -238,7 +241,8 @@ using DeviceGemmInstance =
...
@@ -238,7 +241,8 @@ using DeviceGemmInstance =
// NumDimN,
// NumDimN,
// NumDimK,
// NumDimK,
// NumDimO,
// NumDimO,
// DataType,
// InputDataType,
// OutputDataType,
// GemmDataType,
// GemmDataType,
// ZDataType,
// ZDataType,
// LSEDataType,
// LSEDataType,
...
@@ -296,7 +300,7 @@ using DeviceGemmInstance =
...
@@ -296,7 +300,7 @@ using DeviceGemmInstance =
// 1, // CShuffleMXdlPerWavePerShuffle
// 1, // CShuffleMXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
//
8
, // CShuffleBlockTransferScalarPerVector_NPerBlock
//
4
, // CShuffleBlockTransferScalarPerVector_NPerBlock
// MaskingSpec>; // MaskingSpecialization
// MaskingSpec>; // MaskingSpecialization
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
...
@@ -306,7 +310,8 @@ using DeviceGemmInstance =
...
@@ -306,7 +310,8 @@ using DeviceGemmInstance =
NumDimN
,
NumDimN
,
NumDimK
,
NumDimK
,
NumDimO
,
NumDimO
,
DataType
,
InputDataType
,
OutputDataType
,
GemmDataType
,
GemmDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
...
@@ -364,14 +369,14 @@ using DeviceGemmInstance =
...
@@ -364,14 +369,14 @@ using DeviceGemmInstance =
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
4
,
// CShuffleNXdlPerWavePerShuffle
4
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
#endif
#endif
// Ref Gemm0: S = alpha * Q * K^T
// Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out
// fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
DataType
,
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
Input
DataType
,
DataType
,
Input
DataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
...
@@ -381,13 +386,13 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -381,13 +386,13 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
// Ref Softmax: P = Softmax(S)
// Ref Softmax: P = Softmax(S)
// fp32 in, fp16 out
// fp32 in, fp16 out
using
ReferenceSoftmaxInstance
=
using
ReferenceSoftmaxInstance
=
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
DataType
,
AccDataType
>
;
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
Input
DataType
,
AccDataType
>
;
// Ref Gemm1: Y = P * V
// Ref Gemm1: Y = P * V
// fp16 in, fp16 out
// fp16 in, fp16 out
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
DataType
,
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
Input
DataType
,
DataType
,
Input
DataType
,
DataType
,
Input
DataType
,
AccDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
...
@@ -395,16 +400,25 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -395,16 +400,25 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
// Ref Gemm for backward pass
// Ref Gemm for backward pass
// fp16 in, fp16 out
// fp16 in, fp16 out
using
ReferenceGemmGradInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
DataType
,
using
ReferenceGemm0GradInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
InputDataType
,
DataType
,
InputDataType
,
DataType
,
InputDataType
,
AccDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
>
;
Scale
>
;
using
ReferenceGemm1GradInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
InputDataType
,
InputDataType
,
OutputDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
Scale
>
;
// Ref dropout
// Ref dropout
using
ReferenceDropoutInstance
=
using
ReferenceDropoutInstance
=
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ushort
,
DataType
,
DataType
>
;
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ushort
,
Input
DataType
,
Input
DataType
>
;
template
<
typename
TensorQ
,
template
<
typename
TensorQ
,
typename
TensorK
,
typename
TensorK
,
...
@@ -539,26 +553,26 @@ int run(int argc, char* argv[])
...
@@ -539,26 +553,26 @@ int run(int argc, char* argv[])
std
::
vector
<
void
*>
p_vgrad
;
std
::
vector
<
void
*>
p_vgrad
;
std
::
vector
<
const
void
*>
p_ygrad
;
std
::
vector
<
const
void
*>
p_ygrad
;
std
::
vector
<
Tensor
<
DataType
>>
q_g_m_ks
;
std
::
vector
<
Tensor
<
Input
DataType
>>
q_g_m_ks
;
std
::
vector
<
Tensor
<
DataType
>>
k_g_n_ks
;
std
::
vector
<
Tensor
<
Input
DataType
>>
k_g_n_ks
;
std
::
vector
<
Tensor
<
ZDataType
>>
z_g_m_ns
;
std
::
vector
<
Tensor
<
ZDataType
>>
z_g_m_ns
;
std
::
vector
<
Tensor
<
DataType
>>
v_g_n_os
;
std
::
vector
<
Tensor
<
Input
DataType
>>
v_g_n_os
;
std
::
vector
<
Tensor
<
AccDataType
>>
s_g_m_ns
;
std
::
vector
<
Tensor
<
AccDataType
>>
s_g_m_ns
;
std
::
vector
<
Tensor
<
DataType
>>
p_g_m_ns
;
std
::
vector
<
Tensor
<
Input
DataType
>>
p_g_m_ns
;
std
::
vector
<
Tensor
<
DataType
>>
y_g_m_os
;
std
::
vector
<
Tensor
<
Input
DataType
>>
y_g_m_os
;
std
::
vector
<
Tensor
<
LSEDataType
>>
lse_g_ms
;
std
::
vector
<
Tensor
<
LSEDataType
>>
lse_g_ms
;
std
::
vector
<
Tensor
<
DataType
>>
p_drop_g_m_ns
;
std
::
vector
<
Tensor
<
Input
DataType
>>
p_drop_g_m_ns
;
std
::
vector
<
Tensor
<
DataType
>>
q_tensors
;
std
::
vector
<
Tensor
<
Input
DataType
>>
q_tensors
;
std
::
vector
<
Tensor
<
DataType
>>
k_tensors
;
std
::
vector
<
Tensor
<
Input
DataType
>>
k_tensors
;
std
::
vector
<
Tensor
<
DataType
>>
v_tensors
;
std
::
vector
<
Tensor
<
Input
DataType
>>
v_tensors
;
std
::
vector
<
Tensor
<
DataType
>>
y_tensors
;
std
::
vector
<
Tensor
<
Input
DataType
>>
y_tensors
;
std
::
vector
<
Tensor
<
ZDataType
>>
z_tensors
;
std
::
vector
<
Tensor
<
ZDataType
>>
z_tensors
;
std
::
vector
<
Tensor
<
LSEDataType
>>
lse_tensors
;
std
::
vector
<
Tensor
<
LSEDataType
>>
lse_tensors
;
std
::
vector
<
Tensor
<
DataType
>>
qgrad_tensors
;
std
::
vector
<
Tensor
<
Output
DataType
>>
qgrad_tensors
;
std
::
vector
<
Tensor
<
DataType
>>
kgrad_tensors
;
std
::
vector
<
Tensor
<
Output
DataType
>>
kgrad_tensors
;
std
::
vector
<
Tensor
<
DataType
>>
vgrad_tensors
;
std
::
vector
<
Tensor
<
Output
DataType
>>
vgrad_tensors
;
std
::
vector
<
Tensor
<
DataType
>>
ygrad_tensors
;
std
::
vector
<
Tensor
<
Input
DataType
>>
ygrad_tensors
;
std
::
vector
<
DeviceMemPtr
>
q_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
q_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
k_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
k_tensors_device
;
...
@@ -639,17 +653,19 @@ int run(int argc, char* argv[])
...
@@ -639,17 +653,19 @@ int run(int argc, char* argv[])
int
BatchCount
=
G0
*
G1
;
int
BatchCount
=
G0
*
G1
;
flop
+=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
flop
+=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
// Q/K/V/Y, dQ/dK/dV/dY, LSE
num_byte
+=
(
sizeof
(
DataType
)
*
M
*
K
+
sizeof
(
DataType
)
*
K
*
N
+
num_byte
+=
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
sizeof
(
DataType
)
*
N
*
O
+
sizeof
(
DataType
)
*
M
*
O
)
*
sizeof
(
InputDataType
)
*
N
*
O
+
sizeof
(
InputDataType
)
*
M
*
O
*
size_t
(
2
)
+
size_t
(
2
)
*
BatchCount
+
sizeof
(
OutputDataType
)
*
M
*
K
+
sizeof
(
OutputDataType
)
*
K
*
N
+
sizeof
(
OutputDataType
)
*
N
*
O
)
*
BatchCount
+
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
Tensor
<
DataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
Input
DataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
Input
DataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
DataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
Input
DataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
DataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
Input
DataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
DataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
Input
DataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
if
(
i
<
4
)
if
(
i
<
4
)
{
{
...
@@ -660,45 +676,45 @@ int run(int argc, char* argv[])
...
@@ -660,45 +676,45 @@ int run(int argc, char* argv[])
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
}
}
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
0
});
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
0
});
switch
(
init_method
)
switch
(
init_method
)
{
{
case
0
:
break
;
case
0
:
break
;
case
1
:
case
1
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
2
,
2
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
2
,
2
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
2
,
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
2
,
2
});
break
;
break
;
case
2
:
case
2
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
0.0
,
1.0
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
Input
DataType
>
{
0.0
,
1.0
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
0.0
,
1.0
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
Input
DataType
>
{
0.0
,
1.0
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
-
0.5
,
0.5
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
Input
DataType
>
{
-
0.5
,
0.5
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
-
0.5
,
0.5
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_3
<
Input
DataType
>
{
-
0.5
,
0.5
});
break
;
break
;
case
3
:
case
3
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
5
,
5
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
5
,
5
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
break
;
break
;
case
4
:
case
4
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
2
});
break
;
break
;
case
5
:
case
5
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
// dO dot O = [0; 1; 2; ...]
// dO dot O = [0; 1; 2; ...]
break
;
break
;
case
6
:
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
...
@@ -709,10 +725,11 @@ int run(int argc, char* argv[])
...
@@ -709,10 +725,11 @@ int run(int argc, char* argv[])
//
//
break
;
break
;
default:
default:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1, m, o]
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// O = P V = 0.0039 * ones
...
@@ -722,15 +739,15 @@ int run(int argc, char* argv[])
...
@@ -722,15 +739,15 @@ int run(int argc, char* argv[])
// = 0.0039 * ones * (ones - 1)
// = 0.0039 * ones * (ones - 1)
// = 0
// = 0
}
}
Tensor
<
DataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
Input
DataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
Input
DataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
Input
DataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
Input
DataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
Tensor
<
DataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
...
@@ -759,25 +776,25 @@ int run(int argc, char* argv[])
...
@@ -759,25 +776,25 @@ int run(int argc, char* argv[])
lse_tensors
.
push_back
(
lse_gs_ms
);
lse_tensors
.
push_back
(
lse_gs_ms
);
ygrad_tensors
.
push_back
(
ygrad_gs_ms_os
);
ygrad_tensors
.
push_back
(
ygrad_gs_ms_os
);
q_tensors_device
.
emplace_back
(
q_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
Input
DataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
k_tensors_device
.
emplace_back
(
k_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
Input
DataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
z_tensors_device
.
emplace_back
(
z_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
GetElementSpaceSize
()));
v_tensors_device
.
emplace_back
(
v_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
Input
DataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
y_tensors_device
.
emplace_back
(
y_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
Input
DataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
lse_tensors_device
.
emplace_back
(
lse_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
GetElementSpaceSize
()));
qgrad_tensors_device
.
emplace_back
(
qgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
Output
DataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
kgrad_tensors_device
.
emplace_back
(
kgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
Output
DataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
vgrad_tensors_device
.
emplace_back
(
vgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
Output
DataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
ygrad_tensors_device
.
emplace_back
(
ygrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
Input
DataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
q_tensors_device
.
back
()
->
ToDevice
(
q_gs_ms_ks
.
data
());
q_tensors_device
.
back
()
->
ToDevice
(
q_gs_ms_ks
.
data
());
k_tensors_device
.
back
()
->
ToDevice
(
k_gs_ns_ks
.
data
());
k_tensors_device
.
back
()
->
ToDevice
(
k_gs_ns_ks
.
data
());
z_tensors_device
.
back
()
->
ToDevice
(
z_gs_ms_ns
.
data
());
z_tensors_device
.
back
()
->
ToDevice
(
z_gs_ms_ns
.
data
());
...
@@ -918,23 +935,26 @@ int run(int argc, char* argv[])
...
@@ -918,23 +935,26 @@ int run(int argc, char* argv[])
int
M
=
q_tensors
[
i
].
GetLengths
()[
2
];
int
M
=
q_tensors
[
i
].
GetLengths
()[
2
];
int
K
=
q_tensors
[
i
].
GetLengths
()[
3
];
int
K
=
q_tensors
[
i
].
GetLengths
()[
3
];
int
BatchCount
=
G0
*
G1
;
int
BatchCount
=
G0
*
G1
;
Tensor
<
DataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
Output
DataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
Output
DataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
DataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
Output
DataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
DataType
>
sgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
sgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
pgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
pgrad_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
Input
DataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
ygrad_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
auto
ref_gemm_grad
=
ReferenceGemmGradInstance
{};
auto
ref_gemm0_grad
=
ReferenceGemm0GradInstance
{};
auto
ref_gemm_grad_invoker
=
ref_gemm_grad
.
MakeInvoker
();
auto
ref_gemm0_grad_invoker
=
ref_gemm0_grad
.
MakeInvoker
();
using
RefGemmGradArg
=
ReferenceGemmGradInstance
::
Argument
;
using
RefGemm0GradArg
=
ReferenceGemm0GradInstance
::
Argument
;
auto
ref_gemm1_grad
=
ReferenceGemm1GradInstance
{};
auto
ref_gemm1_grad_invoker
=
ref_gemm1_grad
.
MakeInvoker
();
using
RefGemm1GradArg
=
ReferenceGemm1GradInstance
::
Argument
;
// dP = dY * V^T
// dP = dY * V^T
auto
v_g_o_n
=
v_g_n_os
[
i
].
Transpose
({
0
,
2
,
1
});
auto
v_g_o_n
=
v_g_n_os
[
i
].
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm
0
_grad_invoker
.
Run
(
RefGemm
0
GradArg
{
ygrad_g_m_o
,
v_g_o_n
,
pgrad_drop_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
ygrad_g_m_o
,
v_g_o_n
,
pgrad_drop_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
...
@@ -951,33 +971,33 @@ int run(int argc, char* argv[])
...
@@ -951,33 +971,33 @@ int run(int argc, char* argv[])
ygrad_dot_y
+=
ck
::
type_convert
<
AccDataType
>
(
ygrad_g_m_o
(
idx_gmo
))
*
ygrad_dot_y
+=
ck
::
type_convert
<
AccDataType
>
(
ygrad_g_m_o
(
idx_gmo
))
*
ck
::
type_convert
<
AccDataType
>
(
y_g_m_os
[
i
](
idx_gmo
));
ck
::
type_convert
<
AccDataType
>
(
y_g_m_os
[
i
](
idx_gmo
));
}
}
self
(
idx_gmn
)
=
ck
::
type_convert
<
DataType
>
(
self
(
idx_gmn
)
=
ck
::
type_convert
<
Input
DataType
>
(
ck
::
type_convert
<
AccDataType
>
(
p_g_m_ns
[
i
](
idx_gmn
))
*
ck
::
type_convert
<
AccDataType
>
(
p_g_m_ns
[
i
](
idx_gmn
))
*
(
ck
::
type_convert
<
AccDataType
>
(
pgrad_g_m_n
(
idx_gmn
))
-
ygrad_dot_y
));
(
ck
::
type_convert
<
AccDataType
>
(
pgrad_g_m_n
(
idx_gmn
))
-
ygrad_dot_y
));
});
});
auto
p_drop_g_n_m
=
p_drop_g_m_ns
[
i
].
Transpose
({
0
,
2
,
1
});
auto
p_drop_g_n_m
=
p_drop_g_m_ns
[
i
].
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm
1
_grad_invoker
.
Run
(
RefGemm
1
GradArg
{
p_drop_g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.0
f
}});
p_drop_g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.0
f
}});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm
1
_grad_invoker
.
Run
(
RefGemm
1
GradArg
{
sgrad_g_m_n
,
k_g_n_ks
[
i
],
qgrad_g_m_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
sgrad_g_m_n
,
k_g_n_ks
[
i
],
qgrad_g_m_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
auto
sgrad_g_n_m
=
sgrad_g_m_n
.
Transpose
({
0
,
2
,
1
});
auto
sgrad_g_n_m
=
sgrad_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm
1
_grad_invoker
.
Run
(
RefGemm
1
GradArg
{
sgrad_g_n_m
,
q_g_m_ks
[
i
],
kgrad_g_n_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
sgrad_g_n_m
,
q_g_m_ks
[
i
],
kgrad_g_n_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
Tensor
<
DataType
>
qgrad_gs_ms_ks_host_result
(
q_tensors
[
i
].
GetLengths
(),
Tensor
<
Output
DataType
>
qgrad_gs_ms_ks_host_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
q_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
kgrad_gs_ns_ks_host_result
(
k_tensors
[
i
].
GetLengths
(),
Tensor
<
Output
DataType
>
kgrad_gs_ns_ks_host_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
k_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
vgrad_gs_os_ns_host_result
(
v_tensors
[
i
].
GetLengths
(),
Tensor
<
Output
DataType
>
vgrad_gs_os_ns_host_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
v_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
qgrad_gs_ms_ks_device_result
(
q_tensors
[
i
].
GetLengths
(),
Tensor
<
Output
DataType
>
qgrad_gs_ms_ks_device_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
q_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
kgrad_gs_ns_ks_device_result
(
k_tensors
[
i
].
GetLengths
(),
Tensor
<
Output
DataType
>
kgrad_gs_ns_ks_device_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
k_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
vgrad_gs_os_ns_device_result
(
v_tensors
[
i
].
GetLengths
(),
Tensor
<
Output
DataType
>
vgrad_gs_os_ns_device_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
v_tensors
[
i
].
GetStrides
());
qgrad_tensors_device
[
i
]
->
FromDevice
(
qgrad_gs_ms_ks_device_result
.
data
());
qgrad_tensors_device
[
i
]
->
FromDevice
(
qgrad_gs_ms_ks_device_result
.
data
());
kgrad_tensors_device
[
i
]
->
FromDevice
(
kgrad_gs_ns_ks_device_result
.
data
());
kgrad_tensors_device
[
i
]
->
FromDevice
(
kgrad_gs_ns_ks_device_result
.
data
());
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
View file @
f3e61c0a
...
@@ -28,7 +28,8 @@ namespace tensor_operation {
...
@@ -28,7 +28,8 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
DataType
,
typename
InputDataType
,
typename
OutputDataType
,
typename
ZDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
...
@@ -53,16 +54,16 @@ __global__ void
...
@@ -53,16 +54,16 @@ __global__ void
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
#endif
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1
(
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1
(
const
DataType
*
__restrict__
p_a_grid
,
const
Input
DataType
*
__restrict__
p_a_grid
,
const
DataType
*
__restrict__
p_b_grid
,
const
Input
DataType
*
__restrict__
p_b_grid
,
ZDataType
*
__restrict__
p_z_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
DataType
*
__restrict__
p_b1_grid
,
const
Input
DataType
*
__restrict__
p_b1_grid
,
const
DataType
*
__restrict__
p_c_grid
,
const
Input
DataType
*
__restrict__
p_c_grid
,
const
LSEDataType
*
__restrict__
p_lse_grid
,
const
LSEDataType
*
__restrict__
p_lse_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
const
Input
DataType
*
__restrict__
p_ygrad_grid
,
DataType
*
__restrict__
p_qgrad_grid
,
Output
DataType
*
__restrict__
p_qgrad_grid
,
DataType
*
__restrict__
p_kgrad_grid
,
Output
DataType
*
__restrict__
p_kgrad_grid
,
DataType
*
__restrict__
p_vgrad_grid
,
Output
DataType
*
__restrict__
p_vgrad_grid
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
AccElementwiseOperation
acc_element_op
,
const
AccElementwiseOperation
acc_element_op
,
...
@@ -171,7 +172,8 @@ template <index_t NumDimG,
...
@@ -171,7 +172,8 @@ template <index_t NumDimG,
index_t
NumDimN
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimK
,
index_t
NumDimO
,
// NumDimGemm1N
index_t
NumDimO
,
// NumDimGemm1N
typename
DataType
,
typename
InputDataType
,
typename
OutputDataType
,
typename
GemmDataType
,
typename
GemmDataType
,
typename
ZDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
...
@@ -597,7 +599,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -597,7 +599,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
DataType
,
// TODO: distinguish A/B datatype
InputDataType
,
// TODO: distinguish A/B datatype
OutputDataType
,
GemmDataType
,
GemmDataType
,
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
...
@@ -666,16 +669,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -666,16 +669,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
Argument
(
const
DataType
*
p_a_grid
,
const
Input
DataType
*
p_a_grid
,
const
DataType
*
p_b_grid
,
const
Input
DataType
*
p_b_grid
,
ZDataType
*
p_z_grid
,
ZDataType
*
p_z_grid
,
const
DataType
*
p_b1_grid
,
const
Input
DataType
*
p_b1_grid
,
const
DataType
*
p_c_grid
,
// for dS
const
Input
DataType
*
p_c_grid
,
// for dS
const
LSEDataType
*
p_lse_grid
,
const
LSEDataType
*
p_lse_grid
,
const
DataType
*
p_ygrad_grid
,
const
Input
DataType
*
p_ygrad_grid
,
DataType
*
p_qgrad_grid
,
Output
DataType
*
p_qgrad_grid
,
DataType
*
p_kgrad_grid
,
Output
DataType
*
p_kgrad_grid
,
DataType
*
p_vgrad_grid
,
Output
DataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
...
@@ -820,16 +823,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -820,16 +823,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
}
// pointers
// pointers
const
DataType
*
p_a_grid_
;
const
Input
DataType
*
p_a_grid_
;
const
DataType
*
p_b_grid_
;
const
Input
DataType
*
p_b_grid_
;
ZDataType
*
p_z_grid_
;
ZDataType
*
p_z_grid_
;
const
DataType
*
p_b1_grid_
;
const
Input
DataType
*
p_b1_grid_
;
const
DataType
*
p_c_grid_
;
const
Input
DataType
*
p_c_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
DataType
*
p_ygrad_grid_
;
const
Input
DataType
*
p_ygrad_grid_
;
DataType
*
p_qgrad_grid_
;
Output
DataType
*
p_qgrad_grid_
;
DataType
*
p_kgrad_grid_
;
Output
DataType
*
p_kgrad_grid_
;
DataType
*
p_vgrad_grid_
;
Output
DataType
*
p_vgrad_grid_
;
// tensor descriptor
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
...
@@ -901,7 +904,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -901,7 +904,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1
<
const
auto
kernel
=
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1
<
GridwiseGemm
,
GridwiseGemm
,
DataType
,
InputDataType
,
OutputDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
...
@@ -1067,16 +1071,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1067,16 +1071,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
}
static
auto
MakeArgument
(
static
auto
MakeArgument
(
const
DataType
*
p_a
,
const
Input
DataType
*
p_a
,
const
DataType
*
p_b
,
const
Input
DataType
*
p_b
,
ZDataType
*
p_z
,
ZDataType
*
p_z
,
const
DataType
*
p_b1
,
const
Input
DataType
*
p_b1
,
const
DataType
*
p_c
,
const
Input
DataType
*
p_c
,
const
LSEDataType
*
p_lse
,
const
LSEDataType
*
p_lse
,
const
DataType
*
p_ygrad_grid
,
const
Input
DataType
*
p_ygrad_grid
,
DataType
*
p_qgrad_grid
,
Output
DataType
*
p_qgrad_grid
,
DataType
*
p_kgrad_grid
,
Output
DataType
*
p_kgrad_grid
,
DataType
*
p_vgrad_grid
,
Output
DataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
...
@@ -1182,16 +1186,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1182,16 +1186,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
float
p_drop
,
float
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
DataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
Input
DataType
*>
(
p_a
),
static_cast
<
const
DataType
*>
(
p_b
),
static_cast
<
const
Input
DataType
*>
(
p_b
),
static_cast
<
ZDataType
*>
(
p_z
),
static_cast
<
ZDataType
*>
(
p_z
),
static_cast
<
const
DataType
*>
(
p_b1
),
static_cast
<
const
Input
DataType
*>
(
p_b1
),
static_cast
<
const
DataType
*>
(
p_c
),
static_cast
<
const
Input
DataType
*>
(
p_c
),
static_cast
<
const
LSEDataType
*>
(
p_lse
),
static_cast
<
const
LSEDataType
*>
(
p_lse
),
static_cast
<
const
DataType
*>
(
p_ygrad_grid
),
static_cast
<
const
Input
DataType
*>
(
p_ygrad_grid
),
static_cast
<
DataType
*>
(
p_qgrad_grid
),
static_cast
<
Output
DataType
*>
(
p_qgrad_grid
),
static_cast
<
DataType
*>
(
p_kgrad_grid
),
static_cast
<
Output
DataType
*>
(
p_kgrad_grid
),
static_cast
<
DataType
*>
(
p_vgrad_grid
),
static_cast
<
Output
DataType
*>
(
p_vgrad_grid
),
p_acc0_biases
,
// cast in struct Argument
p_acc0_biases
,
// cast in struct Argument
p_acc1_biases
,
// cast in struct Argument
p_acc1_biases
,
// cast in struct Argument
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
f3e61c0a
...
@@ -27,7 +27,8 @@ namespace tensor_operation {
...
@@ -27,7 +27,8 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
DataType
,
typename
InputDataType
,
typename
OutputDataType
,
typename
ZDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
...
@@ -52,16 +53,16 @@ __global__ void
...
@@ -52,16 +53,16 @@ __global__ void
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
#endif
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2
(
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2
(
const
DataType
*
__restrict__
p_a_grid
,
const
Input
DataType
*
__restrict__
p_a_grid
,
const
DataType
*
__restrict__
p_b_grid
,
const
Input
DataType
*
__restrict__
p_b_grid
,
ZDataType
*
__restrict__
p_z_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
DataType
*
__restrict__
p_b1_grid
,
const
Input
DataType
*
__restrict__
p_b1_grid
,
const
DataType
*
__restrict__
p_c_grid
,
const
Input
DataType
*
__restrict__
p_c_grid
,
const
LSEDataType
*
__restrict__
p_lse_grid
,
const
LSEDataType
*
__restrict__
p_lse_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
const
Input
DataType
*
__restrict__
p_ygrad_grid
,
DataType
*
__restrict__
p_qgrad_grid
,
Output
DataType
*
__restrict__
p_qgrad_grid
,
DataType
*
__restrict__
p_kgrad_grid
,
Output
DataType
*
__restrict__
p_kgrad_grid
,
DataType
*
__restrict__
p_vgrad_grid
,
Output
DataType
*
__restrict__
p_vgrad_grid
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
AccElementwiseOperation
acc_element_op
,
const
AccElementwiseOperation
acc_element_op
,
...
@@ -170,7 +171,8 @@ template <index_t NumDimG,
...
@@ -170,7 +171,8 @@ template <index_t NumDimG,
index_t
NumDimN
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimK
,
index_t
NumDimO
,
// NumDimGemm1N
index_t
NumDimO
,
// NumDimGemm1N
typename
DataType
,
typename
InputDataType
,
typename
OutputDataType
,
typename
GemmDataType
,
typename
GemmDataType
,
typename
ZDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
...
@@ -596,7 +598,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -596,7 +598,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
DataType
,
// TODO: distinguish A/B datatype
InputDataType
,
// TODO: distinguish A/B datatype
OutputDataType
,
GemmDataType
,
GemmDataType
,
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
...
@@ -665,16 +668,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -665,16 +668,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
Argument
(
const
DataType
*
p_a_grid
,
const
Input
DataType
*
p_a_grid
,
const
DataType
*
p_b_grid
,
const
Input
DataType
*
p_b_grid
,
ZDataType
*
p_z_grid
,
ZDataType
*
p_z_grid
,
const
DataType
*
p_b1_grid
,
const
Input
DataType
*
p_b1_grid
,
const
DataType
*
p_c_grid
,
// for dS
const
Input
DataType
*
p_c_grid
,
// for dS
const
LSEDataType
*
p_lse_grid
,
const
LSEDataType
*
p_lse_grid
,
const
DataType
*
p_ygrad_grid
,
const
Input
DataType
*
p_ygrad_grid
,
DataType
*
p_qgrad_grid
,
Output
DataType
*
p_qgrad_grid
,
DataType
*
p_kgrad_grid
,
Output
DataType
*
p_kgrad_grid
,
DataType
*
p_vgrad_grid
,
Output
DataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
...
@@ -818,16 +821,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -818,16 +821,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
}
// pointers
// pointers
const
DataType
*
p_a_grid_
;
const
Input
DataType
*
p_a_grid_
;
const
DataType
*
p_b_grid_
;
const
Input
DataType
*
p_b_grid_
;
ZDataType
*
p_z_grid_
;
ZDataType
*
p_z_grid_
;
const
DataType
*
p_b1_grid_
;
const
Input
DataType
*
p_b1_grid_
;
const
DataType
*
p_c_grid_
;
const
Input
DataType
*
p_c_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
DataType
*
p_ygrad_grid_
;
const
Input
DataType
*
p_ygrad_grid_
;
DataType
*
p_qgrad_grid_
;
Output
DataType
*
p_qgrad_grid_
;
DataType
*
p_kgrad_grid_
;
Output
DataType
*
p_kgrad_grid_
;
DataType
*
p_vgrad_grid_
;
Output
DataType
*
p_vgrad_grid_
;
// tensor descriptor
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
...
@@ -903,7 +906,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -903,7 +906,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2
<
const
auto
kernel
=
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2
<
GridwiseGemm
,
GridwiseGemm
,
DataType
,
InputDataType
,
OutputDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
...
@@ -1067,16 +1071,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1067,16 +1071,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
}
static
auto
MakeArgument
(
static
auto
MakeArgument
(
const
DataType
*
p_a
,
const
Input
DataType
*
p_a
,
const
DataType
*
p_b
,
const
Input
DataType
*
p_b
,
ZDataType
*
p_z
,
ZDataType
*
p_z
,
const
DataType
*
p_b1
,
const
Input
DataType
*
p_b1
,
const
DataType
*
p_c
,
const
Input
DataType
*
p_c
,
const
LSEDataType
*
p_lse
,
const
LSEDataType
*
p_lse
,
const
DataType
*
p_ygrad_grid
,
const
Input
DataType
*
p_ygrad_grid
,
DataType
*
p_qgrad_grid
,
Output
DataType
*
p_qgrad_grid
,
DataType
*
p_kgrad_grid
,
Output
DataType
*
p_kgrad_grid
,
DataType
*
p_vgrad_grid
,
Output
DataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
...
@@ -1182,16 +1186,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1182,16 +1186,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
float
p_drop
,
float
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
DataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
Input
DataType
*>
(
p_a
),
static_cast
<
const
DataType
*>
(
p_b
),
static_cast
<
const
Input
DataType
*>
(
p_b
),
static_cast
<
ZDataType
*>
(
p_z
),
static_cast
<
ZDataType
*>
(
p_z
),
static_cast
<
const
DataType
*>
(
p_b1
),
static_cast
<
const
Input
DataType
*>
(
p_b1
),
static_cast
<
const
DataType
*>
(
p_c
),
static_cast
<
const
Input
DataType
*>
(
p_c
),
static_cast
<
const
LSEDataType
*>
(
p_lse
),
static_cast
<
const
LSEDataType
*>
(
p_lse
),
static_cast
<
const
DataType
*>
(
p_ygrad_grid
),
static_cast
<
const
Input
DataType
*>
(
p_ygrad_grid
),
static_cast
<
DataType
*>
(
p_qgrad_grid
),
static_cast
<
Output
DataType
*>
(
p_qgrad_grid
),
static_cast
<
DataType
*>
(
p_kgrad_grid
),
static_cast
<
Output
DataType
*>
(
p_kgrad_grid
),
static_cast
<
DataType
*>
(
p_vgrad_grid
),
static_cast
<
Output
DataType
*>
(
p_vgrad_grid
),
p_acc0_biases
,
// cast in struct Argument
p_acc0_biases
,
// cast in struct Argument
p_acc1_biases
,
// cast in struct Argument
p_acc1_biases
,
// cast in struct Argument
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp
View file @
f3e61c0a
...
@@ -150,7 +150,8 @@ template <index_t NumDimG,
...
@@ -150,7 +150,8 @@ template <index_t NumDimG,
index_t
NumDimN
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimK
,
index_t
NumDimO
,
// NumDimGemm1N
index_t
NumDimO
,
// NumDimGemm1N
typename
DataType
,
typename
InputDataType
,
typename
OutputDataType
,
typename
GemmDataType
,
typename
GemmDataType
,
typename
ZDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
...
@@ -534,7 +535,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -534,7 +535,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
DataType
,
// TODO: distinguish A/B datatype
InputDataType
,
// TODO: distinguish A/B datatype
OutputDataType
,
GemmDataType
,
GemmDataType
,
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
...
@@ -604,16 +606,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -604,16 +606,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct
GroupKernelArg
struct
GroupKernelArg
{
{
// pointers
// pointers
const
DataType
*
p_a_grid_
;
const
Input
DataType
*
p_a_grid_
;
const
DataType
*
p_b_grid_
;
const
Input
DataType
*
p_b_grid_
;
ZDataType
*
p_z_grid_
;
ZDataType
*
p_z_grid_
;
const
DataType
*
p_b1_grid_
;
const
Input
DataType
*
p_b1_grid_
;
const
DataType
*
p_c_grid_
;
const
Input
DataType
*
p_c_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
DataType
*
p_ygrad_grid_
;
const
Input
DataType
*
p_ygrad_grid_
;
DataType
*
p_qgrad_grid_
;
Output
DataType
*
p_qgrad_grid_
;
DataType
*
p_kgrad_grid_
;
Output
DataType
*
p_kgrad_grid_
;
DataType
*
p_vgrad_grid_
;
Output
DataType
*
p_vgrad_grid_
;
// tensor descriptors for block/thread-wise copy
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
...
@@ -712,16 +714,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -712,16 +714,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
grid_size_
=
0
;
grid_size_
=
0
;
for
(
index_t
i
=
0
;
i
<
group_count_
;
i
++
)
for
(
index_t
i
=
0
;
i
<
group_count_
;
i
++
)
{
{
const
auto
p_a_grid
=
static_cast
<
const
DataType
*>
(
p_As
[
i
]);
const
auto
p_a_grid
=
static_cast
<
const
Input
DataType
*>
(
p_As
[
i
]);
const
auto
p_b_grid
=
static_cast
<
const
DataType
*>
(
p_Bs
[
i
]);
const
auto
p_b_grid
=
static_cast
<
const
Input
DataType
*>
(
p_Bs
[
i
]);
auto
p_z_grid
=
static_cast
<
ZDataType
*>
(
p_Zs
[
i
]);
auto
p_z_grid
=
static_cast
<
ZDataType
*>
(
p_Zs
[
i
]);
const
auto
p_b1_grid
=
static_cast
<
const
DataType
*>
(
p_B1s
[
i
]);
const
auto
p_b1_grid
=
static_cast
<
const
Input
DataType
*>
(
p_B1s
[
i
]);
const
auto
p_c_grid
=
static_cast
<
const
DataType
*>
(
p_Cs
[
i
]);
const
auto
p_c_grid
=
static_cast
<
const
Input
DataType
*>
(
p_Cs
[
i
]);
const
auto
p_lse_grid
=
static_cast
<
const
LSEDataType
*>
(
p_LSEs
[
i
]);
const
auto
p_lse_grid
=
static_cast
<
const
LSEDataType
*>
(
p_LSEs
[
i
]);
const
auto
p_ygrad_grid
=
static_cast
<
const
DataType
*>
(
p_Ygrads
[
i
]);
const
auto
p_ygrad_grid
=
static_cast
<
const
Input
DataType
*>
(
p_Ygrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
DataType
*>
(
p_Qgrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
Output
DataType
*>
(
p_Qgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
DataType
*>
(
p_Kgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
Output
DataType
*>
(
p_Kgrads
[
i
]);
auto
p_vgrad_grid
=
static_cast
<
DataType
*>
(
p_Vgrads
[
i
]);
auto
p_vgrad_grid
=
static_cast
<
Output
DataType
*>
(
p_Vgrads
[
i
]);
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
f3e61c0a
...
@@ -150,7 +150,8 @@ template <index_t NumDimG,
...
@@ -150,7 +150,8 @@ template <index_t NumDimG,
index_t
NumDimN
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimK
,
index_t
NumDimO
,
// NumDimGemm1N
index_t
NumDimO
,
// NumDimGemm1N
typename
DataType
,
typename
InputDataType
,
typename
OutputDataType
,
typename
GemmDataType
,
typename
GemmDataType
,
typename
ZDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
...
@@ -527,7 +528,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -527,7 +528,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
DataType
,
// TODO: distinguish A/B datatype
InputDataType
,
// TODO: distinguish A/B datatype
OutputDataType
,
GemmDataType
,
GemmDataType
,
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
...
@@ -597,16 +599,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -597,16 +599,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct
GroupKernelArg
struct
GroupKernelArg
{
{
// pointers
// pointers
const
DataType
*
p_a_grid_
;
const
Input
DataType
*
p_a_grid_
;
const
DataType
*
p_b_grid_
;
const
Input
DataType
*
p_b_grid_
;
ZDataType
*
p_z_grid_
;
ZDataType
*
p_z_grid_
;
const
DataType
*
p_b1_grid_
;
const
Input
DataType
*
p_b1_grid_
;
const
DataType
*
p_c_grid_
;
const
Input
DataType
*
p_c_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
DataType
*
p_ygrad_grid_
;
const
Input
DataType
*
p_ygrad_grid_
;
DataType
*
p_qgrad_grid_
;
Output
DataType
*
p_qgrad_grid_
;
DataType
*
p_kgrad_grid_
;
Output
DataType
*
p_kgrad_grid_
;
DataType
*
p_vgrad_grid_
;
Output
DataType
*
p_vgrad_grid_
;
// tensor descriptors for block/thread-wise copy
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
...
@@ -705,16 +707,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -705,16 +707,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
grid_size_
=
0
;
grid_size_
=
0
;
for
(
index_t
i
=
0
;
i
<
group_count_
;
i
++
)
for
(
index_t
i
=
0
;
i
<
group_count_
;
i
++
)
{
{
const
auto
p_a_grid
=
static_cast
<
const
DataType
*>
(
p_As
[
i
]);
const
auto
p_a_grid
=
static_cast
<
const
Input
DataType
*>
(
p_As
[
i
]);
const
auto
p_b_grid
=
static_cast
<
const
DataType
*>
(
p_Bs
[
i
]);
const
auto
p_b_grid
=
static_cast
<
const
Input
DataType
*>
(
p_Bs
[
i
]);
auto
p_z_grid
=
static_cast
<
ZDataType
*>
(
p_Zs
[
i
]);
auto
p_z_grid
=
static_cast
<
ZDataType
*>
(
p_Zs
[
i
]);
const
auto
p_b1_grid
=
static_cast
<
const
DataType
*>
(
p_B1s
[
i
]);
const
auto
p_b1_grid
=
static_cast
<
const
Input
DataType
*>
(
p_B1s
[
i
]);
const
auto
p_c_grid
=
static_cast
<
const
DataType
*>
(
p_Cs
[
i
]);
const
auto
p_c_grid
=
static_cast
<
const
Input
DataType
*>
(
p_Cs
[
i
]);
const
auto
p_lse_grid
=
static_cast
<
const
LSEDataType
*>
(
p_LSEs
[
i
]);
const
auto
p_lse_grid
=
static_cast
<
const
LSEDataType
*>
(
p_LSEs
[
i
]);
const
auto
p_ygrad_grid
=
static_cast
<
const
DataType
*>
(
p_Ygrads
[
i
]);
const
auto
p_ygrad_grid
=
static_cast
<
const
Input
DataType
*>
(
p_Ygrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
DataType
*>
(
p_Qgrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
Output
DataType
*>
(
p_Qgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
DataType
*>
(
p_Kgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
Output
DataType
*>
(
p_Kgrads
[
i
]);
auto
p_vgrad_grid
=
static_cast
<
DataType
*>
(
p_Vgrads
[
i
]);
auto
p_vgrad_grid
=
static_cast
<
Output
DataType
*>
(
p_Vgrads
[
i
]);
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
View file @
f3e61c0a
...
@@ -20,7 +20,8 @@
...
@@ -20,7 +20,8 @@
namespace
ck
{
namespace
ck
{
template
<
typename
DataType
,
template
<
typename
InputDataType
,
typename
OutputDataType
,
typename
GemmDataType
,
typename
GemmDataType
,
typename
FloatGemmAcc
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatCShuffle
,
...
@@ -381,7 +382,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -381,7 +382,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GemmDataType
,
GridDesc_K0_M_K1
,
GridDesc_K0_M_K1
,
decltype
(
q_block_desc_k0_m_k1
),
decltype
(
q_block_desc_k0_m_k1
),
...
@@ -406,7 +407,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -406,7 +407,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GemmDataType
,
GridDesc_K0_N_K1
,
GridDesc_K0_N_K1
,
decltype
(
k_block_desc_k0_n_k1
),
decltype
(
k_block_desc_k0_n_k1
),
...
@@ -431,7 +432,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -431,7 +432,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GemmDataType
,
GridDesc_K0_N_K1
,
GridDesc_K0_N_K1
,
decltype
(
v_block_desc_k0_n_k1
),
decltype
(
v_block_desc_k0_n_k1
),
...
@@ -456,7 +457,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -456,7 +457,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GemmDataType
,
GridDesc_K0_M_K1
,
GridDesc_K0_M_K1
,
decltype
(
ygrad_block_desc_k0_m_k1
),
decltype
(
ygrad_block_desc_k0_m_k1
),
...
@@ -1043,7 +1044,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1043,7 +1044,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
using
CBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
using
CBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatGemmAcc
,
DataType
,
Output
DataType
,
decltype
(
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
),
decltype
(
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
),
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
,
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
,
ElementwiseOp
,
// CElementwiseOperation
ElementwiseOp
,
// CElementwiseOperation
...
@@ -1059,7 +1060,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1059,7 +1060,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
struct
YDotYGrad_M_O_
{
{
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
DataType
);
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
Input
DataType
);
static
constexpr
auto
ThreadClusterLength_O
=
static
constexpr
auto
ThreadClusterLength_O
=
Number
<
BlockSliceLength_O_
/
SrcScalarPerVector
>
{};
Number
<
BlockSliceLength_O_
/
SrcScalarPerVector
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
...
@@ -1234,16 +1235,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1234,16 +1235,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename
C0MatrixMask
,
typename
C0MatrixMask
,
typename
VGradGridDescriptor_N_O
,
typename
VGradGridDescriptor_N_O
,
typename
YGradGridDesc_O0_M_O1
>
typename
YGradGridDesc_O0_M_O1
>
__device__
static
void
Run
(
const
DataType
*
__restrict__
p_q_grid
,
__device__
static
void
Run
(
const
Input
DataType
*
__restrict__
p_q_grid
,
const
DataType
*
__restrict__
p_k_grid
,
const
Input
DataType
*
__restrict__
p_k_grid
,
unsigned
short
*
__restrict__
p_z_grid
,
unsigned
short
*
__restrict__
p_z_grid
,
const
DataType
*
__restrict__
p_v_grid
,
const
Input
DataType
*
__restrict__
p_v_grid
,
const
DataType
*
__restrict__
p_y_grid
,
const
Input
DataType
*
__restrict__
p_y_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
const
Input
DataType
*
__restrict__
p_ygrad_grid
,
DataType
*
__restrict__
p_qgrad_grid
,
Output
DataType
*
__restrict__
p_qgrad_grid
,
DataType
*
__restrict__
p_kgrad_grid
,
Output
DataType
*
__restrict__
p_kgrad_grid
,
DataType
*
__restrict__
p_vgrad_grid
,
Output
DataType
*
__restrict__
p_vgrad_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
...
@@ -1723,7 +1724,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1723,7 +1724,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// performs for y
// performs for y
auto
y_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
auto
y_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
Input
DataType
,
FloatGemmAcc
,
FloatGemmAcc
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
decltype
(
y_thread_desc_m0_m1_o0_o1
),
decltype
(
y_thread_desc_m0_m1_o0_o1
),
...
@@ -2307,7 +2308,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -2307,7 +2308,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
FloatCShuffle
,
// typename SrcData,
DataType
,
// typename DstData,
Output
DataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
),
decltype
(
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt2.hpp
View file @
f3e61c0a
...
@@ -20,7 +20,8 @@
...
@@ -20,7 +20,8 @@
namespace
ck
{
namespace
ck
{
template
<
typename
DataType
,
template
<
typename
InputDataType
,
typename
OutputDataType
,
typename
GemmDataType
,
typename
GemmDataType
,
typename
FloatGemmAcc
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatCShuffle
,
...
@@ -457,7 +458,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -457,7 +458,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GemmDataType
,
GridDesc_K0_M_K1
,
GridDesc_K0_M_K1
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
...
@@ -482,7 +483,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -482,7 +483,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GemmDataType
,
GridDesc_K0_N_K1
,
GridDesc_K0_N_K1
,
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
@@ -585,7 +586,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -585,7 +586,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Sequence
<
B1K0
,
Gemm1NPerBlock
,
B1K1
>
,
Sequence
<
B1K0
,
Gemm1NPerBlock
,
B1K1
>
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GemmDataType
,
GridDesc_K0_N_K1
,
GridDesc_K0_N_K1
,
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
@@ -823,7 +824,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -823,7 +824,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename
Gemm2Params_N_O_M
::
BBlockSliceLengths
,
typename
Gemm2Params_N_O_M
::
BBlockSliceLengths
,
typename
Gemm2Params_N_O_M
::
BThreadClusterLengths
,
typename
Gemm2Params_N_O_M
::
BThreadClusterLengths
,
typename
Gemm2Params_N_O_M
::
BThreadClusterArrangeOrder
,
typename
Gemm2Params_N_O_M
::
BThreadClusterArrangeOrder
,
DataType
,
Input
DataType
,
GemmDataType
,
GemmDataType
,
GridDesc_M0_O_M1
,
GridDesc_M0_O_M1
,
decltype
(
b_block_desc_m0_o_m1
),
decltype
(
b_block_desc_m0_o_m1
),
...
@@ -892,7 +893,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -892,7 +893,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
using
CBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
using
CBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatGemmAcc
,
DataType
,
Output
DataType
,
decltype
(
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
),
decltype
(
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
),
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
,
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
,
ElementwiseOp
,
// CElementwiseOperation
ElementwiseOp
,
// CElementwiseOperation
...
@@ -908,7 +909,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -908,7 +909,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
struct
YDotYGrad_M_O_
{
{
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
DataType
);
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
Input
DataType
);
static
constexpr
auto
ThreadClusterLength_O
=
static
constexpr
auto
ThreadClusterLength_O
=
Number
<
BlockSliceLength_O_
/
SrcScalarPerVector
>
{};
Number
<
BlockSliceLength_O_
/
SrcScalarPerVector
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
...
@@ -1144,16 +1145,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1144,16 +1145,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename
C0MatrixMask
,
typename
C0MatrixMask
,
typename
VGradGridDescriptor_N_O
,
typename
VGradGridDescriptor_N_O
,
typename
YGradGridDesc_M0_O_M1
>
typename
YGradGridDesc_M0_O_M1
>
__device__
static
void
Run
(
const
DataType
*
__restrict__
p_q_grid
,
__device__
static
void
Run
(
const
Input
DataType
*
__restrict__
p_q_grid
,
const
DataType
*
__restrict__
p_k_grid
,
const
Input
DataType
*
__restrict__
p_k_grid
,
unsigned
short
*
__restrict__
p_z_grid
,
unsigned
short
*
__restrict__
p_z_grid
,
const
DataType
*
__restrict__
p_v_grid
,
const
Input
DataType
*
__restrict__
p_v_grid
,
const
DataType
*
__restrict__
p_y_grid
,
const
Input
DataType
*
__restrict__
p_y_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
const
Input
DataType
*
__restrict__
p_ygrad_grid
,
DataType
*
__restrict__
p_qgrad_grid
,
Output
DataType
*
__restrict__
p_qgrad_grid
,
DataType
*
__restrict__
p_kgrad_grid
,
Output
DataType
*
__restrict__
p_kgrad_grid
,
DataType
*
__restrict__
p_vgrad_grid
,
Output
DataType
*
__restrict__
p_vgrad_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
...
@@ -1646,7 +1647,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1646,7 +1647,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// performs double duty for both y and ygrad
// performs double duty for both y and ygrad
auto
yygrad_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
auto
yygrad_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
Input
DataType
,
FloatGemmAcc
,
FloatGemmAcc
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
decltype
(
y_thread_desc_m0_m1_o0_o1
),
decltype
(
y_thread_desc_m0_m1_o0_o1
),
...
@@ -2257,7 +2258,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -2257,7 +2258,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
FloatCShuffle
,
// typename SrcData,
DataType
,
// typename DstData,
Output
DataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
),
decltype
(
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment