Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
105c6a08
Commit
105c6a08
authored
Jan 16, 2023
by
qin letao
Browse files
change dv thread copy
parent
7409bc5d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
2 deletions
+24
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
+24
-2
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
View file @
105c6a08
...
@@ -742,6 +742,28 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -742,6 +742,28 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
1
,
// DstScalarStrideInVector
1
,
// DstScalarStrideInVector
true
>
;
true
>
;
using
ABlockwiseCopy_dV
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
DataType
,
decltype
(
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
tensor_operation
::
element_wise
::
Relu
,
Sequence
<
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I0
),
// ThreadSliceLengths
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I1
),
I1
,
I1
,
I1
,
N2
,
I1
,
N4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
// DstVectorDim
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
;
template
<
typename
GridDesc_M0_O_M1
>
template
<
typename
GridDesc_M0_O_M1
>
using
BBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
using
BBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThisThreadBlock
,
...
@@ -1379,10 +1401,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1379,10 +1401,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
vgrad_grid_desc_n_o
);
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
vgrad_grid_desc_n_o
);
// dV: A matrix VGPR-to-LDS blockwise copy
// dV: A matrix VGPR-to-LDS blockwise copy
auto
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds
=
typename
Gemm2
::
ABlockwiseCopy
{
auto
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds
=
typename
Gemm2
::
ABlockwiseCopy
_dV
{
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
Gemm2
::
MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
(),
Gemm2
::
MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
(),
tensor_operation
::
element_wise
::
PassThrough
{}};
tensor_operation
::
element_wise
::
Relu
{}};
//relu(P-dropped)
// dV: B matrix global-to-LDS blockwise copy
// dV: B matrix global-to-LDS blockwise copy
auto
vgrad_gemm_tile_ygrad_blockwise_copy
=
auto
vgrad_gemm_tile_ygrad_blockwise_copy
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment