Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
72498367
"docs/source/vscode:/vscode.git/clone" did not exist on "92e5ddd2959c734c1e24e31414c8b685aafaf583"
Commit
72498367
authored
Sep 04, 2023
by
guangzlu
Browse files
fixed bugs
parent
d1280ed5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
13 deletions
+13
-13
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
...ched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
+6
-6
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...atched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+7
-7
No files found.
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
View file @
72498367
...
@@ -32,10 +32,10 @@ using YElementOp = PassThrough;
...
@@ -32,10 +32,10 @@ using YElementOp = PassThrough;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
//
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
TensorDefault
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
//
static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default;
static
constexpr
auto
TensorSpecQ
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecQ
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
...
@@ -57,10 +57,10 @@ using device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances =
...
@@ -57,10 +57,10 @@ using device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances =
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
unsigned
short
,
F32
,
Acc0BiasDataType
,
Acc1BiasDataType
,
F32
,
F32
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
,
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
unsigned
short
,
F32
,
Acc0BiasDataType
,
Acc1BiasDataType
,
F32
,
F32
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
unsigned
short
,
F32
,
Acc0BiasDataType
,
Acc1BiasDataType
,
F32
,
F32
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
4
,
1
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
,
//
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
unsigned
short
,
F32
,
Acc0BiasDataType
,
Acc1BiasDataType
,
F32
,
F32
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
4
,
1
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
,
//
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
unsigned
short
,
F32
,
Acc0BiasDataType
,
Acc1BiasDataType
,
F32
,
F32
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
4
,
1
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
//
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>
// clang-format on
// clang-format on
>
;
>
;
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
72498367
...
@@ -32,10 +32,10 @@ using YElementOp = PassThrough;
...
@@ -32,10 +32,10 @@ using YElementOp = PassThrough;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
//
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
TensorDefault
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
//
static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default;
static
constexpr
auto
TensorSpecQ
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecQ
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
...
@@ -57,16 +57,16 @@ using device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances =
...
@@ -57,16 +57,16 @@ using device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances =
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
F16
,
F16
,
F16
,
unsigned
short
,
F32
,
Acc0BiasDataType
,
Acc1BiasDataType
,
F32
,
F32
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
,
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
F16
,
F16
,
F16
,
unsigned
short
,
F32
,
Acc0BiasDataType
,
Acc1BiasDataType
,
F32
,
F32
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
64
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
F16
,
F16
,
F16
,
unsigned
short
,
F32
,
Acc0BiasDataType
,
Acc1BiasDataType
,
F32
,
F32
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
4
,
1
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
,
//
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
F16
,
F16
,
F16
,
unsigned
short
,
F32
,
Acc0BiasDataType
,
Acc1BiasDataType
,
F32
,
F32
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
4
,
1
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
,
//
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
F16
,
F16
,
F16
,
unsigned
short
,
F32
,
Acc0BiasDataType
,
Acc1BiasDataType
,
F32
,
F32
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
4
,
1
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
//
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>
// clang-format on
// clang-format on
>
;
>
;
void
add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatched
GemmSoftmaxGemmPermute
<
2
,
DeviceBatched
MultiheadAttentionBackward
<
2
,
1
,
1
,
1
,
1
,
1
,
1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment