Commit ab0d58b2 authored by letaoqin's avatar letaoqin
Browse files

change D0M name

parent 2220cf9a
...@@ -1154,8 +1154,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1154,8 +1154,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
} }
// D0 // D0
static constexpr auto D0M1 = Number<4>{}; static constexpr auto D0M2 = Number<4>{};
static constexpr auto D0M0 = Number<MPerBlock>{} / D0M1; static constexpr auto D0M1 = Number<MPerBlock>{} / D0M2;
// static constexpr auto D0M = Number<MPerBlock>{} / D0M2;
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(const D0GridDesc_M_N& d0_grid_desc_m_n) MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(const D0GridDesc_M_N& d0_grid_desc_m_n)
...@@ -1168,7 +1169,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1168,7 +1169,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const auto d0_grid_desc_m0_n0_m1_m2_n1_m3 = transform_tensor_descriptor( const auto d0_grid_desc_m0_n0_m1_m2_n1_m3 = transform_tensor_descriptor(
d0_grid_desc_m_n, d0_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, D0M0, D0M1)), make_tuple(make_unmerge_transform(make_tuple(MBlock, D0M1, D0M2)),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))), make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3>{})); make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3>{}));
...@@ -1187,24 +1188,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1187,24 +1188,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__ __device__ static constexpr auto GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3() __host__ __device__ static constexpr auto GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3()
{ {
// B1 matrix in LDS memory, dst of blockwise copy // B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(make_tuple(I1, I1, D0M0, Number<NPerBlock>{}, D0M1), return make_naive_tensor_descriptor(make_tuple(I1, I1, D0M1, Number<NPerBlock>{}, D0M2),
make_tuple(Number<NPerBlock>{} * D0M1, make_tuple(Number<NPerBlock>{} * D0M2,
Number<NPerBlock>{} * D0M1, Number<NPerBlock>{} * D0M2,
Number<NPerBlock>{} * D0M1, Number<NPerBlock>{} * D0M2,
D0M1, D0M2,
I1)); I1));
} }
__host__ __device__ static constexpr auto GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3() __host__ __device__ static constexpr auto GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3()
{ {
constexpr auto d0_raw_m0_n_m1 = constexpr auto d0_raw_m0_n_m1 =
make_naive_tensor_descriptor(make_tuple(D0M0, Number<NPerBlock>{}, D0M1), make_naive_tensor_descriptor(make_tuple(D0M1, Number<NPerBlock>{}, D0M2),
make_tuple(Number<NPerBlock>{} * D0M1, D0M1, I1)); make_tuple(Number<NPerBlock>{} * D0M2, D0M2, I1));
constexpr auto d0_n0_n1_m0_m1_m2_m3 = transform_tensor_descriptor( constexpr auto d0_n0_n1_m0_m1_m2_m3 = transform_tensor_descriptor(
d0_raw_m0_n_m1, d0_raw_m0_n_m1,
make_tuple(make_unmerge_transform(make_tuple(D0M0 / I2, I2)), make_tuple(make_unmerge_transform(make_tuple(D0M1 / I2, I2)),
make_unmerge_transform( make_unmerge_transform(
make_tuple(Number<NPerBlock / NPerXdl>{}, Number<NPerXdl>{})), make_tuple(Number<NPerBlock / NPerXdl>{}, Number<NPerXdl>{})),
make_pass_through_transform(D0M1)), make_pass_through_transform(D0M2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{})); make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{}));
return d0_n0_n1_m0_m1_m2_m3; return d0_n0_n1_m0_m1_m2_m3;
...@@ -1215,14 +1216,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1215,14 +1216,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3(); GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3();
static constexpr auto d0_thread_desc_ = static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I8, I1, D0M1)); make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I8, I1, D0M2));
using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1< using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock, ThisThreadBlock,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<I1, I1, D0M0, NPerBlock, D0M1>, // BlockSliceLengths Sequence<I1, I1, D0M1, NPerBlock, D0M2>, // BlockSliceLengths
Sequence<1, 1, 8, 32, 1>, // ThreadClusterLengths Sequence<1, 1, 8, 32, 1>, // ThreadClusterLengths
Sequence<0, 1, 2, 4, 3>, // ThreadClusterArrangeOrder Sequence<0, 1, 2, 4, 3>, // ThreadClusterArrangeOrder
D0DataType, // SrcData D0DataType, // SrcData
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment