Commit 38eb3bba authored by letaoqin's avatar letaoqin
Browse files

Q load once when Q_k <= 64

parent 2f93e26f
......@@ -8,7 +8,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
|-------------------------------------|
Gemm1
*/
#define DIM 128 // DIM should be a multiple of 8.
#define DIM 64 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......
......@@ -15,6 +15,7 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1r1.hpp"
namespace ck {
......@@ -95,6 +96,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
static constexpr auto Q_d = 64;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
......@@ -232,12 +234,11 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(FloatAB);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
const index_t softmax_bytes_end =
SharedMemTrait::reduction_space_offset * sizeof(FloatAB) +
SharedMemTrait::reduction_space_size_aligned * sizeof(FloatGemmAcc);
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end);
}
......@@ -498,7 +499,8 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1);
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
a_block_desc_ak0_m_ak1.GetElementSpaceSize() * Number<Q_d / KPerBlock>{},
max_lds_align);
static constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
......@@ -506,13 +508,13 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0;
static constexpr auto b1_block_space_offset = a_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset = 0;
static constexpr auto reduction_space_offset = a_block_space_size_aligned.value;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
......@@ -611,7 +613,6 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
......@@ -673,11 +674,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + SharedMemTrait::a_block_space_offset,
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
// LDS allocation for B: be careful of alignment
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b_block_space_offset,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
......@@ -689,11 +686,13 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
const auto b_block_reset_copy_step =
make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0);
const auto Q_k = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_Selector<PipelineVer,
const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v1r1<
NumGemmKPrefetchStage>{}; /*GridwiseGemmPipeline_Selector<PipelineVer,
NumGemmKPrefetchStage,
LoopScheduler::Default>();
LoopScheduler::Default>();*/
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
......@@ -843,7 +842,8 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
// Blockwise softmax
//
auto workspace_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatGemmAcc*>(p_shared) + SharedMemTrait::reduction_space_offset,
static_cast<FloatGemmAcc*>(p_shared) +
SharedMemTrait::reduction_space_offset * sizeof(FloatAB) / sizeof(FloatGemmAcc),
SharedMemTrait::reduction_space_size_aligned);
// get acc0 8D thread cluster
......@@ -940,11 +940,11 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
continue;
}
// gemm0
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop, FloatAB>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
......@@ -954,7 +954,9 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
b_block_slice_copy_step,
blockwise_gemm,
acc_thread_buf,
num_k_block_main_loop);
num_k_block_main_loop,
p_shared,
gemm1_k_block_outer_index == 0 && Q_k <= 64);
// do MNK padding or upper triangular masking
if constexpr(MaskOutUpperTriangle || PadN)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment