Commit ab0d58b2 authored by letaoqin's avatar letaoqin
Browse files

change D0M name

parent 2220cf9a
......@@ -1154,8 +1154,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
}
// D0
static constexpr auto D0M1 = Number<4>{};
static constexpr auto D0M0 = Number<MPerBlock>{} / D0M1;
static constexpr auto D0M2 = Number<4>{};
static constexpr auto D0M1 = Number<MPerBlock>{} / D0M2;
// static constexpr auto D0M = Number<MPerBlock>{} / D0M2;
__host__ __device__ static constexpr auto
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(const D0GridDesc_M_N& d0_grid_desc_m_n)
......@@ -1168,7 +1169,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const auto d0_grid_desc_m0_n0_m1_m2_n1_m3 = transform_tensor_descriptor(
d0_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, D0M0, D0M1)),
make_tuple(make_unmerge_transform(make_tuple(MBlock, D0M1, D0M2)),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3>{}));
......@@ -1187,24 +1188,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__ __device__ static constexpr auto GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3()
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(make_tuple(I1, I1, D0M0, Number<NPerBlock>{}, D0M1),
make_tuple(Number<NPerBlock>{} * D0M1,
Number<NPerBlock>{} * D0M1,
Number<NPerBlock>{} * D0M1,
D0M1,
return make_naive_tensor_descriptor(make_tuple(I1, I1, D0M1, Number<NPerBlock>{}, D0M2),
make_tuple(Number<NPerBlock>{} * D0M2,
Number<NPerBlock>{} * D0M2,
Number<NPerBlock>{} * D0M2,
D0M2,
I1));
}
__host__ __device__ static constexpr auto GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3()
{
constexpr auto d0_raw_m0_n_m1 =
make_naive_tensor_descriptor(make_tuple(D0M0, Number<NPerBlock>{}, D0M1),
make_tuple(Number<NPerBlock>{} * D0M1, D0M1, I1));
make_naive_tensor_descriptor(make_tuple(D0M1, Number<NPerBlock>{}, D0M2),
make_tuple(Number<NPerBlock>{} * D0M2, D0M2, I1));
constexpr auto d0_n0_n1_m0_m1_m2_m3 = transform_tensor_descriptor(
d0_raw_m0_n_m1,
make_tuple(make_unmerge_transform(make_tuple(D0M0 / I2, I2)),
make_tuple(make_unmerge_transform(make_tuple(D0M1 / I2, I2)),
make_unmerge_transform(
make_tuple(Number<NPerBlock / NPerXdl>{}, Number<NPerXdl>{})),
make_pass_through_transform(D0M1)),
make_pass_through_transform(D0M2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{}));
return d0_n0_n1_m0_m1_m2_m3;
......@@ -1215,14 +1216,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3();
static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I8, I1, D0M1));
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I8, I1, D0M2));
using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<I1, I1, D0M0, NPerBlock, D0M1>, // BlockSliceLengths
Sequence<I1, I1, D0M1, NPerBlock, D0M2>, // BlockSliceLengths
Sequence<1, 1, 8, 32, 1>, // ThreadClusterLengths
Sequence<0, 1, 2, 4, 3>, // ThreadClusterArrangeOrder
D0DataType, // SrcData
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment