Commit 81639679 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Let D0 shuffled laoding not depend on ABlockTransfer Spec

parent 9a423017
...@@ -361,14 +361,21 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -361,14 +361,21 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(C1GridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(C1GridDesc_M_N{}))>;
static constexpr auto D0N2 = AK1; using D0StoreType =
static constexpr auto D0N1 = Number<32 / AK1.value>{}; typename conditional<is_same<D0DataType, void>::value, half_t, D0DataType>::type;
static constexpr auto D0ShuffleBlock_N =
ck::math::min(static_cast<index_t>(32768 / sizeof(D0StoreType)) / MPerBlock, NPerBlock);
static constexpr auto D0N2 = Number<4 * sizeof(float) / sizeof(D0StoreType)>{};
static constexpr auto D0N1 = Number<32 / D0N2.value>{};
static constexpr auto D0N0 = Number<NPerBlock / 32>{}; static constexpr auto D0N0 = Number<NPerBlock / 32>{};
static constexpr auto D0N0_PerShuffle = Number<KPerBlock / 32>{}; // ToDo: strange issue when D0N0_PerShuffle == 4 is used (too many vgpr consumption ?)
static constexpr auto D0_NumShuffle = NPerBlock / KPerBlock; static constexpr auto D0N0_PerShuffle = Number<ck::math::min(D0ShuffleBlock_N / 32, 2)>{};
static constexpr auto D0_NumShuffle = D0N0.value / D0N0_PerShuffle.value;
static constexpr auto I16 = Number<16>{};
static_assert(NPerBlock % KPerBlock == 0 && KPerBlock % 32 == 0, static_assert(NPerBlock % D0ShuffleBlock_N == 0 && D0ShuffleBlock_N % 32 == 0,
"KPerBlock should be multiple of 32 and divisor of NPerBlock"); "Calculated D0ShuffleBlock_N should be multiple of 32 and divisor of NPerBlock");
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3(const D0GridDesc_M_N& d0_grid_desc_m_n) MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3(const D0GridDesc_M_N& d0_grid_desc_m_n)
...@@ -394,9 +401,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -394,9 +401,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
struct D0Operator struct D0Operator
{ {
static_assert(ABlockTransferThreadClusterLengths_AK0_M_AK1::Size() == 3); static_assert(D0N2 % D0BlockTransferSrcScalarPerVector == 0);
static_assert(ABlockTransferDstScalarPerVector_AK1 % D0BlockTransferSrcScalarPerVector ==
0);
template <typename DataType> template <typename DataType>
struct TypeTransform struct TypeTransform
...@@ -454,8 +459,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -454,8 +459,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<I1, I1, D0N0_PerShuffle, D0N1, MPerBlock, D0N2>, Sequence<I1, I1, D0N0_PerShuffle, D0N1, MPerBlock, D0N2>,
typename sequence_merge<Sequence<1, 1, 1>, typename sequence_merge<Sequence<1, 1, 1>, Sequence<4, BlockSize / 4, 1>>::type,
ABlockTransferThreadClusterLengths_AK0_M_AK1>::type,
Sequence<0, 1, 2, 4, 3, 5>, Sequence<0, 1, 2, 4, 3, 5>,
typename TypeTransform<D0DataType>::Type, // SrcData typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
...@@ -466,7 +470,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -466,7 +470,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
5, 5,
5, 5,
D0BlockTransferSrcScalarPerVector, D0BlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1, D0N2, // D0BlockTransferDstScalarPerVector
1, 1,
1, 1,
true, // SrcResetCoord true, // SrcResetCoord
...@@ -482,7 +486,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -482,7 +486,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
Sequence<0, 1, 2, 3, 4, 5>, // DimAccessOrder Sequence<0, 1, 2, 3, 4, 5>, // DimAccessOrder
5, // SrcVectorDim 5, // SrcVectorDim
4, // SrcScalarPerVector 4, // SrcScalarPerVector
2>; 2>; // SrcScalarStrideInVector (not used)
}; };
struct SharedMemTrait struct SharedMemTrait
...@@ -1045,7 +1049,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -1045,7 +1049,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
// bias add // bias add
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) { static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset = c_thread_desc.CalculateOffset( constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(I0, nr * D0N0_PerShuffle, i)); make_tuple(I0, nr * D0N0_PerShuffle + i / I16, i % I16));
acc_thread_buf(Number<c_offset>{}) += acc_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]); ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment