Commit 48a16339 authored by letaoqin's avatar letaoqin
Browse files

load D0 data only for 4 xdl

parent ab0d58b2
...@@ -1155,8 +1155,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1155,8 +1155,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// D0 // D0
static constexpr auto D0M2 = Number<4>{}; static constexpr auto D0M2 = Number<4>{};
static constexpr auto D0M1 = Number<MPerBlock>{} / D0M2; static constexpr auto D0M1 = Number<MPerXdl>{} / D0M2;
// static constexpr auto D0M = Number<MPerBlock>{} / D0M2; static constexpr auto D0M0 = Number<MPerBlock>{} / Number<MPerXdl>{};
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(const D0GridDesc_M_N& d0_grid_desc_m_n) MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(const D0GridDesc_M_N& d0_grid_desc_m_n)
...@@ -1169,10 +1169,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1169,10 +1169,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const auto d0_grid_desc_m0_n0_m1_m2_n1_m3 = transform_tensor_descriptor( const auto d0_grid_desc_m0_n0_m1_m2_n1_m3 = transform_tensor_descriptor(
d0_grid_desc_m_n, d0_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, D0M1, D0M2)), make_tuple(make_unmerge_transform(make_tuple(MBlock, D0M0, D0M1, D0M2)),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))), make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3>{})); make_tuple(Sequence<0, 2, 3, 5>{}, Sequence<1, 4>{}));
return d0_grid_desc_m0_n0_m1_m2_n1_m3; return d0_grid_desc_m0_n0_m1_m2_n1_m3;
} }
...@@ -1188,8 +1188,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1188,8 +1188,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__ __device__ static constexpr auto GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3() __host__ __device__ static constexpr auto GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3()
{ {
// B1 matrix in LDS memory, dst of blockwise copy // B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(make_tuple(I1, I1, D0M1, Number<NPerBlock>{}, D0M2), return make_naive_tensor_descriptor(
make_tuple(I1, I1, I1, D0M1, Number<NPerBlock>{}, D0M2),
make_tuple(Number<NPerBlock>{} * D0M2, make_tuple(Number<NPerBlock>{} * D0M2,
Number<NPerBlock>{} * D0M2,
Number<NPerBlock>{} * D0M2, Number<NPerBlock>{} * D0M2,
Number<NPerBlock>{} * D0M2, Number<NPerBlock>{} * D0M2,
D0M2, D0M2,
...@@ -1216,24 +1218,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1216,24 +1218,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3(); GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3();
static constexpr auto d0_thread_desc_ = static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I8, I1, D0M2)); make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2));
using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1< using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock, ThisThreadBlock,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<I1, I1, D0M1, NPerBlock, D0M2>, // BlockSliceLengths Sequence<I1, I1, I1, D0M1, NPerBlock, D0M2>, // BlockSliceLengths
Sequence<1, 1, 8, 32, 1>, // ThreadClusterLengths Sequence<1, 1, 1, 8, 32, 1>, // ThreadClusterLengths
Sequence<0, 1, 2, 4, 3>, // ThreadClusterArrangeOrder Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
D0DataType, // SrcData D0DataType, // SrcData
D0DataType, // DstData D0DataType, // DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc
decltype(d0_block_desc_m0_n0_m1_m2_n1_m3), // DstDesc decltype(d0_block_desc_m0_n0_m1_m2_n1_m3), // DstDesc
Sequence<0, 1, 2, 4, 3>, // SrcDimAccessOrder Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder
Sequence<0, 1, 3, 2, 4>, // DstDimAccessOrder Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder
3, // SrcVectorDim 4, // SrcVectorDim
4, // DstVectorDim 5, // DstVectorDim
D0BlockTransferSrcScalarPerVector, // SrcScalarPerVector D0BlockTransferSrcScalarPerVector, // SrcScalarPerVector
4, // DstScalarPerVector 4, // DstScalarPerVector
1, 1,
...@@ -1247,7 +1249,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1247,7 +1249,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
D0DataType, // DstData D0DataType, // DstData
decltype(d0_block_desc_n0_n1_m0_m1_m2_m3), // SrcDesc decltype(d0_block_desc_n0_n1_m0_m1_m2_m3), // SrcDesc
decltype(d0_thread_desc_), // DstDesc decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 8, 1, 4>, // SliceLengths Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim 4, // SrcVectorDim
2, // SrcScalarPerVector 2, // SrcScalarPerVector
...@@ -1833,10 +1835,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1833,10 +1835,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// D0 // D0
auto d0_block_copy_global_to_lds = typename D0Loader::D0BlockwiseCopy( auto d0_block_copy_global_to_lds = typename D0Loader::D0BlockwiseCopy(
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0), make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3, D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadCopy( auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadCopy(
...@@ -2002,8 +2004,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2002,8 +2004,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// add bias // add bias
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
// static constexpr auto c_thread_desc_ = static constexpr auto c_thread_desc_ =
// make_naive_tensor_descriptor_packed(make_tuple(I2, Number<16>{})); make_naive_tensor_descriptor_packed(make_tuple(D0M0, Number<16>{}));
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
...@@ -2016,11 +2018,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2016,11 +2018,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
D0Loader::d0_thread_desc_.GetElementSpaceSize()); D0Loader::d0_thread_desc_.GetElementSpaceSize());
ignore = d0_thread_buf; ignore = d0_thread_buf;
static_for<0, D0M0, 1>{}([&](auto mr) {
// load data to lds // load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_buf); d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0_grid_buf);
// d0_block_copy_global_to_lds.MoveSrcSliceWindow( d0_block_copy_global_to_lds.MoveSrcSliceWindow(
// d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(1, 0, 0, 0, 0)); d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite(D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3, d0_block_copy_global_to_lds.RunWrite(D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3,
d0_block_buf); d0_block_buf);
...@@ -2032,11 +2036,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2032,11 +2036,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
D0Loader::d0_thread_desc_, D0Loader::d0_thread_desc_,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
d0_thread_buf); d0_thread_buf);
// static_for<0, 2, 1>{}([&](auto mr) {
// bias add // bias add
static_for<0, s_slash_p_thread_buf.Size(), 1>{}([&](auto i) { static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
// constexpr index_t c_offset = constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(mr, i)); c_thread_desc_.CalculateOffset(make_tuple(mr, i));
// if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0) // if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
// if(ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]) != 1.0f) // if(ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]) != 1.0f)
...@@ -2054,12 +2058,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2054,12 +2058,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// ck::type_convert<FloatGemmAcc>(d0_thread_buf[i])); // ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]));
// } // }
s_slash_p_thread_buf(i) += ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]); s_slash_p_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
});
}); });
//});
d0_block_copy_global_to_lds.MoveSrcSliceWindow(d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_block_copy_global_to_lds.MoveSrcSliceWindow(
make_multi_index(-1, 0, 0, 0, 0)); d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
} }
// P_i: = softmax(scalar * S_i:) // P_i: = softmax(scalar * S_i:)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment