Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
17bb1aaa
Commit
17bb1aaa
authored
Jan 17, 2023
by
ltqin
Browse files
add alpha for dV and change alpha for dK dQ
parent
4cbab521
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
33 additions
and
64 deletions
+33
-64
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+7
-1
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
+4
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
+2
-24
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
+20
-37
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
17bb1aaa
...
@@ -255,6 +255,12 @@ int run(int argc, char* argv[])
...
@@ -255,6 +255,12 @@ int run(int argc, char* argv[])
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
float
p_drop
=
0.2
;
float
p_dropout
=
1
-
p_drop
;
float
rp_dropout
=
1.0
/
p_dropout
;
float
scale_rp_dropout
=
alpha
*
rp_dropout
;
if
(
argc
==
1
)
if
(
argc
==
1
)
{
{
...
@@ -479,7 +485,7 @@ int run(int argc, char* argv[])
...
@@ -479,7 +485,7 @@ int run(int argc, char* argv[])
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
Scale
{
scale_rp_dropout
},
//dQ *= scale_rp_dropout
QKVElementOp
{},
QKVElementOp
{},
YElementOp
{});
YElementOp
{});
...
...
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
View file @
17bb1aaa
...
@@ -21,8 +21,10 @@ struct BlockwiseDropout
...
@@ -21,8 +21,10 @@ struct BlockwiseDropout
{
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
return
keep
?
val
*
p_dropout_rescale
if
constexpr
(
using_sign_bit
)
:
(
using_sign_bit
?
-
val
*
p_dropout_rescale
:
float
(
0
));
return
keep
?
val
:
-
val
;
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
View file @
17bb1aaa
...
@@ -742,28 +742,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -742,28 +742,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
1
,
// DstScalarStrideInVector
1
,
// DstScalarStrideInVector
true
>
;
true
>
;
using
ABlockwiseCopy_dV
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
DataType
,
decltype
(
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
tensor_operation
::
element_wise
::
Relu
,
Sequence
<
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I0
),
// ThreadSliceLengths
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I1
),
I1
,
I1
,
I1
,
N2
,
I1
,
N4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
// DstVectorDim
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
;
template
<
typename
GridDesc_M0_O_M1
>
template
<
typename
GridDesc_M0_O_M1
>
using
BBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
using
BBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThisThreadBlock
,
...
@@ -1401,10 +1379,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1401,10 +1379,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
vgrad_grid_desc_n_o
);
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
vgrad_grid_desc_n_o
);
// dV: A matrix VGPR-to-LDS blockwise copy
// dV: A matrix VGPR-to-LDS blockwise copy
auto
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds
=
typename
Gemm2
::
ABlockwiseCopy
_dV
{
auto
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds
=
typename
Gemm2
::
ABlockwiseCopy
{
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
Gemm2
::
MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
(),
Gemm2
::
MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
(),
tensor_operation
::
element_wise
::
Relu
{}};
// relu(P-dropped)
tensor_operation
::
element_wise
::
PassThrough
{}};
// dV: B matrix global-to-LDS blockwise copy
// dV: B matrix global-to-LDS blockwise copy
auto
vgrad_gemm_tile_ygrad_blockwise_copy
=
auto
vgrad_gemm_tile_ygrad_blockwise_copy
=
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
17bb1aaa
...
@@ -722,34 +722,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -722,34 +722,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
,
typename
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
,
false
>
;
false
>
;
template
<
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatGemmAcc
,
DataType
,
DataType
,
decltype
(
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
tensor_operation
::
element_wise
::
PassThrough
,
ElementwiseOp
,
Sequence
<
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I0
),
// ThreadSliceLengths
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I1
),
I1
,
I1
,
I1
,
N2
,
I1
,
N4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
// DstVectorDim
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
;
using
ABlockwiseCopy_dV
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
DataType
,
decltype
(
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
tensor_operation
::
element_wise
::
Relu
,
Sequence
<
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
Sequence
<
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I0
),
// ThreadSliceLengths
I0
),
// ThreadSliceLengths
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I1
),
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I1
),
...
@@ -1410,10 +1389,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1410,10 +1389,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
vgrad_grid_desc_n_o
);
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
vgrad_grid_desc_n_o
);
// dV: A matrix VGPR-to-LDS blockwise copy
// dV: A matrix VGPR-to-LDS blockwise copy
auto
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds
=
typename
Gemm2
::
ABlockwiseCopy_dV
{
auto
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds
=
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
typename
Gemm2
::
template
ABlockwiseCopy
<
tensor_operation
::
element_wise
::
Relu
>{
Gemm2
::
MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
(),
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
tensor_operation
::
element_wise
::
Relu
{}};
// relu(P-dropped)
Gemm2
::
MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
(),
tensor_operation
::
element_wise
::
Relu
{}};
// relu(P-dropped)
// dV: B matrix global-to-LDS blockwise copy
// dV: B matrix global-to-LDS blockwise copy
auto
vgrad_gemm_tile_ygrad_blockwise_copy
=
auto
vgrad_gemm_tile_ygrad_blockwise_copy
=
...
@@ -1438,11 +1418,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1438,11 +1418,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
make_multi_index
(
make_multi_index
(
I0
,
block_work_idx
[
I1
]
*
Gemm2Params_N_O_M
::
GemmORepeat
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
I0
,
block_work_idx
[
I1
]
*
Gemm2Params_N_O_M
::
GemmORepeat
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
auto
vgrad_thread_copy_vgpr_to_global
=
typename
Gemm2
::
template
CBlockwiseCopy
<
decltype
(
auto
vgrad_thread_copy_vgpr_to_global
=
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
)>(
typename
Gemm2
::
template
CBlockwiseCopy
<
decltype
(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
),
vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
,
tensor_operation
::
element_wise
::
Scale
>(
tensor_operation
::
element_wise
::
PassThrough
{});
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
,
tensor_operation
::
element_wise
::
Scale
{
rp_dropout
});
// dK: transform input and output tensor descriptors
// dK: transform input and output tensor descriptors
const
auto
q_grid_desc_m0_k_m1
=
const
auto
q_grid_desc_m0_k_m1
=
...
@@ -1453,10 +1435,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1453,10 +1435,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
kgrad_grid_desc_n_k
);
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
kgrad_grid_desc_n_k
);
// dK: A matrix VGPR-to-LDS blockwise copy
// dK: A matrix VGPR-to-LDS blockwise copy
auto
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds
=
typename
Gemm2
::
ABlockwiseCopy
{
auto
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds
=
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
typename
Gemm2
::
template
ABlockwiseCopy
<
tensor_operation
::
element_wise
::
PassThrough
>{
Gemm2
::
MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
(),
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
tensor_operation
::
element_wise
::
PassThrough
{}};
Gemm2
::
MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
(),
tensor_operation
::
element_wise
::
PassThrough
{}};
// dK: B matrix global-to-LDS blockwise copy
// dK: B matrix global-to-LDS blockwise copy
auto
kgrad_gemm_tile_q_blockwise_copy
=
auto
kgrad_gemm_tile_q_blockwise_copy
=
...
@@ -1724,7 +1707,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1724,7 +1707,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// P_dropped
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
true
>(
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
true
>(
s_slash_p_thread_buf
,
ph
,
gemm1_k_block_outer_index
,
num_gemm1_k_block_outer_loop
);
s_slash_p_thread_buf
,
ph
);
block_sync_lds
();
// wait for gemm1 LDS read
block_sync_lds
();
// wait for gemm1 LDS read
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment