Commit 72da1930 authored by mtgu0705's avatar mtgu0705
Browse files

add kperblock>scaleK

parent 8713ade3
......@@ -345,6 +345,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_scale_thread_copy_step.At(Number<1>{}));
}
constexpr auto num_scale_k_block = BScaleThreadDesc{}.GetLength(I1);
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
......@@ -374,7 +376,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
b_block_buf,
b_scale_thread_buf[n0],
b_scale_thread_buf[Number<n0 * num_scale_k_block + k0/num_scale_k_block>{}],
b_thread_desc_,
make_tuple(n0, I0, k0, I0),
b_thread_buf);
......@@ -467,7 +469,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
b_block_buf,
b_scale_thread_buf[n0],
b_scale_thread_buf[Number<n0 * num_scale_k_block + k0/num_scale_k_block>{}],
b_thread_desc_,
make_tuple(n0, I0, k0, I0),
b_thread_buf);
......
......@@ -1423,9 +1423,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
KPerBlock);
// b scale
static_assert(KPerBlock <= ScaleBlockK);
// static_assert(KPerBlock <= ScaleBlockK);
const index_t ScaleSliceSizeN = NXdlPerWave;
const index_t ScaleSliceSizeK = 1;
const index_t ScaleSliceSizeK = (KPerBlock + ScaleBlockK - 1) / ScaleBlockK;
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
......@@ -1452,9 +1452,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr auto b_scale_thread_slice_copy_step =
make_tuple(make_multi_index(NWaves * NPerXdl, 0),
make_multi_index(-NPerBlock, 0),
make_multi_index(-NPerBlock, 1));
make_multi_index(-NPerBlock, ScaleSliceSizeK));
const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock;
const index_t num_k_block_per_scale = (ScaleBlockK + KPerBlock - 1) / KPerBlock;
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment