Commit 81639679 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Let D0 shuffled laoding not depend on ABlockTransfer Spec

parent 9a423017
......@@ -361,14 +361,21 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(C1GridDesc_M_N{}))>;
static constexpr auto D0N2 = AK1;
static constexpr auto D0N1 = Number<32 / AK1.value>{};
using D0StoreType =
typename conditional<is_same<D0DataType, void>::value, half_t, D0DataType>::type;
static constexpr auto D0ShuffleBlock_N =
ck::math::min(static_cast<index_t>(32768 / sizeof(D0StoreType)) / MPerBlock, NPerBlock);
static constexpr auto D0N2 = Number<4 * sizeof(float) / sizeof(D0StoreType)>{};
static constexpr auto D0N1 = Number<32 / D0N2.value>{};
static constexpr auto D0N0 = Number<NPerBlock / 32>{};
static constexpr auto D0N0_PerShuffle = Number<KPerBlock / 32>{};
static constexpr auto D0_NumShuffle = NPerBlock / KPerBlock;
// ToDo: strange issue when D0N0_PerShuffle == 4 is used (too many vgpr consumption ?)
static constexpr auto D0N0_PerShuffle = Number<ck::math::min(D0ShuffleBlock_N / 32, 2)>{};
static constexpr auto D0_NumShuffle = D0N0.value / D0N0_PerShuffle.value;
static constexpr auto I16 = Number<16>{};
static_assert(NPerBlock % KPerBlock == 0 && KPerBlock % 32 == 0,
"KPerBlock should be multiple of 32 and divisor of NPerBlock");
static_assert(NPerBlock % D0ShuffleBlock_N == 0 && D0ShuffleBlock_N % 32 == 0,
"Calculated D0ShuffleBlock_N should be multiple of 32 and divisor of NPerBlock");
__host__ __device__ static constexpr auto
MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3(const D0GridDesc_M_N& d0_grid_desc_m_n)
......@@ -394,9 +401,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
struct D0Operator
{
static_assert(ABlockTransferThreadClusterLengths_AK0_M_AK1::Size() == 3);
static_assert(ABlockTransferDstScalarPerVector_AK1 % D0BlockTransferSrcScalarPerVector ==
0);
static_assert(D0N2 % D0BlockTransferSrcScalarPerVector == 0);
template <typename DataType>
struct TypeTransform
......@@ -454,8 +459,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<I1, I1, D0N0_PerShuffle, D0N1, MPerBlock, D0N2>,
typename sequence_merge<Sequence<1, 1, 1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1>::type,
typename sequence_merge<Sequence<1, 1, 1>, Sequence<4, BlockSize / 4, 1>>::type,
Sequence<0, 1, 2, 4, 3, 5>,
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
......@@ -466,7 +470,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
5,
5,
D0BlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
D0N2, // D0BlockTransferDstScalarPerVector
1,
1,
true, // SrcResetCoord
......@@ -482,7 +486,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
Sequence<0, 1, 2, 3, 4, 5>, // DimAccessOrder
5, // SrcVectorDim
4, // SrcScalarPerVector
2>;
2>; // SrcScalarStrideInVector (not used)
};
struct SharedMemTrait
......@@ -1045,7 +1049,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
// bias add
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(I0, nr * D0N0_PerShuffle, i));
make_tuple(I0, nr * D0N0_PerShuffle + i / I16, i % I16));
acc_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment