Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e3eb4381
Commit
e3eb4381
authored
Aug 17, 2023
by
letaoqin
Browse files
add d0_block_copy_global_to_lds
parent
77df3ccb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
52 additions
and
4 deletions
+52
-4
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+52
-4
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
e3eb4381
...
@@ -1179,13 +1179,51 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1179,13 +1179,51 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
return
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
}
}
struct
D0
{
};
using
D0GridDescriptor_M0_N0_M1_M2_N1_M3
=
using
D0GridDescriptor_M0_N0_M1_M2_N1_M3
=
remove_cvref_t
<
decltype
(
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
D0GridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
D0GridDesc_M_N
{}))
>
;
struct
D0Loader
{
__host__
__device__
static
constexpr
auto
GetD0BlockDescriptor_M0_N0_M1_M2_N1_M
()
{
// B1 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
I1
,
I1
,
I1
,
D0M2
,
Number
<
NPerBlock
>
{},
D0M3
),
make_tuple
(
Number
<
NPerBlock
>
{}
*
D0M3
,
Number
<
NPerBlock
>
{}
*
D0M3
,
Number
<
NPerBlock
>
{}
*
D0M3
,
Number
<
NPerBlock
>
{}
*
D0M3
,
D0M3
,
I1
));
}
static
constexpr
auto
d0_block_desc_m0_n0_m1_m2_n1_m3
=
GetD0BlockDescriptor_M0_N0_M1_M2_N1_M
();
using
D0BlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
1
,
1
,
D0M2
,
NPerBlock
,
D0M3
>
,
// BlockSliceLengths
Sequence
<
1
,
1
,
1
,
8
,
32
,
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
D0DataType
,
// SrcData
D0DataType
,
// DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
decltype
(
d0_block_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
4
,
// SrcVectorDim
2
,
// DstVectorDim
NPerBlock
/
32
,
// SrcScalarPerVector
D0M3
.
value
/
1
,
// DstScalarPerVector
1
,
1
,
false
,
true
,
// DstResetCoord
1
>
;
};
template
<
bool
HasMainKBlockLoop
,
template
<
bool
HasMainKBlockLoop
,
bool
IsDropout
,
bool
IsDropout
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
...
@@ -1513,6 +1551,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1513,6 +1551,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
qgrad_thread_origin_on_grid_m0_o0_m1_o1_m2_o2_o3_o4
,
qgrad_thread_origin_on_grid_m0_o0_m1_o1_m2_o2_o3_o4
,
scale_rp_dropout
);
scale_rp_dropout
);
// D0
auto
d0_block_copy_global_to_lds
=
typename
D0Loader
::
D0BlockwiseCopy
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
D0Loader
::
d0_block_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
//
//
// Blockwise softmax
// Blockwise softmax
//
//
...
@@ -1896,6 +1942,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1896,6 +1942,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
// add bias
// P_i: = softmax(scalar * S_i:)
// P_i: = softmax(scalar * S_i:)
// scaling is already performed in the preceding statements with s_element_op
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment