Commit e3eb4381 authored by letaoqin's avatar letaoqin
Browse files

add d0_block_copy_global_to_lds

parent 77df3ccb
......@@ -1179,13 +1179,51 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return d0_grid_desc_m0_n0_m1_m2_n1_m3;
}
struct D0
{
};
using D0GridDescriptor_M0_N0_M1_M2_N1_M3 =
remove_cvref_t<decltype(MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(D0GridDesc_M_N{}))>;
struct D0Loader
{
__host__ __device__ static constexpr auto GetD0BlockDescriptor_M0_N0_M1_M2_N1_M()
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(I1, I1, I1, D0M2, Number<NPerBlock>{}, D0M3),
make_tuple(Number<NPerBlock>{} * D0M3,
Number<NPerBlock>{} * D0M3,
Number<NPerBlock>{} * D0M3,
Number<NPerBlock>{} * D0M3,
D0M3,
I1));
}
static constexpr auto d0_block_desc_m0_n0_m1_m2_n1_m3 =
GetD0BlockDescriptor_M0_N0_M1_M2_N1_M();
using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, 1, 1, D0M2, NPerBlock, D0M3>, // BlockSliceLengths
Sequence<1, 1, 1, 8, 32, 1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
D0DataType, // SrcData
D0DataType, // DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc
decltype(d0_block_desc_m0_n0_m1_m2_n1_m3), // DstDesc
Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder
Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder
4, // SrcVectorDim
2, // DstVectorDim
NPerBlock / 32, // SrcScalarPerVector
D0M3.value / 1, // DstScalarPerVector
1,
1,
false,
true, // DstResetCoord
1>;
};
template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap,
......@@ -1513,6 +1551,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
qgrad_thread_origin_on_grid_m0_o0_m1_o1_m2_o2_o3_o4,
scale_rp_dropout);
// D0
auto d0_block_copy_global_to_lds =
typename D0Loader::D0BlockwiseCopy(d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
//
// Blockwise softmax
//
......@@ -1896,6 +1942,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// add bias
// P_i: = softmax(scalar * S_i:)
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment