Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7b915a10
Commit
7b915a10
authored
Jul 28, 2023
by
danyao12
Browse files
bwd qloop 2 kernels update mask
parent
2018bd28
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
43 additions
and
47 deletions
+43
-47
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
..._softmax_gemm/batched_multihead_attention_backward_v2.cpp
+2
-2
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3.cpp
..._softmax_gemm/batched_multihead_attention_backward_v3.cpp
+3
-2
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v3.cpp
+4
-3
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+8
-10
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+8
-10
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+9
-10
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+9
-10
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
View file @
7b915a10
...
...
@@ -24,7 +24,7 @@ Kernel outputs:
*/
#define PRINT_HOST 0
#define USING_MASK
1
#define USING_MASK
0
#define DIM 128 // DIM should be a multiple of 8.
#include <iostream>
...
...
@@ -268,7 +268,7 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
253
;
ck
::
index_t
M
=
512
;
ck
::
index_t
N
=
512
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3.cpp
View file @
7b915a10
...
...
@@ -86,7 +86,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
Mask
Out
UpperTriangle
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTriangle
FromTopLeft
;
#else
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
...
...
@@ -228,8 +228,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// masking
auto
M
=
s_g_m_n
.
GetLengths
()[
1
];
auto
N
=
s_g_m_n
.
GetLengths
()[
2
];
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
N
);
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
M
,
N
);
s_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
View file @
7b915a10
...
...
@@ -68,7 +68,7 @@ using GemmDataType = F16;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
ZDataType
=
INT32
;
// U16
using
ZDataType
=
U16
;
// INT32
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
DDataType
=
F32
;
...
...
@@ -85,7 +85,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
Mask
Out
UpperTriangle
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTriangle
FromTopLeft
;
#else
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
...
...
@@ -227,8 +227,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// masking
auto
M
=
s_g_m_n
.
GetLengths
()[
1
];
auto
N
=
s_g_m_n
.
GetLengths
()[
2
];
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
N
);
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
M
,
N
);
s_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
7b915a10
...
...
@@ -359,12 +359,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
Q_K1
=
8
;
static
constexpr
index_t
K_K1
=
8
;
static
constexpr
index_t
V_N1
=
2
;
static
constexpr
index_t
Q_M1
=
2
;
static
constexpr
index_t
K_N1
=
2
;
static
constexpr
index_t
V_O1
=
8
;
static
constexpr
index_t
Y_O1
=
8
;
static
constexpr
index_t
Y_M1
=
2
;
...
...
@@ -666,9 +660,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
)
{
return
MaskUpperTriangleFromTopLeftPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTriangleFromBottomRight
)
{
return
Mask
Out
UpperTrianglePredicate
{};
return
MaskUpperTriangle
FromBottomRight
Predicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
@@ -791,7 +789,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
=
=
MaskingSpecialization
::
Mask
OutUpperTriang
le
,
MaskingSpec
!
=
MaskingSpecialization
::
Mask
Disab
le
d
,
Deterministic
>
;
// GridwiseYDotYGrad
...
...
@@ -892,7 +890,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c0_matrix_mask_
{
b_grid_desc_g_n_k_
.
GetLength
(
I1
)},
c0_matrix_mask_
{
a_grid_desc_g_m_k_
.
GetLength
(
I1
),
b_grid_desc_g_n_k_
.
GetLength
(
I1
)},
raw_lengths_mz_nz_kz_gemm1nz_
{
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
+
NumDimK
-
1
],
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
7b915a10
...
...
@@ -365,12 +365,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
Q_K1
=
8
;
static
constexpr
index_t
K_K1
=
8
;
static
constexpr
index_t
V_N1
=
2
;
static
constexpr
index_t
Q_M1
=
2
;
static
constexpr
index_t
K_N1
=
2
;
static
constexpr
index_t
V_O1
=
8
;
static
constexpr
index_t
Y_O1
=
8
;
static
constexpr
index_t
Y_M1
=
2
;
...
...
@@ -672,9 +666,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
)
{
return
MaskUpperTriangleFromTopLeftPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTriangleFromBottomRight
)
{
return
Mask
Out
UpperTrianglePredicate
{};
return
MaskUpperTriangle
FromBottomRight
Predicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
@@ -805,7 +803,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
=
=
MaskingSpecialization
::
Mask
OutUpperTriang
le
,
MaskingSpec
!
=
MaskingSpecialization
::
Mask
Disab
le
d
,
Deterministic
>
;
// GridwiseYDotYGrad
...
...
@@ -905,7 +903,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c0_matrix_mask_
{
b_grid_desc_g_n_k_
.
GetLength
(
I1
)},
c0_matrix_mask_
{
a_grid_desc_g_m_k_
.
GetLength
(
I1
),
b_grid_desc_g_n_k_
.
GetLength
(
I1
)},
raw_lengths_mz_nz_kz_gemm1nz_
{
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
+
NumDimK
-
1
],
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
7b915a10
...
...
@@ -352,12 +352,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
Q_K1
=
8
;
static
constexpr
index_t
K_K1
=
8
;
static
constexpr
index_t
V_N1
=
2
;
static
constexpr
index_t
Q_M1
=
2
;
static
constexpr
index_t
K_N1
=
2
;
static
constexpr
index_t
V_O1
=
8
;
static
constexpr
index_t
Y_O1
=
8
;
static
constexpr
index_t
Y_M1
=
2
;
...
...
@@ -604,9 +598,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
)
{
return
MaskUpperTriangleFromTopLeftPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTriangleFromBottomRight
)
{
return
Mask
Out
UpperTrianglePredicate
{};
return
MaskUpperTriangle
FromBottomRight
Predicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
@@ -728,7 +726,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
=
=
MaskingSpecialization
::
Mask
OutUpperTriang
le
,
MaskingSpec
!
=
MaskingSpecialization
::
Mask
Disab
le
d
,
Deterministic
>
;
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
...
...
@@ -940,7 +938,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()));
// C0 mask
const
auto
c0_matrix_mask
=
C0MatrixMask
(
b_grid_desc_g_n_k
.
GetLength
(
I1
));
const
auto
c0_matrix_mask
=
C0MatrixMask
(
a_grid_desc_g_m_k
.
GetLength
(
I1
),
b_grid_desc_g_n_k
.
GetLength
(
I1
));
grid_size_
+=
grid_size_grp
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
7b915a10
...
...
@@ -359,12 +359,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
Q_K1
=
8
;
static
constexpr
index_t
K_K1
=
8
;
static
constexpr
index_t
V_N1
=
2
;
static
constexpr
index_t
Q_M1
=
2
;
static
constexpr
index_t
K_N1
=
2
;
static
constexpr
index_t
V_O1
=
8
;
static
constexpr
index_t
Y_O1
=
8
;
static
constexpr
index_t
Y_M1
=
2
;
...
...
@@ -604,9 +598,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTriangleFromTopLeft
)
{
return
MaskUpperTriangleFromTopLeftPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTriangleFromBottomRight
)
{
return
Mask
Out
UpperTrianglePredicate
{};
return
MaskUpperTriangle
FromBottomRight
Predicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
@@ -736,7 +734,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
=
=
MaskingSpecialization
::
Mask
OutUpperTriang
le
,
MaskingSpec
!
=
MaskingSpecialization
::
Mask
Disab
le
d
,
Deterministic
>
;
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
...
...
@@ -948,7 +946,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()));
// C0 mask
const
auto
c0_matrix_mask
=
C0MatrixMask
(
b_grid_desc_g_n_k
.
GetLength
(
I1
));
const
auto
c0_matrix_mask
=
C0MatrixMask
(
a_grid_desc_g_m_k
.
GetLength
(
I1
),
b_grid_desc_g_n_k
.
GetLength
(
I1
));
grid_size_
+=
grid_size_grp
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment