Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
70d700b3
Commit
70d700b3
authored
Sep 04, 2023
by
danyao12
Browse files
optimized bwd split kernels w/ bias
parent
9e11dea6
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
967 additions
and
553 deletions
+967
-553
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3.cpp
..._softmax_gemm/batched_multihead_attention_backward_v3.cpp
+2
-2
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v3.cpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+262
-181
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+267
-188
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+6
-6
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+6
-6
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+177
-82
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+244
-85
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+1
-1
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3.cpp
View file @
70d700b3
...
...
@@ -83,7 +83,7 @@ static constexpr ck::index_t NumDimO = 1;
// When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static
constexpr
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
=
8
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
;
...
...
@@ -119,7 +119,7 @@ using DeviceGemmInstance =
// ##############################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | |
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 64, 32, 64, 8, 8, 2, 32, 32, 2, 1, 2, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
DDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
64
,
64
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
2
,
1
,
64
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
DDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
64
,
64
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
4
,
1
,
2
,
1
,
64
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// ##############################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| DDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| Gemm2| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2|YDotYGrad| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ##############################################################################################| | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| KPer| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ##############################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
View file @
70d700b3
...
...
@@ -118,7 +118,7 @@ using DeviceGemmInstance =
// ##############################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | |
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, DDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 64, 32, 64, 8, 8, 2, 32, 32, 2, 1, 2, 1, 64, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
DDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
64
,
64
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
1
,
2
,
1
,
64
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
DDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
64
,
64
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
4
,
1
,
2
,
1
,
64
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpec
,
Deterministic
>
;
// ##############################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| DDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| Gemm2| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2|YDotYGrad| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ##############################################################################################| | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| KPer| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ##############################################################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Block| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
70d700b3
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
70d700b3
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
70d700b3
...
...
@@ -566,9 +566,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return
lse_grid_desc_mraw
;
}
}
// D in Gemm0 C position
static
auto
MakeDGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
// D
0
in Gemm0 C position
static
auto
MakeD
0
GridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
}
...
...
@@ -585,7 +585,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeDGridDescriptor_M_N
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD
0
GridDescriptor_M_N
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
YGradGridDesc_O0_M_O1
=
decltype
(
MakeYGradGridDescriptor_O0_M_O1
({},
{}));
using
ZGridDesc_M_N
=
decltype
(
MakeZGridDescriptor_M_N
({},
{}));
...
...
@@ -857,8 +857,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
const
auto
d0_grid_desc_m_n
=
MakeDGridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
const
auto
d0_grid_desc_m_n
=
MakeD0GridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
70d700b3
...
...
@@ -518,9 +518,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
v_grid_desc_n_o
,
Number
<
V_O1
>
{});
}
// D in Gemm0 C position
static
auto
MakeDGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
// D
0
in Gemm0 C position
static
auto
MakeD
0
GridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
}
...
...
@@ -594,7 +594,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeDGridDescriptor_M_N
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD
0
GridDescriptor_M_N
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
YGradGridDesc_M0_O_M1
=
decltype
(
MakeYGradGridDescriptor_M0_O_M1
(
YGridDesc_M_O
{}));
using
ZGridDesc_M_N
=
decltype
(
MakeZGridDescriptor_M_N
({},
{}));
...
...
@@ -870,8 +870,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
const
auto
d0_grid_desc_m_n
=
MakeDGridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
const
auto
d0_grid_desc_m_n
=
MakeD0GridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
70d700b3
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
70d700b3
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
70d700b3
...
...
@@ -498,7 +498,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return
lse_grid_desc_mraw
;
}
}
// D in Gemm0 C position
// D
0
in Gemm0 C position
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
70d700b3
...
...
@@ -561,7 +561,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return
lse_grid_desc_mraw
;
}
}
// D in Gemm0 C position
// D
0
in Gemm0 C position
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment