Commit 17774771 authored by danyao12's avatar danyao12
Browse files

restore original c_grid_desc_m0_n0_m1_n1_m2_n2

parent 8582c75c
...@@ -778,16 +778,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -778,16 +778,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
__host__ __device__ static auto __host__ __device__ static auto
MakeCGridDesc_M0_N0_M1_N1_M2_N2_N3_N4(const CGradDesc_M_N& c_grid_desc_m_n) MakeCGridDesc_M0_N0_M1_N1_M2_N2_N3_N4(const CGradDesc_M_N& c_grid_desc_m_n)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto MRepeat = M / GemmMWave / MPerXdl;
const auto NRepeat = N / GemmNWave / NPerXdl;
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy // HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there // variable I1 there
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n, c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MRepeat, GemmMWave, MPerXdl)), make_tuple(make_unmerge_transform(make_tuple(I1, GemmMWave, MPerXdl)),
make_unmerge_transform(make_tuple(NRepeat, GemmNWave, NPerXdl))), make_unmerge_transform(make_tuple(I1, GemmNWave, NPerXdl))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment