Commit 38eb3bba authored by letaoqin's avatar letaoqin
Browse files

Q load once when Q_k <= 64

parent 2f93e26f
...@@ -8,7 +8,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -8,7 +8,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
|-------------------------------------| |-------------------------------------|
Gemm1 Gemm1
*/ */
#define DIM 128 // DIM should be a multiple of 8. #define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp" #include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1r1.hpp"
namespace ck { namespace ck {
...@@ -95,14 +96,15 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -95,14 +96,15 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
static constexpr auto I0 = Number<0>{}; static constexpr auto Q_d = 64;
static constexpr auto I1 = Number<1>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I4 = Number<4>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto I5 = Number<5>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto I6 = Number<6>{}; static constexpr auto I5 = Number<5>{};
static constexpr auto I7 = Number<7>{}; static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto WaveSize = 64; static constexpr auto WaveSize = 64;
// K1 should be Number<...> // K1 should be Number<...>
...@@ -232,12 +234,11 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -232,12 +234,11 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
const index_t gemm1_bytes_end = const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) * (SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(FloatAB); sizeof(FloatAB);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset + const index_t softmax_bytes_end =
SharedMemTrait::reduction_space_size_aligned) * SharedMemTrait::reduction_space_offset * sizeof(FloatAB) +
sizeof(FloatGemmAcc); SharedMemTrait::reduction_space_size_aligned * sizeof(FloatGemmAcc);
const index_t c_block_bytes_end = const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle); SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end); return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end);
} }
...@@ -498,7 +499,8 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -498,7 +499,8 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1); static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1);
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple( static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize() * Number<Q_d / KPerBlock>{},
max_lds_align);
static constexpr auto b_block_space_size_aligned = math::integer_least_multiple( static constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple( static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
...@@ -506,13 +508,13 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -506,13 +508,13 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
static constexpr auto a_block_space_offset = 0; static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned.value; static constexpr auto b_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0; static constexpr auto b1_block_space_offset = a_block_space_size_aligned.value;
// LDS allocation for reduction // LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned = static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align); math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset = 0; static constexpr auto reduction_space_offset = a_block_space_size_aligned.value;
// LDS allocation for C shuffle in LDS // LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
...@@ -611,7 +613,6 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -611,7 +613,6 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
...@@ -673,11 +674,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -673,11 +674,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment // LDS allocation for B: be careful of alignment
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + SharedMemTrait::a_block_space_offset,
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b_block_space_offset, static_cast<FloatAB*>(p_shared) + SharedMemTrait::b_block_space_offset,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
...@@ -689,11 +686,13 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -689,11 +686,13 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
const auto b_block_reset_copy_step = const auto b_block_reset_copy_step =
make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0); make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0);
const auto Q_k = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
// gridwise GEMM pipeline // gridwise GEMM pipeline
// Only supports LoopScheduler::Default // Only supports LoopScheduler::Default
const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_Selector<PipelineVer, const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v1r1<
NumGemmKPrefetchStage, NumGemmKPrefetchStage>{}; /*GridwiseGemmPipeline_Selector<PipelineVer,
LoopScheduler::Default>(); NumGemmKPrefetchStage,
LoopScheduler::Default>();*/
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
...@@ -843,7 +842,8 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -843,7 +842,8 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
// Blockwise softmax // Blockwise softmax
// //
auto workspace_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto workspace_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatGemmAcc*>(p_shared) + SharedMemTrait::reduction_space_offset, static_cast<FloatGemmAcc*>(p_shared) +
SharedMemTrait::reduction_space_offset * sizeof(FloatAB) / sizeof(FloatGemmAcc),
SharedMemTrait::reduction_space_size_aligned); SharedMemTrait::reduction_space_size_aligned);
// get acc0 8D thread cluster // get acc0 8D thread cluster
...@@ -940,21 +940,23 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -940,21 +940,23 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
continue; continue;
} }
// gemm0 // gemm0
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1, gridwise_gemm_pipeline.template Run<HasMainKBlockLoop, FloatAB>(
a_block_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
a_blockwise_copy, a_block_desc_ak0_m_ak1,
a_grid_buf, a_blockwise_copy,
a_block_buf, a_grid_buf,
a_block_slice_copy_step, a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1,
b_blockwise_copy, b_blockwise_copy,
b_grid_buf, b_grid_buf,
b_block_buf, b_block_buf,
b_block_slice_copy_step, b_block_slice_copy_step,
blockwise_gemm, blockwise_gemm,
acc_thread_buf, acc_thread_buf,
num_k_block_main_loop); num_k_block_main_loop,
p_shared,
gemm1_k_block_outer_index == 0 && Q_k <= 64);
// do MNK padding or upper triangular masking // do MNK padding or upper triangular masking
if constexpr(MaskOutUpperTriangle || PadN) if constexpr(MaskOutUpperTriangle || PadN)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment