Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ab0d58b2
Commit
ab0d58b2
authored
Aug 23, 2023
by
letaoqin
Browse files
change D0M name
parent
2220cf9a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
14 deletions
+15
-14
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+15
-14
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
ab0d58b2
...
@@ -1154,8 +1154,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1154,8 +1154,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
}
}
// D0
// D0
static
constexpr
auto
D0M1
=
Number
<
4
>
{};
static
constexpr
auto
D0M2
=
Number
<
4
>
{};
static
constexpr
auto
D0M0
=
Number
<
MPerBlock
>
{}
/
D0M1
;
static
constexpr
auto
D0M1
=
Number
<
MPerBlock
>
{}
/
D0M2
;
// static constexpr auto D0M = Number<MPerBlock>{} / D0M2;
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
const
D0GridDesc_M_N
&
d0_grid_desc_m_n
)
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
const
D0GridDesc_M_N
&
d0_grid_desc_m_n
)
...
@@ -1168,7 +1169,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1168,7 +1169,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
auto
d0_grid_desc_m0_n0_m1_m2_n1_m3
=
transform_tensor_descriptor
(
const
auto
d0_grid_desc_m0_n0_m1_m2_n1_m3
=
transform_tensor_descriptor
(
d0_grid_desc_m_n
,
d0_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
D0M
0
,
D0M
1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
D0M
1
,
D0M
2
)),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
>
{}));
...
@@ -1187,24 +1188,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1187,24 +1188,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__
__device__
static
constexpr
auto
GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3
()
__host__
__device__
static
constexpr
auto
GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3
()
{
{
// B1 matrix in LDS memory, dst of blockwise copy
// B1 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
I1
,
I1
,
D0M
0
,
Number
<
NPerBlock
>
{},
D0M
1
),
return
make_naive_tensor_descriptor
(
make_tuple
(
I1
,
I1
,
D0M
1
,
Number
<
NPerBlock
>
{},
D0M
2
),
make_tuple
(
Number
<
NPerBlock
>
{}
*
D0M
1
,
make_tuple
(
Number
<
NPerBlock
>
{}
*
D0M
2
,
Number
<
NPerBlock
>
{}
*
D0M
1
,
Number
<
NPerBlock
>
{}
*
D0M
2
,
Number
<
NPerBlock
>
{}
*
D0M
1
,
Number
<
NPerBlock
>
{}
*
D0M
2
,
D0M
1
,
D0M
2
,
I1
));
I1
));
}
}
__host__
__device__
static
constexpr
auto
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3
()
__host__
__device__
static
constexpr
auto
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3
()
{
{
constexpr
auto
d0_raw_m0_n_m1
=
constexpr
auto
d0_raw_m0_n_m1
=
make_naive_tensor_descriptor
(
make_tuple
(
D0M
0
,
Number
<
NPerBlock
>
{},
D0M
1
),
make_naive_tensor_descriptor
(
make_tuple
(
D0M
1
,
Number
<
NPerBlock
>
{},
D0M
2
),
make_tuple
(
Number
<
NPerBlock
>
{}
*
D0M
1
,
D0M
1
,
I1
));
make_tuple
(
Number
<
NPerBlock
>
{}
*
D0M
2
,
D0M
2
,
I1
));
constexpr
auto
d0_n0_n1_m0_m1_m2_m3
=
transform_tensor_descriptor
(
constexpr
auto
d0_n0_n1_m0_m1_m2_m3
=
transform_tensor_descriptor
(
d0_raw_m0_n_m1
,
d0_raw_m0_n_m1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
D0M
0
/
I2
,
I2
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
D0M
1
/
I2
,
I2
)),
make_unmerge_transform
(
make_unmerge_transform
(
make_tuple
(
Number
<
NPerBlock
/
NPerXdl
>
{},
Number
<
NPerXdl
>
{})),
make_tuple
(
Number
<
NPerBlock
/
NPerXdl
>
{},
Number
<
NPerXdl
>
{})),
make_pass_through_transform
(
D0M
1
)),
make_pass_through_transform
(
D0M
2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
return
d0_n0_n1_m0_m1_m2_m3
;
return
d0_n0_n1_m0_m1_m2_m3
;
...
@@ -1215,14 +1216,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1215,14 +1216,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3
();
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3
();
static
constexpr
auto
d0_thread_desc_
=
static
constexpr
auto
d0_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I8
,
I1
,
D0M
1
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I8
,
I1
,
D0M
2
));
using
D0BlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
using
D0BlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
I1
,
I1
,
D0M
0
,
NPerBlock
,
D0M
1
>
,
// BlockSliceLengths
Sequence
<
I1
,
I1
,
D0M
1
,
NPerBlock
,
D0M
2
>
,
// BlockSliceLengths
Sequence
<
1
,
1
,
8
,
32
,
1
>
,
// ThreadClusterLengths
Sequence
<
1
,
1
,
8
,
32
,
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
4
,
3
>
,
// ThreadClusterArrangeOrder
Sequence
<
0
,
1
,
2
,
4
,
3
>
,
// ThreadClusterArrangeOrder
D0DataType
,
// SrcData
D0DataType
,
// SrcData
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment