Commit e1980d10 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Update in D0 shuffled loading to support bigger KPerBlock size

parent ac3ef99c
...@@ -361,10 +361,14 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -361,10 +361,14 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(C1GridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(C1GridDesc_M_N{}))>;
static constexpr auto D0N2 = AK1; static constexpr auto D0N2 = AK1;
static constexpr auto D0N1 = AK0; static constexpr auto D0N1 = Number<32 / AK1.value>{};
static constexpr auto D0N0 = Number<NPerBlock / KPerBlock>{}; static constexpr auto D0N0 = Number<NPerBlock / 32>{};
static_assert(NPerBlock % KPerBlock == 0); static constexpr auto D0N0_PerShuffle = Number<KPerBlock / 32>{};
static constexpr auto D0_NumShuffle = NPerBlock / KPerBlock;
static_assert(NPerBlock % KPerBlock == 0 && KPerBlock % 32 == 0,
"KPerBlock should be multiple of 32 and divisor of NPerBlock");
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3(const D0GridDesc_M_N& d0_grid_desc_m_n) MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3(const D0GridDesc_M_N& d0_grid_desc_m_n)
...@@ -408,47 +412,48 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -408,47 +412,48 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
__host__ __device__ static constexpr auto GetD0BlockGlobalDescriptor_M0_N0_N1_N2_M1_N3() __host__ __device__ static constexpr auto GetD0BlockGlobalDescriptor_M0_N0_N1_N2_M1_N3()
{ {
return make_naive_tensor_descriptor_packed( return make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, D0N1, Number<MPerBlock>{}, D0N2)); make_tuple(I1, I1, D0N0_PerShuffle, D0N1, Number<MPerBlock>{}, D0N2));
} }
__host__ __device__ static constexpr auto GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2() __host__ __device__ static constexpr auto GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2_N3()
{ {
constexpr auto d0_raw_n0_m_n1 = constexpr auto d0_raw_n0_n1_m_n2 = make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(D0N1, Number<MPerBlock>{}, D0N2)); make_tuple(D0N0_PerShuffle, D0N1, Number<MPerBlock>{}, D0N2));
constexpr auto d0_raw_m_n = transform_tensor_descriptor( constexpr auto d0_raw_m_n = transform_tensor_descriptor(
d0_raw_n0_m_n1, d0_raw_n0_n1_m_n2,
make_tuple(make_pass_through_transform(Number<MPerBlock>{}), make_tuple(make_pass_through_transform(Number<MPerBlock>{}),
make_merge_transform(make_tuple(D0N1, D0N2))), make_merge_transform(make_tuple(D0N0_PerShuffle, D0N1, D0N2))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<2>{}, Sequence<0, 1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto d0_m0_m1_n0_n1_n2 = transform_tensor_descriptor( constexpr auto d0_m0_m1_n0_n1_n2_n3 = transform_tensor_descriptor(
d0_raw_m_n, d0_raw_m_n,
make_tuple(make_unmerge_transform( make_tuple(make_unmerge_transform(
make_tuple(Number<MPerBlock / MPerXdl>{}, Number<MPerXdl>{})), make_tuple(Number<MPerBlock / MPerXdl>{}, Number<MPerXdl>{})),
make_unmerge_transform(make_tuple((D0N1 * D0N2) / (I2 * I4), I2, I4))), make_unmerge_transform(
make_tuple(D0N0_PerShuffle, (D0N1 * D0N2) / (I2 * I4), I2, I4))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4, 5>{}));
return d0_m0_m1_n0_n1_n2; return d0_m0_m1_n0_n1_n2_n3;
} }
static constexpr auto d0_thread_desc_ = static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, I4)); make_naive_tensor_descriptor_packed(make_tuple(I1, I1, D0N0_PerShuffle, I4, I1, I4));
static constexpr auto d0_block_dst_desc_m0_n0_n1_n2_m1_n3 = static constexpr auto d0_block_dst_desc_m0_n0_n1_n2_m1_n3 =
GetD0BlockGlobalDescriptor_M0_N0_N1_N2_M1_N3(); GetD0BlockGlobalDescriptor_M0_N0_N1_N2_M1_N3();
static constexpr auto d0_block_src_desc_m0_m1_n0_n1_n2 = static constexpr auto d0_block_src_desc_m0_m1_n0_n1_n2_n3 =
GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2(); GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2_N3();
using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1< using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock, ThisThreadBlock,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<I1, I1, I1, D0N1, MPerBlock, D0N2>, Sequence<I1, I1, D0N0_PerShuffle, D0N1, MPerBlock, D0N2>,
typename sequence_merge<Sequence<1, 1, 1>, typename sequence_merge<Sequence<1, 1, 1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1>::type, ABlockTransferThreadClusterLengths_AK0_M_AK1>::type,
Sequence<0, 1, 2, 4, 3, 5>, Sequence<0, 1, 2, 4, 3, 5>,
...@@ -468,16 +473,16 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -468,16 +473,16 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
true, // DstResetCoord true, // DstResetCoord
NumGemmKPrefetchStage>; NumGemmKPrefetchStage>;
using D0ThreadwiseCopyLdsToVgpr = using D0ThreadwiseCopyLdsToVgpr = ThreadwiseTensorSliceTransfer_v4<
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_src_desc_m0_m1_n0_n1_n2), // SrcDesc decltype(d0_block_src_desc_m0_m1_n0_n1_n2_n3), // SrcDesc
decltype(d0_thread_desc_), // DstDesc decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 4>, // SliceLengths Sequence<1, 1, D0N0_PerShuffle.value, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder Sequence<0, 1, 2, 3, 4, 5>, // DimAccessOrder
4, // SrcVectorDim 5, // SrcVectorDim
4, // SrcScalarPerVector 4, // SrcScalarPerVector
2>; 2>;
}; };
struct SharedMemTrait struct SharedMemTrait
...@@ -907,7 +912,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -907,7 +912,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr( auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr(
make_tuple(wave_id[I0], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0)); make_tuple(wave_id[I0], wave_m_n_id[I1], 0, 0, wave_m_n_id[I0], 0));
index_t gemm1_k_block_outer_index = 0; index_t gemm1_k_block_outer_index = 0;
do do
...@@ -1016,29 +1021,31 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -1016,29 +1021,31 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>( auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize()); D0Operator::d0_thread_desc_.GetElementSpaceSize());
static_for<0, D0N0, 1>{}([&](auto nr) { static_for<0, D0_NumShuffle, 1>{}([&](auto nr) {
// load data to lds // load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_n1_n2_m1_n3, d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_n1_n2_m1_n3,
d0_grid_buf); d0_grid_buf);
d0_block_copy_global_to_lds.MoveSrcSliceWindow( d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_n1_n2_m1_n3, make_multi_index(0, 0, 1, 0, 0, 0)); d0_grid_desc_m0_n0_n1_n2_m1_n3,
make_multi_index(0, 0, D0N0_PerShuffle, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite( d0_block_copy_global_to_lds.RunWrite(
D0Operator::d0_block_dst_desc_m0_n0_n1_n2_m1_n3, d0_block_buf); D0Operator::d0_block_dst_desc_m0_n0_n1_n2_m1_n3, d0_block_buf);
block_sync_lds(); block_sync_lds();
// read data form lds // read data form lds
d0_thread_copy_lds_to_vgpr.Run(D0Operator::d0_block_src_desc_m0_m1_n0_n1_n2, d0_thread_copy_lds_to_vgpr.Run(
make_tuple(I0, I0, I0, I0, I0), D0Operator::d0_block_src_desc_m0_m1_n0_n1_n2_n3,
d0_block_buf, make_tuple(I0, I0, I0, I0, I0, I0),
D0Operator::d0_thread_desc_, d0_block_buf,
make_tuple(I0, I0, I0, I0, I0), D0Operator::d0_thread_desc_,
d0_thread_buf); make_tuple(I0, I0, I0, I0, I0, I0),
d0_thread_buf);
// bias add // bias add
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) { static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset = constexpr index_t c_offset = c_thread_desc.CalculateOffset(
c_thread_desc.CalculateOffset(make_tuple(I0, nr, i)); make_tuple(I0, nr * D0N0_PerShuffle, i));
acc_thread_buf(Number<c_offset>{}) += acc_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]); ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment