"vscode:/vscode.git/clone" did not exist on "fa35750d3ba8c2aadac4ce33515be302e12d00fc"
Commit 72da1930 authored by mtgu0705's avatar mtgu0705
Browse files

add kperblock>scaleK

parent 8713ade3
...@@ -345,6 +345,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -345,6 +345,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_scale_thread_copy_step.At(Number<1>{})); b_scale_thread_copy_step.At(Number<1>{}));
} }
constexpr auto num_scale_k_block = BScaleThreadDesc{}.GetLength(I1);
// Local prefill 1 // Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
...@@ -374,7 +376,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -374,7 +376,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}), make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
b_block_buf, b_block_buf,
b_scale_thread_buf[n0], b_scale_thread_buf[Number<n0 * num_scale_k_block + k0/num_scale_k_block>{}],
b_thread_desc_, b_thread_desc_,
make_tuple(n0, I0, k0, I0), make_tuple(n0, I0, k0, I0),
b_thread_buf); b_thread_buf);
...@@ -467,7 +469,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -467,7 +469,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}), make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
b_block_buf, b_block_buf,
b_scale_thread_buf[n0], b_scale_thread_buf[Number<n0 * num_scale_k_block + k0/num_scale_k_block>{}],
b_thread_desc_, b_thread_desc_,
make_tuple(n0, I0, k0, I0), make_tuple(n0, I0, k0, I0),
b_thread_buf); b_thread_buf);
......
...@@ -1423,9 +1423,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1423,9 +1423,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
KPerBlock); KPerBlock);
// b scale // b scale
static_assert(KPerBlock <= ScaleBlockK); // static_assert(KPerBlock <= ScaleBlockK);
const index_t ScaleSliceSizeN = NXdlPerWave; const index_t ScaleSliceSizeN = NXdlPerWave;
const index_t ScaleSliceSizeK = 1; const index_t ScaleSliceSizeK = (KPerBlock + ScaleBlockK - 1) / ScaleBlockK;
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{})); make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
...@@ -1452,9 +1452,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1452,9 +1452,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr auto b_scale_thread_slice_copy_step = constexpr auto b_scale_thread_slice_copy_step =
make_tuple(make_multi_index(NWaves * NPerXdl, 0), make_tuple(make_multi_index(NWaves * NPerXdl, 0),
make_multi_index(-NPerBlock, 0), make_multi_index(-NPerBlock, 0),
make_multi_index(-NPerBlock, 1)); make_multi_index(-NPerBlock, ScaleSliceSizeK));
const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock; const index_t num_k_block_per_scale = (ScaleBlockK + KPerBlock - 1) / KPerBlock;
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>( blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment