Commit 940f786e authored by letaoqin's avatar letaoqin
Browse files

fix d0 load parameter

parent 9de99bdb
......@@ -239,7 +239,12 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
SharedMemTrait::reduction_space_size_aligned * sizeof(FloatGemmAcc);
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end);
const index_t d0_bytes_end = SharedMemTrait::d0_block_space_offset * sizeof(FloatAB) +
SharedMemTrait::d0_block_space_size_aligned *
D0Operator::template TypeTransform<D0DataType>::Size0;
return math::max(
gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end, d0_bytes_end);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
......@@ -362,14 +367,15 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(C1GridDesc_M_N{}))>;
static constexpr auto AKPerBlock = 32;
static constexpr auto D0N2 = AK1;
static constexpr auto D0N1 = Number<32 / AK1.value>{};
static constexpr auto D0N0 = Number<NPerBlock / 32>{};
static constexpr auto D0N0_PerShuffle = Number<KPerBlock / 32>{};
static constexpr auto D0_NumShuffle = NPerBlock / KPerBlock;
static constexpr auto D0N0_PerShuffle = Number<AKPerBlock / 32>{};
static constexpr auto D0_NumShuffle = NPerBlock / AKPerBlock;
static_assert(NPerBlock % KPerBlock == 0 && KPerBlock % 32 == 0,
"KPerBlock should be multiple of 32 and divisor of NPerBlock");
static_assert(NPerBlock % AKPerBlock == 0 && AKPerBlock % 32 == 0,
"AKPerBlock should be multiple of 32 and divisor of NPerBlock");
__host__ __device__ static constexpr auto
MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3(const D0GridDesc_M_N& d0_grid_desc_m_n)
......@@ -402,12 +408,16 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
template <typename DataType>
struct TypeTransform
{
using Type = DataType;
using Type = DataType;
static constexpr index_t Size0 = sizeof(DataType);
static constexpr index_t Size = sizeof(DataType);
};
template <>
struct TypeTransform<void>
{
using Type = ck::half_t;
using Type = ck::half_t;
static constexpr index_t Size0 = 0;
static constexpr index_t Size = sizeof(ck::half_t);
};
__host__ __device__ static constexpr auto GetD0BlockGlobalDescriptor_M0_N0_N1_N2_M1_N3()
......@@ -499,7 +509,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1);
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize() * Number<Q_d / KPerBlock>{},
a_block_desc_ak0_m_ak1.GetElementSpaceSize() * Number<std::max(Q_d / KPerBlock, 1)>{},
max_lds_align);
static constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
......@@ -521,6 +531,11 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
// LDS allocation for D0 shuffle in LDS
static constexpr auto d0_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Operator::d0_block_dst_desc_m0_n0_n1_n2_m1_n3.GetElementSpaceSize(), max_lds_align);
};
template <bool HasMainKBlockLoop, typename Block2CTileMap, typename C0MatrixMask>
......@@ -686,7 +701,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
const auto b_block_reset_copy_step =
make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0);
const auto Q_k = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
const auto q_k = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v1r1<1>{};
......@@ -937,6 +952,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
continue;
}
// gemm0
const bool is_can_load_once = (q_k <= 64 && KPerBlock <= 64);
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop, FloatAB>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
......@@ -953,7 +969,8 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
acc_thread_buf,
num_k_block_main_loop,
p_shared,
gemm1_k_block_outer_index == 0 || Q_k > 64);
(gemm1_k_block_outer_index == 0 && is_can_load_once) || (!is_can_load_once),
KPerBlock == 32 && q_k == 64);
// do MNK padding or upper triangular masking
if constexpr(MaskOutUpperTriangle || PadN)
......@@ -1029,7 +1046,9 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
p_d0_grid, d0_grid_desc_m0_n0_n1_n2_m1_n3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset *
sizeof(FloatAB) /
sizeof(D0DataType),
D0Operator::d0_block_dst_desc_m0_n0_n1_n2_m1_n3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
......
......@@ -55,7 +55,8 @@ struct GridwiseGemmPipeline_v1r1<1>
CThreadBuffer& c_thread_buf,
index_t num_loop,
void* p_shared,
bool bIsLoadAblock)
bool bIsLoadAblock,
bool bIsChangeCache)
{
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + 0, a_block_desc.GetElementSpaceSize());
......@@ -95,9 +96,14 @@ struct GridwiseGemmPipeline_v1r1<1>
if(bIsLoadAblock)
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_desc.GetElementSpaceSize() * (i + 1),
a_block_desc.GetElementSpaceSize());
ignore = bIsChangeCache;
if(bIsChangeCache)
a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) +
a_block_desc.GetElementSpaceSize() * (i + 1),
a_block_desc.GetElementSpaceSize());
if(bIsLoadAblock)
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment