Commit 703ef6d7 authored by letaoqin's avatar letaoqin
Browse files

change d0_block_desc_n0_n1_m0_m1_m2_m3 to d0_block_desc_n0_n1_m0_m1_m2

parent 1696ca42
...@@ -1208,12 +1208,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1208,12 +1208,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
D0M2, D0M2,
I1)); I1));
} }
__host__ __device__ static constexpr auto GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3() __host__ __device__ static constexpr auto GetD0BlockReadDescriptor_N0_N1_M0_M1_M2()
{ {
constexpr auto d0_raw_m0_n_m1 = constexpr auto d0_raw_m0_n_m1 =
make_naive_tensor_descriptor(make_tuple(D0M1, Number<NPerBlock>{}, D0M2), make_naive_tensor_descriptor(make_tuple(D0M1, Number<NPerBlock>{}, D0M2),
make_tuple(Number<NPerBlock>{} * D0M2, D0M2, I1)); make_tuple(Number<NPerBlock>{} * D0M2, D0M2, I1));
constexpr auto d0_n0_n1_m0_m1_m2_m3 = transform_tensor_descriptor( constexpr auto d0_n0_n1_m0_m1_m2 = transform_tensor_descriptor(
d0_raw_m0_n_m1, d0_raw_m0_n_m1,
make_tuple(make_unmerge_transform(make_tuple(D0M1 / I2, I2)), make_tuple(make_unmerge_transform(make_tuple(D0M1 / I2, I2)),
make_unmerge_transform( make_unmerge_transform(
...@@ -1221,12 +1221,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1221,12 +1221,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_pass_through_transform(D0M2)), make_pass_through_transform(D0M2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{})); make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{}));
return d0_n0_n1_m0_m1_m2_m3; return d0_n0_n1_m0_m1_m2;
} }
static constexpr auto d0_block_desc_m0_n0_m1_m2_n1_m3 = static constexpr auto d0_block_desc_m0_n0_m1_m2_n1_m3 =
GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3(); GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3();
static constexpr auto d0_block_desc_n0_n1_m0_m1_m2_m3 = static constexpr auto d0_block_desc_n0_n1_m0_m1_m2 =
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3(); GetD0BlockReadDescriptor_N0_N1_M0_M1_M2();
static constexpr auto d0_thread_desc_ = static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2)); make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2));
...@@ -1261,10 +1261,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1261,10 +1261,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
1>; 1>;
using D0ThreadCopy = using D0ThreadCopy =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_desc_n0_n1_m0_m1_m2_m3), // SrcDesc decltype(d0_block_desc_n0_n1_m0_m1_m2), // SrcDesc
decltype(d0_thread_desc_), // DstDesc decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 4>, // SliceLengths Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim 4, // SrcVectorDim
...@@ -2018,7 +2018,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2018,7 +2018,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
d0_block_buf); d0_block_buf);
block_sync_lds(); block_sync_lds();
// read data form lds // read data form lds
d0_thread_copy_lds_to_vgpr.Run(D0Loader::d0_block_desc_n0_n1_m0_m1_m2_m3, d0_thread_copy_lds_to_vgpr.Run(D0Loader::d0_block_desc_n0_n1_m0_m1_m2,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
d0_block_buf, d0_block_buf,
D0Loader::d0_thread_desc_, D0Loader::d0_thread_desc_,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment