Unverified Commit 9a423017 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #1033 from ROCmSoftwarePlatform/mha_train_develop_d0shuffle_update

Update in D0 shuffled loading to support bigger KPerBlock size
parents ac3ef99c e1980d10
......@@ -361,10 +361,14 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(C1GridDesc_M_N{}))>;
static constexpr auto D0N2 = AK1;
static constexpr auto D0N1 = AK0;
static constexpr auto D0N0 = Number<NPerBlock / KPerBlock>{};
static_assert(NPerBlock % KPerBlock == 0);
static constexpr auto D0N2 = AK1;
static constexpr auto D0N1 = Number<32 / AK1.value>{};
static constexpr auto D0N0 = Number<NPerBlock / 32>{};
static constexpr auto D0N0_PerShuffle = Number<KPerBlock / 32>{};
static constexpr auto D0_NumShuffle = NPerBlock / KPerBlock;
static_assert(NPerBlock % KPerBlock == 0 && KPerBlock % 32 == 0,
"KPerBlock should be multiple of 32 and divisor of NPerBlock");
__host__ __device__ static constexpr auto
MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3(const D0GridDesc_M_N& d0_grid_desc_m_n)
......@@ -408,47 +412,48 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
__host__ __device__ static constexpr auto GetD0BlockGlobalDescriptor_M0_N0_N1_N2_M1_N3()
{
return make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, D0N1, Number<MPerBlock>{}, D0N2));
make_tuple(I1, I1, D0N0_PerShuffle, D0N1, Number<MPerBlock>{}, D0N2));
}
__host__ __device__ static constexpr auto GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2()
__host__ __device__ static constexpr auto GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2_N3()
{
constexpr auto d0_raw_n0_m_n1 =
make_naive_tensor_descriptor_packed(make_tuple(D0N1, Number<MPerBlock>{}, D0N2));
constexpr auto d0_raw_n0_n1_m_n2 = make_naive_tensor_descriptor_packed(
make_tuple(D0N0_PerShuffle, D0N1, Number<MPerBlock>{}, D0N2));
constexpr auto d0_raw_m_n = transform_tensor_descriptor(
d0_raw_n0_m_n1,
d0_raw_n0_n1_m_n2,
make_tuple(make_pass_through_transform(Number<MPerBlock>{}),
make_merge_transform(make_tuple(D0N1, D0N2))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_merge_transform(make_tuple(D0N0_PerShuffle, D0N1, D0N2))),
make_tuple(Sequence<2>{}, Sequence<0, 1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto d0_m0_m1_n0_n1_n2 = transform_tensor_descriptor(
constexpr auto d0_m0_m1_n0_n1_n2_n3 = transform_tensor_descriptor(
d0_raw_m_n,
make_tuple(make_unmerge_transform(
make_tuple(Number<MPerBlock / MPerXdl>{}, Number<MPerXdl>{})),
make_unmerge_transform(make_tuple((D0N1 * D0N2) / (I2 * I4), I2, I4))),
make_unmerge_transform(
make_tuple(D0N0_PerShuffle, (D0N1 * D0N2) / (I2 * I4), I2, I4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}));
make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4, 5>{}));
return d0_m0_m1_n0_n1_n2;
return d0_m0_m1_n0_n1_n2_n3;
}
static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, I4));
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, D0N0_PerShuffle, I4, I1, I4));
static constexpr auto d0_block_dst_desc_m0_n0_n1_n2_m1_n3 =
GetD0BlockGlobalDescriptor_M0_N0_N1_N2_M1_N3();
static constexpr auto d0_block_src_desc_m0_m1_n0_n1_n2 =
GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2();
static constexpr auto d0_block_src_desc_m0_m1_n0_n1_n2_n3 =
GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2_N3();
using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<I1, I1, I1, D0N1, MPerBlock, D0N2>,
Sequence<I1, I1, D0N0_PerShuffle, D0N1, MPerBlock, D0N2>,
typename sequence_merge<Sequence<1, 1, 1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1>::type,
Sequence<0, 1, 2, 4, 3, 5>,
......@@ -468,16 +473,16 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
true, // DstResetCoord
NumGemmKPrefetchStage>;
using D0ThreadwiseCopyLdsToVgpr =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_src_desc_m0_m1_n0_n1_n2), // SrcDesc
decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim
4, // SrcScalarPerVector
2>;
using D0ThreadwiseCopyLdsToVgpr = ThreadwiseTensorSliceTransfer_v4<
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_src_desc_m0_m1_n0_n1_n2_n3), // SrcDesc
decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, D0N0_PerShuffle.value, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4, 5>, // DimAccessOrder
5, // SrcVectorDim
4, // SrcScalarPerVector
2>;
};
struct SharedMemTrait
......@@ -907,7 +912,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr(
make_tuple(wave_id[I0], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
make_tuple(wave_id[I0], wave_m_n_id[I1], 0, 0, wave_m_n_id[I0], 0));
index_t gemm1_k_block_outer_index = 0;
do
......@@ -1016,29 +1021,31 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize());
static_for<0, D0N0, 1>{}([&](auto nr) {
static_for<0, D0_NumShuffle, 1>{}([&](auto nr) {
// load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_n1_n2_m1_n3,
d0_grid_buf);
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_n1_n2_m1_n3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_grid_desc_m0_n0_n1_n2_m1_n3,
make_multi_index(0, 0, D0N0_PerShuffle, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite(
D0Operator::d0_block_dst_desc_m0_n0_n1_n2_m1_n3, d0_block_buf);
block_sync_lds();
// read data form lds
d0_thread_copy_lds_to_vgpr.Run(D0Operator::d0_block_src_desc_m0_m1_n0_n1_n2,
make_tuple(I0, I0, I0, I0, I0),
d0_block_buf,
D0Operator::d0_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
d0_thread_buf);
d0_thread_copy_lds_to_vgpr.Run(
D0Operator::d0_block_src_desc_m0_m1_n0_n1_n2_n3,
make_tuple(I0, I0, I0, I0, I0, I0),
d0_block_buf,
D0Operator::d0_thread_desc_,
make_tuple(I0, I0, I0, I0, I0, I0),
d0_thread_buf);
// bias add
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(I0, nr, i));
constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(I0, nr * D0N0_PerShuffle, i));
acc_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment