Commit 105c6a08 authored by qin letao's avatar qin letao
Browse files

change dv thread copy

parent 7409bc5d
......@@ -742,6 +742,28 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
1, // DstScalarStrideInVector
true>;
using ABlockwiseCopy_dV = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
DataType,
decltype(a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
tensor_operation::element_wise::Relu,
Sequence<Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(
I0), // ThreadSliceLengths
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I1),
I1,
I1,
I1,
N2,
I1,
N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>;
template <typename GridDesc_M0_O_M1>
using BBlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
......@@ -1379,10 +1401,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(vgrad_grid_desc_n_o);
// dV: A matrix VGPR-to-LDS blockwise copy
auto vgrad_gemm_tile_p_thread_copy_vgpr_to_lds = typename Gemm2::ABlockwiseCopy{
auto vgrad_gemm_tile_p_thread_copy_vgpr_to_lds = typename Gemm2::ABlockwiseCopy_dV{
Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
Gemm2::MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4(),
tensor_operation::element_wise::PassThrough{}};
tensor_operation::element_wise::Relu{}}; //relu(P-dropped)
// dV: B matrix global-to-LDS blockwise copy
auto vgrad_gemm_tile_ygrad_blockwise_copy =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment