Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3b57967f
Commit
3b57967f
authored
Apr 13, 2023
by
danyao12
Browse files
batched&grouped train output datatype can be selected
parent
d042e931
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
328 additions
and
279 deletions
+328
-279
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
..._scale_softmax_gemm/batched_multihead_attention_train.cpp
+167
-142
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train.cpp
..._scale_softmax_gemm/grouped_multihead_attention_train.cpp
+161
-137
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
View file @
3b57967f
...
...
@@ -71,7 +71,8 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using
QKVElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
DataType
=
BF16
;
using
InputDataType
=
BF16
;
using
OutputDataType
=
F32
;
using
GemmDataType
=
BF16
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
...
...
@@ -85,6 +86,9 @@ static constexpr ck::index_t NumDimM = 1;
static
constexpr
ck
::
index_t
NumDimN
=
1
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
// When OutputDataType == F32, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static
constexpr
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
=
4
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
...
...
@@ -112,10 +116,10 @@ using DeviceGemmInstanceFWD =
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
DataType
,
DataType
,
DataType
,
Input
DataType
,
Input
DataType
,
Input
DataType
,
Input
DataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
...
...
@@ -182,7 +186,8 @@ using DeviceGemmInstanceBWD =
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
...
...
@@ -240,8 +245,8 @@ using DeviceGemmInstanceBWD =
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
#elif(DIM <= 64)
using
DeviceGemmInstanceFWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
...
...
@@ -250,10 +255,10 @@ using DeviceGemmInstanceFWD =
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
DataType
,
DataType
,
DataType
,
Input
DataType
,
Input
DataType
,
Input
DataType
,
Input
DataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
...
...
@@ -320,7 +325,8 @@ using DeviceGemmInstanceBWD =
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
...
...
@@ -378,8 +384,8 @@ using DeviceGemmInstanceBWD =
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
// using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...
...
@@ -388,7 +394,8 @@ using DeviceGemmInstanceBWD =
// NumDimN,
// NumDimK,
// NumDimO,
// DataType,
// InputDataType,
// OutputDataType,
// GemmDataType,
// ZDataType,
// LSEDataType,
...
...
@@ -446,8 +453,8 @@ using DeviceGemmInstanceBWD =
// 1, // CShuffleMXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
//
8, //
CShuffleBlockTransferScalarPerVector_NPerBlock
// MaskingSpec>;
// MaskingSpecialization
// CShuffleBlockTransferScalarPerVector_NPerBlock
,
// MaskingSpec>;
#elif(DIM <= 128)
using
DeviceGemmInstanceFWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
...
...
@@ -456,10 +463,10 @@ using DeviceGemmInstanceFWD =
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
DataType
,
DataType
,
DataType
,
Input
DataType
,
Input
DataType
,
Input
DataType
,
Input
DataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
...
...
@@ -526,7 +533,8 @@ using DeviceGemmInstanceBWD =
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
...
...
@@ -584,14 +592,14 @@ using DeviceGemmInstanceBWD =
1
,
// CShuffleMXdlPerWavePerShuffle
4
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
#endif
// Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
DataType
,
DataType
,
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
Input
DataType
,
Input
DataType
,
AccDataType
,
AccDataType
,
PassThrough
,
...
...
@@ -601,13 +609,13 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
// Ref Softmax: P = Softmax(S)
// fp32 in, fp16 out
using
ReferenceSoftmaxInstance
=
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
DataType
,
AccDataType
>
;
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
Input
DataType
,
AccDataType
>
;
// Ref Gemm1: Y = P * V
// fp16 in, fp16 out
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
DataType
,
DataType
,
DataType
,
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
Input
DataType
,
Input
DataType
,
Input
DataType
,
AccDataType
,
PassThrough
,
PassThrough
,
...
...
@@ -615,16 +623,25 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
// Ref Gemm for backward pass
// fp16 in, fp16 out
using
ReferenceGemmGradInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
DataType
,
DataType
,
DataType
,
AccDataType
,
PassThrough
,
PassThrough
,
Scale
>
;
using
ReferenceGemm0GradInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
InputDataType
,
InputDataType
,
InputDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
Scale
>
;
using
ReferenceGemm1GradInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
InputDataType
,
InputDataType
,
OutputDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
Scale
>
;
// Ref dropout
using
ReferenceDropoutInstance
=
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ushort
,
DataType
,
DataType
>
;
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ushort
,
Input
DataType
,
Input
DataType
>
;
template
<
typename
TensorQ
,
typename
TensorK
,
...
...
@@ -811,17 +828,17 @@ int run(int argc, char* argv[])
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
Tensor
<
DataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
Input
DataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
Input
DataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
ZDataType
>
z_fwd_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
ZDataType
>
z_bwd_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
DataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
DataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
DataType
>
y_gs_ms_os_device_result
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
Input
DataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
Input
DataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
Input
DataType
>
y_gs_ms_os_device_result
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
DataType
>
qgrad_gs_ms_ks_device_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
kgrad_gs_ns_ks_device_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
DataType
>
vgrad_gs_os_ns_device_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
Output
DataType
>
qgrad_gs_ms_ks_device_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
Output
DataType
>
kgrad_gs_ns_ks_device_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
Output
DataType
>
vgrad_gs_os_ns_device_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
...
...
@@ -830,46 +847,46 @@ int run(int argc, char* argv[])
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os_device_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms_device_result
.
mDesc
<<
std
::
endl
;
z_fwd_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
0
});
z_bwd_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
0
});
z_fwd_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
0
});
z_bwd_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
0
});
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
2
,
2
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
2
,
2
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
2
,
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
2
,
2
});
break
;
case
2
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
0.0
,
1.0
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
0.0
,
1.0
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
-
0.5
,
0.5
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
-
0.5
,
0.5
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
Input
DataType
>
{
0.0
,
1.0
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
Input
DataType
>
{
0.0
,
1.0
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
Input
DataType
>
{
-
0.5
,
0.5
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_3
<
Input
DataType
>
{
-
0.5
,
0.5
});
break
;
case
3
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
5
,
5
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
5
,
5
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
break
;
case
4
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
break
;
case
5
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
// dO dot O = [0; 1; 2; ...]
break
;
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
...
...
@@ -880,10 +897,10 @@ int run(int argc, char* argv[])
//
break
;
default:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
// dy[g0, g1, m, o]
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
// dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
...
...
@@ -895,21 +912,22 @@ int run(int argc, char* argv[])
}
// qkv gradients have the same descriptor as with qkv
DeviceMem
q_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
k_device_buf
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
q_device_buf
(
sizeof
(
Input
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
k_device_buf
(
sizeof
(
Input
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
z_fwd_device_buf
(
sizeof
(
ZDataType
)
*
z_fwd_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
z_bwd_device_buf
(
sizeof
(
ZDataType
)
*
z_bwd_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
v_device_buf
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
DataType
)
*
y_gs_ms_os_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
v_device_buf
(
sizeof
(
InputDataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
qgrad_device_buf
(
sizeof
(
DataType
)
*
DeviceMem
qgrad_device_buf
(
sizeof
(
Output
DataType
)
*
qgrad_gs_ms_ks_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
DataType
)
*
DeviceMem
kgrad_device_buf
(
sizeof
(
Output
DataType
)
*
kgrad_gs_ns_ks_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
DataType
)
*
DeviceMem
vgrad_device_buf
(
sizeof
(
Output
DataType
)
*
vgrad_gs_os_ns_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
DataType
)
*
ygrad_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
Input
DataType
)
*
ygrad_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
...
...
@@ -926,10 +944,10 @@ int run(int argc, char* argv[])
if
(
time_kernel
)
{
auto
argument_fwd
=
gemm_fwd
.
MakeArgument
(
static_cast
<
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
nullptr
),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
...
...
@@ -967,10 +985,11 @@ int run(int argc, char* argv[])
float
ave_time_fwd
=
invoker_fwd
.
Run
(
argument_fwd
,
StreamConfig
{
nullptr
,
true
});
std
::
size_t
flop_fwd
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
std
::
size_t
num_btype_fwd
=
(
sizeof
(
DataType
)
*
M
*
K
+
sizeof
(
DataType
)
*
K
*
N
+
sizeof
(
DataType
)
*
N
*
O
+
sizeof
(
DataType
)
*
M
*
O
)
*
BatchCount
;
std
::
size_t
flop_fwd
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
std
::
size_t
num_btype_fwd
=
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
sizeof
(
InputDataType
)
*
N
*
O
+
sizeof
(
InputDataType
)
*
M
*
O
)
*
BatchCount
;
float
tflops_fwd
=
static_cast
<
float
>
(
flop_fwd
)
/
1.E9
/
ave_time_fwd
;
...
...
@@ -981,16 +1000,16 @@ int run(int argc, char* argv[])
// not need output z matrix
auto
argument_bwd
=
gemm_bwd
.
MakeArgument
(
static_cast
<
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
nullptr
),
// set to nullptr
static_cast
<
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Output
DataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Output
DataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Output
DataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths
,
...
...
@@ -1026,10 +1045,13 @@ int run(int argc, char* argv[])
// 3x MNK + 2x MNO
std
::
size_t
flop_bwd
=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
std
::
size_t
num_btype_bwd
=
(
sizeof
(
DataType
)
*
M
*
K
+
sizeof
(
DataType
)
*
K
*
N
+
sizeof
(
DataType
)
*
N
*
O
+
sizeof
(
DataType
)
*
M
*
O
)
*
size_t
(
2
)
*
BatchCount
+
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
std
::
size_t
num_btype_bwd
=
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
sizeof
(
InputDataType
)
*
N
*
O
+
sizeof
(
InputDataType
)
*
M
*
O
*
size_t
(
2
)
+
sizeof
(
OutputDataType
)
*
M
*
K
+
sizeof
(
OutputDataType
)
*
K
*
N
+
sizeof
(
OutputDataType
)
*
N
*
O
)
*
BatchCount
+
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
float
tflops_bwd
=
static_cast
<
float
>
(
flop_bwd
)
/
1.E9
/
ave_time_bwd
;
...
...
@@ -1042,39 +1064,39 @@ int run(int argc, char* argv[])
bool
pass
=
true
;
if
(
do_verification
)
{
Tensor
<
DataType
>
y_gs_ms_os_host_result
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
Input
DataType
>
y_gs_ms_os_host_result
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_host_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
DataType
>
qgrad_gs_ms_ks_host_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
kgrad_gs_ns_ks_host_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
DataType
>
vgrad_gs_os_ns_host_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
Output
DataType
>
qgrad_gs_ms_ks_host_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
Output
DataType
>
kgrad_gs_ns_ks_host_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
Output
DataType
>
vgrad_gs_os_ns_host_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
DataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
Input
DataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
Input
DataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
ZDataType
>
z_fwd_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
ZDataType
>
z_bwd_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
Input
DataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
Input
DataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
Tensor
<
DataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
DataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
DataType
>
sgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
DataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
Tensor
<
Output
DataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
Output
DataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
Output
DataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
Input
DataType
>
sgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
pgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
pgrad_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
Input
DataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
// get kernel output matrixes
{
auto
argument_fwd
=
gemm_fwd
.
MakeArgument
(
static_cast
<
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
z_fwd_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
...
...
@@ -1125,16 +1147,16 @@ int run(int argc, char* argv[])
vgrad_device_buf
.
SetZero
();
auto
argument_bwd
=
gemm_bwd
.
MakeArgument
(
static_cast
<
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
z_bwd_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Input
DataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Output
DataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Output
DataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Output
DataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths
,
...
...
@@ -1223,13 +1245,16 @@ int run(int argc, char* argv[])
#endif
// Gradients
auto
ref_gemm_grad
=
ReferenceGemmGradInstance
{};
auto
ref_gemm_grad_invoker
=
ref_gemm_grad
.
MakeInvoker
();
using
RefGemmGradArg
=
ReferenceGemmGradInstance
::
Argument
;
auto
ref_gemm0_grad
=
ReferenceGemm0GradInstance
{};
auto
ref_gemm0_grad_invoker
=
ref_gemm0_grad
.
MakeInvoker
();
using
RefGemm0GradArg
=
ReferenceGemm0GradInstance
::
Argument
;
auto
ref_gemm1_grad
=
ReferenceGemm1GradInstance
{};
auto
ref_gemm1_grad_invoker
=
ref_gemm1_grad
.
MakeInvoker
();
using
RefGemm1GradArg
=
ReferenceGemm1GradInstance
::
Argument
;
// dP_dropout = dY * V^T
auto
v_g_o_n
=
v_g_n_o
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm
0
_grad_invoker
.
Run
(
RefGemm
0
GradArg
{
ygrad_g_m_o
,
v_g_o_n
,
pgrad_drop_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
#if PRINT_HOST
{
...
...
@@ -1256,7 +1281,7 @@ int run(int argc, char* argv[])
ygrad_dot_y
+=
ck
::
type_convert
<
AccDataType
>
(
ygrad_g_m_o
(
idx_gmo
))
*
ck
::
type_convert
<
AccDataType
>
(
y_g_m_o
(
idx_gmo
));
}
self
(
idx_gmn
)
=
ck
::
type_convert
<
DataType
>
(
self
(
idx_gmn
)
=
ck
::
type_convert
<
Input
DataType
>
(
ck
::
type_convert
<
AccDataType
>
(
p_g_m_n
(
idx_gmn
))
*
(
ck
::
type_convert
<
AccDataType
>
(
pgrad_g_m_n
(
idx_gmn
))
-
ygrad_dot_y
));
});
...
...
@@ -1272,7 +1297,7 @@ int run(int argc, char* argv[])
#endif
// dV = P_drop^T * dY
auto
p_drop_g_n_m
=
p_drop_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm
1
_grad_invoker
.
Run
(
RefGemm
1
GradArg
{
p_drop_g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.0
f
}});
#if PRINT_HOST
{
...
...
@@ -1284,7 +1309,7 @@ int run(int argc, char* argv[])
#endif
// dQ = alpha * dS * K
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm
1
_grad_invoker
.
Run
(
RefGemm
1
GradArg
{
sgrad_g_m_n
,
k_g_n_k
,
qgrad_g_m_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
#if PRINT_HOST
{
...
...
@@ -1297,7 +1322,7 @@ int run(int argc, char* argv[])
// dK = alpha * dS^T * Q
auto
sgrad_g_n_m
=
sgrad_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm
1
_grad_invoker
.
Run
(
RefGemm
1
GradArg
{
sgrad_g_n_m
,
q_g_m_k
,
kgrad_g_n_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
#if PRINT_HOST
{
...
...
@@ -1355,7 +1380,7 @@ int run(int argc, char* argv[])
double
atol
=
1e-3
;
// when BF16 is taken, set absolute error and relative error to 0.01
if
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
||
std
::
is_same_v
<
GemmDataType
,
ck
::
bhalf_t
>
)
if
(
std
::
is_same_v
<
Input
DataType
,
ck
::
bhalf_t
>
||
std
::
is_same_v
<
GemmDataType
,
ck
::
bhalf_t
>
)
{
rtol
=
1e-2
;
atol
=
1e-2
;
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train.cpp
View file @
3b57967f
...
...
@@ -70,7 +70,8 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using
QKVElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
DataType
=
BF16
;
using
InputDataType
=
BF16
;
using
OutputDataType
=
F32
;
using
GemmDataType
=
BF16
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
...
...
@@ -84,6 +85,9 @@ static constexpr ck::index_t NumDimM = 1;
static
constexpr
ck
::
index_t
NumDimN
=
1
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
// When OutputDataType == F32, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static
constexpr
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
=
4
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
...
...
@@ -111,10 +115,10 @@ using DeviceGemmInstanceFWD =
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
DataType
,
DataType
,
DataType
,
Input
DataType
,
Input
DataType
,
Input
DataType
,
Input
DataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
...
...
@@ -181,7 +185,8 @@ using DeviceGemmInstanceBWD =
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
...
...
@@ -239,8 +244,8 @@ using DeviceGemmInstanceBWD =
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
#elif(DIM <= 64)
using
DeviceGemmInstanceFWD
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
...
...
@@ -249,10 +254,10 @@ using DeviceGemmInstanceFWD =
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
DataType
,
DataType
,
DataType
,
Input
DataType
,
Input
DataType
,
Input
DataType
,
Input
DataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
...
...
@@ -319,7 +324,8 @@ using DeviceGemmInstanceBWD =
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
...
...
@@ -377,8 +383,8 @@ using DeviceGemmInstanceBWD =
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
// using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...
...
@@ -387,7 +393,8 @@ using DeviceGemmInstanceBWD =
// NumDimN,
// NumDimK,
// NumDimO,
// DataType,
// InputDataType,
// OutputDataType,
// GemmDataType,
// ZDataType,
// LSEDataType,
...
...
@@ -445,8 +452,8 @@ using DeviceGemmInstanceBWD =
// 1, // CShuffleMXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
//
8, //
CShuffleBlockTransferScalarPerVector_NPerBlock
// MaskingSpec>;
// MaskingSpecialization
// CShuffleBlockTransferScalarPerVector_NPerBlock
,
// MaskingSpec>;
#elif(DIM <= 128)
using
DeviceGemmInstanceFWD
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
...
...
@@ -455,10 +462,10 @@ using DeviceGemmInstanceFWD =
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
DataType
,
DataType
,
DataType
,
Input
DataType
,
Input
DataType
,
Input
DataType
,
Input
DataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
...
...
@@ -525,7 +532,8 @@ using DeviceGemmInstanceBWD =
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
...
...
@@ -583,14 +591,14 @@ using DeviceGemmInstanceBWD =
1
,
// CShuffleMXdlPerWavePerShuffle
4
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
#endif
// Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
DataType
,
DataType
,
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
Input
DataType
,
Input
DataType
,
AccDataType
,
AccDataType
,
PassThrough
,
...
...
@@ -600,13 +608,13 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
// Ref Softmax: P = Softmax(S)
// fp32 in, fp16 out
using
ReferenceSoftmaxInstance
=
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
DataType
,
AccDataType
>
;
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
AccDataType
,
Input
DataType
,
AccDataType
>
;
// Ref Gemm1: Y = P * V
// fp16 in, fp16 out
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
DataType
,
DataType
,
DataType
,
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
Input
DataType
,
Input
DataType
,
Input
DataType
,
AccDataType
,
PassThrough
,
PassThrough
,
...
...
@@ -614,16 +622,25 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
// Ref Gemm for backward pass
// fp16 in, fp16 out
using
ReferenceGemmGradInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
DataType
,
DataType
,
DataType
,
AccDataType
,
PassThrough
,
PassThrough
,
Scale
>
;
using
ReferenceGemm0GradInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
InputDataType
,
InputDataType
,
InputDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
Scale
>
;
using
ReferenceGemm1GradInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
InputDataType
,
InputDataType
,
OutputDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
Scale
>
;
// Ref dropout
using
ReferenceDropoutInstance
=
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ushort
,
DataType
,
DataType
>
;
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ushort
,
Input
DataType
,
Input
DataType
>
;
template
<
typename
TensorQ
,
typename
TensorK
,
...
...
@@ -764,28 +781,28 @@ int run(int argc, char* argv[])
std
::
vector
<
void
*>
p_vgrad
;
std
::
vector
<
const
void
*>
p_ygrad
;
std
::
vector
<
Tensor
<
DataType
>>
q_g_m_ks
;
std
::
vector
<
Tensor
<
DataType
>>
k_g_n_ks
;
std
::
vector
<
Tensor
<
Input
DataType
>>
q_g_m_ks
;
std
::
vector
<
Tensor
<
Input
DataType
>>
k_g_n_ks
;
std
::
vector
<
Tensor
<
ZDataType
>>
z_fwd_g_m_ns
;
std
::
vector
<
Tensor
<
ZDataType
>>
z_bwd_g_m_ns
;
std
::
vector
<
Tensor
<
DataType
>>
v_g_n_os
;
std
::
vector
<
Tensor
<
Input
DataType
>>
v_g_n_os
;
std
::
vector
<
Tensor
<
AccDataType
>>
s_g_m_ns
;
std
::
vector
<
Tensor
<
DataType
>>
p_g_m_ns
;
std
::
vector
<
Tensor
<
DataType
>>
y_g_m_os
;
std
::
vector
<
Tensor
<
Input
DataType
>>
p_g_m_ns
;
std
::
vector
<
Tensor
<
Input
DataType
>>
y_g_m_os
;
std
::
vector
<
Tensor
<
LSEDataType
>>
lse_g_ms
;
std
::
vector
<
Tensor
<
DataType
>>
p_drop_g_m_ns
;
std
::
vector
<
Tensor
<
Input
DataType
>>
p_drop_g_m_ns
;
std
::
vector
<
Tensor
<
DataType
>>
q_tensors
;
std
::
vector
<
Tensor
<
DataType
>>
k_tensors
;
std
::
vector
<
Tensor
<
DataType
>>
v_tensors
;
std
::
vector
<
Tensor
<
DataType
>>
y_tensors
;
std
::
vector
<
Tensor
<
Input
DataType
>>
q_tensors
;
std
::
vector
<
Tensor
<
Input
DataType
>>
k_tensors
;
std
::
vector
<
Tensor
<
Input
DataType
>>
v_tensors
;
std
::
vector
<
Tensor
<
Input
DataType
>>
y_tensors
;
std
::
vector
<
Tensor
<
ZDataType
>>
z_fwd_tensors
;
std
::
vector
<
Tensor
<
ZDataType
>>
z_bwd_tensors
;
std
::
vector
<
Tensor
<
LSEDataType
>>
lse_tensors
;
std
::
vector
<
Tensor
<
DataType
>>
qgrad_tensors
;
std
::
vector
<
Tensor
<
DataType
>>
kgrad_tensors
;
std
::
vector
<
Tensor
<
DataType
>>
vgrad_tensors
;
std
::
vector
<
Tensor
<
DataType
>>
ygrad_tensors
;
std
::
vector
<
Tensor
<
Output
DataType
>>
qgrad_tensors
;
std
::
vector
<
Tensor
<
Output
DataType
>>
kgrad_tensors
;
std
::
vector
<
Tensor
<
Output
DataType
>>
vgrad_tensors
;
std
::
vector
<
Tensor
<
Input
DataType
>>
ygrad_tensors
;
std
::
vector
<
DeviceMemPtr
>
q_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
k_tensors_device
;
...
...
@@ -885,24 +902,26 @@ int run(int argc, char* argv[])
int
BatchCount
=
G0
*
G1
;
flop_fwd
+=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
num_byte_fwd
+=
(
sizeof
(
DataType
)
*
M
*
K
+
sizeof
(
DataType
)
*
K
*
N
+
sizeof
(
DataType
)
*
N
*
O
+
sizeof
(
DataType
)
*
M
*
O
)
*
num_byte_fwd
+=
(
sizeof
(
Input
DataType
)
*
M
*
K
+
sizeof
(
Input
DataType
)
*
K
*
N
+
sizeof
(
Input
DataType
)
*
N
*
O
+
sizeof
(
Input
DataType
)
*
M
*
O
)
*
BatchCount
;
flop_bwd
+=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
num_byte_bwd
+=
(
sizeof
(
DataType
)
*
M
*
K
+
sizeof
(
DataType
)
*
K
*
N
+
sizeof
(
DataType
)
*
N
*
O
+
sizeof
(
DataType
)
*
M
*
O
)
*
size_t
(
2
)
*
BatchCount
+
num_byte_bwd
+=
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
sizeof
(
InputDataType
)
*
N
*
O
+
sizeof
(
InputDataType
)
*
M
*
O
*
size_t
(
2
)
+
sizeof
(
OutputDataType
)
*
M
*
K
+
sizeof
(
OutputDataType
)
*
K
*
N
+
sizeof
(
OutputDataType
)
*
N
*
O
)
*
BatchCount
+
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
Tensor
<
DataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
Input
DataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
Input
DataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
ZDataType
>
z_fwd_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
ZDataType
>
z_bwd_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
DataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
DataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
DataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
Input
DataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
Input
DataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
Input
DataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
if
(
i
<
4
)
{
...
...
@@ -913,46 +932,46 @@ int run(int argc, char* argv[])
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
}
z_fwd_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
0
});
z_bwd_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
0
});
z_fwd_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
0
});
z_bwd_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
0
});
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
2
,
2
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
2
,
2
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
2
,
2
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
2
,
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
2
,
2
});
break
;
case
2
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
0.0
,
1.0
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
0.0
,
1.0
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
-
0.5
,
0.5
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_3
<
DataType
>
{
-
0.5
,
0.5
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
Input
DataType
>
{
0.0
,
1.0
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
Input
DataType
>
{
0.0
,
1.0
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
Input
DataType
>
{
-
0.5
,
0.5
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_3
<
Input
DataType
>
{
-
0.5
,
0.5
});
break
;
case
3
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
DataType
>
{
-
5
,
5
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
Input
DataType
>
{
-
5
,
5
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
break
;
case
4
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
2
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
2
});
break
;
case
5
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
// dO dot O = [0; 1; 2; ...]
break
;
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
Input
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
Input
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
...
...
@@ -963,10 +982,11 @@ int run(int argc, char* argv[])
//
break
;
default:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
// dy[g0, g1, m, o]
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
...
...
@@ -977,16 +997,16 @@ int run(int argc, char* argv[])
// = 0
}
Tensor
<
DataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
Input
DataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
Input
DataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
ZDataType
>
z_fwd_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
ZDataType
>
z_bwd_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
Input
DataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
Input
DataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
Tensor
<
DataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
...
...
@@ -1019,27 +1039,27 @@ int run(int argc, char* argv[])
ygrad_tensors
.
push_back
(
ygrad_gs_ms_os
);
q_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
Input
DataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
k_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
Input
DataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
z_fwd_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ZDataType
)
*
z_fwd_gs_ms_ns
.
GetElementSpaceSize
()));
z_bwd_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ZDataType
)
*
z_bwd_gs_ms_ns
.
GetElementSpaceSize
()));
v_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
Input
DataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
y_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
Input
DataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
lse_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
GetElementSpaceSize
()));
qgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
Output
DataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
kgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
Output
DataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
vgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
Output
DataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
ygrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
Input
DataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
q_tensors_device
.
back
()
->
ToDevice
(
q_gs_ms_ks
.
data
());
k_tensors_device
.
back
()
->
ToDevice
(
k_gs_ns_ks
.
data
());
...
...
@@ -1258,23 +1278,26 @@ int run(int argc, char* argv[])
int
M
=
q_tensors
[
i
].
GetLengths
()[
2
];
int
K
=
q_tensors
[
i
].
GetLengths
()[
3
];
int
BatchCount
=
G0
*
G1
;
Tensor
<
DataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
DataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
DataType
>
sgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
Output
DataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
Output
DataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
Output
DataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
Input
DataType
>
sgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
pgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
pgrad_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
Input
DataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
ygrad_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
auto
ref_gemm_grad
=
ReferenceGemmGradInstance
{};
auto
ref_gemm_grad_invoker
=
ref_gemm_grad
.
MakeInvoker
();
using
RefGemmGradArg
=
ReferenceGemmGradInstance
::
Argument
;
auto
ref_gemm0_grad
=
ReferenceGemm0GradInstance
{};
auto
ref_gemm0_grad_invoker
=
ref_gemm0_grad
.
MakeInvoker
();
using
RefGemm0GradArg
=
ReferenceGemm0GradInstance
::
Argument
;
auto
ref_gemm1_grad
=
ReferenceGemm1GradInstance
{};
auto
ref_gemm1_grad_invoker
=
ref_gemm1_grad
.
MakeInvoker
();
using
RefGemm1GradArg
=
ReferenceGemm1GradInstance
::
Argument
;
// dP = dY * V^T
auto
v_g_o_n
=
v_g_n_os
[
i
].
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm
0
_grad_invoker
.
Run
(
RefGemm
0
GradArg
{
ygrad_g_m_o
,
v_g_o_n
,
pgrad_drop_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
...
...
@@ -1291,39 +1314,39 @@ int run(int argc, char* argv[])
ygrad_dot_y
+=
ck
::
type_convert
<
AccDataType
>
(
ygrad_g_m_o
(
idx_gmo
))
*
ck
::
type_convert
<
AccDataType
>
(
y_g_m_os
[
i
](
idx_gmo
));
}
self
(
idx_gmn
)
=
ck
::
type_convert
<
DataType
>
(
self
(
idx_gmn
)
=
ck
::
type_convert
<
Input
DataType
>
(
ck
::
type_convert
<
AccDataType
>
(
p_g_m_ns
[
i
](
idx_gmn
))
*
(
ck
::
type_convert
<
AccDataType
>
(
pgrad_g_m_n
(
idx_gmn
))
-
ygrad_dot_y
));
});
auto
p_drop_g_n_m
=
p_drop_g_m_ns
[
i
].
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm
1
_grad_invoker
.
Run
(
RefGemm
1
GradArg
{
p_drop_g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.0
f
}});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm
1
_grad_invoker
.
Run
(
RefGemm
1
GradArg
{
sgrad_g_m_n
,
k_g_n_ks
[
i
],
qgrad_g_m_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
auto
sgrad_g_n_m
=
sgrad_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm
1
_grad_invoker
.
Run
(
RefGemm
1
GradArg
{
sgrad_g_n_m
,
q_g_m_ks
[
i
],
kgrad_g_n_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
Tensor
<
DataType
>
qgrad_gs_ms_ks_host_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
kgrad_gs_ns_ks_host_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
vgrad_gs_os_ns_host_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
y_gs_ms_os_host_result
(
y_tensors
[
i
].
GetLengths
(),
y_tensors
[
i
].
GetStrides
());
Tensor
<
Output
DataType
>
qgrad_gs_ms_ks_host_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
Tensor
<
Output
DataType
>
kgrad_gs_ns_ks_host_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
Tensor
<
Output
DataType
>
vgrad_gs_os_ns_host_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
Tensor
<
Input
DataType
>
y_gs_ms_os_host_result
(
y_tensors
[
i
].
GetLengths
(),
y_tensors
[
i
].
GetStrides
());
Tensor
<
LSEDataType
>
lse_gs_ms_host_result
(
lse_tensors
[
i
].
GetLengths
(),
lse_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
qgrad_gs_ms_ks_device_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
kgrad_gs_ns_ks_device_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
vgrad_gs_os_ns_device_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
y_gs_ms_os_device_result
(
y_tensors
[
i
].
GetLengths
(),
y_tensors
[
i
].
GetStrides
());
Tensor
<
Output
DataType
>
qgrad_gs_ms_ks_device_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
Tensor
<
Output
DataType
>
kgrad_gs_ns_ks_device_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
Tensor
<
Output
DataType
>
vgrad_gs_os_ns_device_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
Tensor
<
Input
DataType
>
y_gs_ms_os_device_result
(
y_tensors
[
i
].
GetLengths
(),
y_tensors
[
i
].
GetStrides
());
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_tensors
[
i
].
GetLengths
(),
lse_tensors
[
i
].
GetStrides
());
...
...
@@ -1380,7 +1403,8 @@ int run(int argc, char* argv[])
double
atol
=
1e-3
;
// when BF16 is taken, set absolute error and relative error to 0.01
if
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
||
std
::
is_same_v
<
GemmDataType
,
ck
::
bhalf_t
>
)
if
(
std
::
is_same_v
<
InputDataType
,
ck
::
bhalf_t
>
||
std
::
is_same_v
<
GemmDataType
,
ck
::
bhalf_t
>
)
{
rtol
=
1e-2
;
atol
=
1e-2
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment