Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
17774771
Commit
17774771
authored
Apr 28, 2023
by
danyao12
Browse files
restore original c_grid_desc_m0_n0_m1_n1_m2_n2
parent
8582c75c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
6 deletions
+2
-6
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt3.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt3.hpp
+2
-6
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt3.hpp
View file @
17774771
...
@@ -778,16 +778,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -778,16 +778,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
__host__
__device__
static
auto
__host__
__device__
static
auto
MakeCGridDesc_M0_N0_M1_N1_M2_N2_N3_N4
(
const
CGradDesc_M_N
&
c_grid_desc_m_n
)
MakeCGridDesc_M0_N0_M1_N1_M2_N2_N3_N4
(
const
CGradDesc_M_N
&
c_grid_desc_m_n
)
{
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MRepeat
=
M
/
GemmMWave
/
MPerXdl
;
const
auto
NRepeat
=
N
/
GemmNWave
/
NPerXdl
;
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
// variable I1 there
const
auto
c_grid_desc_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
const
auto
c_grid_desc_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
GemmMWave
,
MPerXdl
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
I1
,
GemmMWave
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
NRepeat
,
GemmNWave
,
NPerXdl
))),
make_unmerge_transform
(
make_tuple
(
I1
,
GemmNWave
,
NPerXdl
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment