Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
883c060a
"configs/vscode:/vscode.git/clone" did not exist on "3cfe73de3ff75cdb657f0db77cd6228c28f41da1"
Commit
883c060a
authored
Sep 04, 2023
by
guangzlu
Browse files
bwd qloop v1 passed
parent
72498367
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
272 additions
and
501 deletions
+272
-501
client_example/08_fused_attention/fused_attention_bwd.cpp
client_example/08_fused_attention/fused_attention_bwd.cpp
+106
-323
library/include/ck/library/tensor_operation_instance/gpu/batched_mha_bwd_qloop.hpp
...y/tensor_operation_instance/gpu/batched_mha_bwd_qloop.hpp
+46
-50
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
...ched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
+60
-64
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...atched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+60
-64
No files found.
client_example/08_fused_attention/fused_attention_bwd.cpp
View file @
883c060a
This diff is collapsed.
Click to expand it.
library/include/ck/library/tensor_operation_instance/gpu/batched_mha_bwd_qloop.hpp
View file @
883c060a
...
@@ -18,15 +18,36 @@ namespace device {
...
@@ -18,15 +18,36 @@ namespace device {
namespace
instance
{
namespace
instance
{
void
add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
unsigned
short
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
instances
);
void
add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
unsigned
short
,
unsigned
short
,
F32
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
...
@@ -34,31 +55,29 @@ void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances(
...
@@ -34,31 +55,29 @@ void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances(
Scale
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
instances
);
void
add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
unsigned
short
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
void
add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
unsigned
short
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
instances
);
void
add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
...
@@ -68,7 +87,7 @@ void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
...
@@ -68,7 +87,7 @@ void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
BF16
,
BF16
,
BF16
,
BF16
,
unsigned
short
,
unsigned
short
,
F32
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
...
@@ -76,29 +95,7 @@ void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
...
@@ -76,29 +95,7 @@ void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
Scale
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
instances
);
void
add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
unsigned
short
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
template
<
typename
InputDataType
,
template
<
typename
InputDataType
,
typename
OutputDataType
,
typename
OutputDataType
,
...
@@ -151,8 +148,7 @@ struct DeviceOperationInstanceFactory<
...
@@ -151,8 +148,7 @@ struct DeviceOperationInstanceFactory<
{
{
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
)
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
)
{
{
add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances
(
add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
op_ptrs
);
}
}
else
if
(
MaskingSpec
==
MaskingSpecialization
::
MaskDisabled
)
else
if
(
MaskingSpec
==
MaskingSpecialization
::
MaskDisabled
)
{
{
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
View file @
883c060a
...
@@ -32,10 +32,11 @@ using YElementOp = PassThrough;
...
@@ -32,10 +32,11 @@ using YElementOp = PassThrough;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
//static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
//
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
//static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default;
// static constexpr auto TensorDefault =
// ck::tensor_operation::device::TensorSpecialization::Default;
static
constexpr
auto
TensorSpecQ
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecQ
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
...
@@ -50,9 +51,8 @@ template <index_t NumDimG,
...
@@ -50,9 +51,8 @@ template <index_t NumDimG,
index_t
NumDimK
,
index_t
NumDimK
,
index_t
NumDimO
,
index_t
NumDimO
,
MaskingSpecialization
MaskingSpec
>
MaskingSpecialization
MaskingSpec
>
using
device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances
=
using
device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
std
::
tuple
<
// clang-format off
// clang-format off
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
...
@@ -61,71 +61,67 @@ using device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances =
...
@@ -61,71 +61,67 @@ using device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>
// clang-format on
// clang-format on
>
;
>
;
void
add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
unsigned
short
,
unsigned
short
,
F32
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
instances
)
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances
<
device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances
<
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>
{});
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>
{});
}
}
void
add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
unsigned
short
,
unsigned
short
,
F32
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
MaskingSpecialization
::
MaskDisabled
>>>&
instances
)
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances
<
device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances
<
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
MaskingSpecialization
::
MaskDisabled
>
{});
MaskingSpecialization
::
MaskDisabled
>
{});
}
}
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
883c060a
...
@@ -32,10 +32,11 @@ using YElementOp = PassThrough;
...
@@ -32,10 +32,11 @@ using YElementOp = PassThrough;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
//static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
//
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
//static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default;
// static constexpr auto TensorDefault =
// ck::tensor_operation::device::TensorSpecialization::Default;
static
constexpr
auto
TensorSpecQ
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecQ
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
...
@@ -50,9 +51,8 @@ template <index_t NumDimG,
...
@@ -50,9 +51,8 @@ template <index_t NumDimG,
index_t
NumDimK
,
index_t
NumDimK
,
index_t
NumDimO
,
index_t
NumDimO
,
MaskingSpecialization
MaskingSpec
>
MaskingSpecialization
MaskingSpec
>
using
device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances
=
using
device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
std
::
tuple
<
// clang-format off
// clang-format off
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
...
@@ -61,71 +61,67 @@ using device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances =
...
@@ -61,71 +61,67 @@ using device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>
// clang-format on
// clang-format on
>
;
>
;
void
add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
unsigned
short
,
unsigned
short
,
F32
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>>>&
instances
)
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances
<
device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances
<
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>
{});
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
>
{});
}
}
void
add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
std
::
unique_ptr
<
DeviceBatchedMultiheadAttentionBackward
<
2
,
DeviceBatchedMultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
unsigned
short
,
unsigned
short
,
F32
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
MaskingSpecialization
::
MaskDisabled
>>>&
instances
)
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances
<
device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances
<
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
MaskingSpecialization
::
MaskDisabled
>
{});
MaskingSpecialization
::
MaskDisabled
>
{});
}
}
}
// namespace instance
}
// namespace instance
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment