Commit 940f786e authored by letaoqin's avatar letaoqin
Browse files

fix d0 load parameter

parent 9de99bdb
...@@ -239,7 +239,12 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -239,7 +239,12 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
SharedMemTrait::reduction_space_size_aligned * sizeof(FloatGemmAcc); SharedMemTrait::reduction_space_size_aligned * sizeof(FloatGemmAcc);
const index_t c_block_bytes_end = const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle); SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end);
const index_t d0_bytes_end = SharedMemTrait::d0_block_space_offset * sizeof(FloatAB) +
SharedMemTrait::d0_block_space_size_aligned *
D0Operator::template TypeTransform<D0DataType>::Size0;
return math::max(
gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end, d0_bytes_end);
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...@@ -362,14 +367,15 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -362,14 +367,15 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(C1GridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(C1GridDesc_M_N{}))>;
static constexpr auto AKPerBlock = 32;
static constexpr auto D0N2 = AK1; static constexpr auto D0N2 = AK1;
static constexpr auto D0N1 = Number<32 / AK1.value>{}; static constexpr auto D0N1 = Number<32 / AK1.value>{};
static constexpr auto D0N0 = Number<NPerBlock / 32>{}; static constexpr auto D0N0 = Number<NPerBlock / 32>{};
static constexpr auto D0N0_PerShuffle = Number<KPerBlock / 32>{}; static constexpr auto D0N0_PerShuffle = Number<AKPerBlock / 32>{};
static constexpr auto D0_NumShuffle = NPerBlock / KPerBlock; static constexpr auto D0_NumShuffle = NPerBlock / AKPerBlock;
static_assert(NPerBlock % KPerBlock == 0 && KPerBlock % 32 == 0, static_assert(NPerBlock % AKPerBlock == 0 && AKPerBlock % 32 == 0,
"KPerBlock should be multiple of 32 and divisor of NPerBlock"); "AKPerBlock should be multiple of 32 and divisor of NPerBlock");
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3(const D0GridDesc_M_N& d0_grid_desc_m_n) MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3(const D0GridDesc_M_N& d0_grid_desc_m_n)
...@@ -402,12 +408,16 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -402,12 +408,16 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
template <typename DataType> template <typename DataType>
struct TypeTransform struct TypeTransform
{ {
using Type = DataType; using Type = DataType;
static constexpr index_t Size0 = sizeof(DataType);
static constexpr index_t Size = sizeof(DataType);
}; };
template <> template <>
struct TypeTransform<void> struct TypeTransform<void>
{ {
using Type = ck::half_t; using Type = ck::half_t;
static constexpr index_t Size0 = 0;
static constexpr index_t Size = sizeof(ck::half_t);
}; };
__host__ __device__ static constexpr auto GetD0BlockGlobalDescriptor_M0_N0_N1_N2_M1_N3() __host__ __device__ static constexpr auto GetD0BlockGlobalDescriptor_M0_N0_N1_N2_M1_N3()
...@@ -499,7 +509,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -499,7 +509,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1); static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1);
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple( static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize() * Number<Q_d / KPerBlock>{}, a_block_desc_ak0_m_ak1.GetElementSpaceSize() * Number<std::max(Q_d / KPerBlock, 1)>{},
max_lds_align); max_lds_align);
static constexpr auto b_block_space_size_aligned = math::integer_least_multiple( static constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
...@@ -521,6 +531,11 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -521,6 +531,11 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size = static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
// LDS allocation for D0 shuffle in LDS
static constexpr auto d0_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Operator::d0_block_dst_desc_m0_n0_n1_n2_m1_n3.GetElementSpaceSize(), max_lds_align);
}; };
template <bool HasMainKBlockLoop, typename Block2CTileMap, typename C0MatrixMask> template <bool HasMainKBlockLoop, typename Block2CTileMap, typename C0MatrixMask>
...@@ -686,7 +701,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -686,7 +701,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
const auto b_block_reset_copy_step = const auto b_block_reset_copy_step =
make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0); make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0);
const auto Q_k = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); const auto q_k = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
// gridwise GEMM pipeline // gridwise GEMM pipeline
// Only supports LoopScheduler::Default // Only supports LoopScheduler::Default
const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v1r1<1>{}; const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v1r1<1>{};
...@@ -937,6 +952,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -937,6 +952,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
continue; continue;
} }
// gemm0 // gemm0
const bool is_can_load_once = (q_k <= 64 && KPerBlock <= 64);
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop, FloatAB>( gridwise_gemm_pipeline.template Run<HasMainKBlockLoop, FloatAB>(
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
...@@ -953,7 +969,8 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -953,7 +969,8 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
acc_thread_buf, acc_thread_buf,
num_k_block_main_loop, num_k_block_main_loop,
p_shared, p_shared,
gemm1_k_block_outer_index == 0 || Q_k > 64); (gemm1_k_block_outer_index == 0 && is_can_load_once) || (!is_can_load_once),
KPerBlock == 32 && q_k == 64);
// do MNK padding or upper triangular masking // do MNK padding or upper triangular masking
if constexpr(MaskOutUpperTriangle || PadN) if constexpr(MaskOutUpperTriangle || PadN)
...@@ -1029,7 +1046,9 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -1029,7 +1046,9 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
p_d0_grid, d0_grid_desc_m0_n0_n1_n2_m1_n3.GetElementSpaceSize()); p_d0_grid, d0_grid_desc_m0_n0_n1_n2_m1_n3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::a_block_space_offset, static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset *
sizeof(FloatAB) /
sizeof(D0DataType),
D0Operator::d0_block_dst_desc_m0_n0_n1_n2_m1_n3.GetElementSpaceSize()); D0Operator::d0_block_dst_desc_m0_n0_n1_n2_m1_n3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>( auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
......
...@@ -55,7 +55,8 @@ struct GridwiseGemmPipeline_v1r1<1> ...@@ -55,7 +55,8 @@ struct GridwiseGemmPipeline_v1r1<1>
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop, index_t num_loop,
void* p_shared, void* p_shared,
bool bIsLoadAblock) bool bIsLoadAblock,
bool bIsChangeCache)
{ {
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + 0, a_block_desc.GetElementSpaceSize()); static_cast<FloatAB*>(p_shared) + 0, a_block_desc.GetElementSpaceSize());
...@@ -95,9 +96,14 @@ struct GridwiseGemmPipeline_v1r1<1> ...@@ -95,9 +96,14 @@ struct GridwiseGemmPipeline_v1r1<1>
if(bIsLoadAblock) if(bIsLoadAblock)
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_desc.GetElementSpaceSize() * (i + 1), ignore = bIsChangeCache;
a_block_desc.GetElementSpaceSize()); if(bIsChangeCache)
a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) +
a_block_desc.GetElementSpaceSize() * (i + 1),
a_block_desc.GetElementSpaceSize());
if(bIsLoadAblock) if(bIsLoadAblock)
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment