"vscode:/vscode.git/clone" did not exist on "a5842a7fbac50b19a6ede3784268c34a65e77cbf"
Commit 1867ffa3 authored by mtgu0705's avatar mtgu0705
Browse files

Add Scale_Block_M = 1 based on ab_scale support

parent f728087c
...@@ -338,18 +338,27 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -338,18 +338,27 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0){ static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, Number<16>{}, 1>{}([&](auto i0) {
a_scale_thread_copy.Run(a_scale_grid_desc, a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf, a_scale_grid_buf,
a_scale_thread_desc, a_scale_thread_desc,
make_tuple(m0, I0), make_tuple(m0, i0, I0),
a_scale_thread_buf); a_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
if((i0 + 1) % 4 == 0)
{
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
}
});
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<0>{})); a_scale_thread_copy_step.At(Number<0>{}));
}); });
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<1>{})); a_scale_thread_copy_step.At(Number<3>{}));
b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf, b_scale_grid_buf,
...@@ -443,7 +452,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -443,7 +452,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) += c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] * c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(a_scale_thread_buf[m0]) * type_convert<AccDataType>(a_scale_thread_buf(Number<m0 * xdlops_gemm.GetRegSizePerXdlops() + t>{})) *
// type_convert<AccDataType>(a_scale_thread_buf(Number<a_offset>{})) *
type_convert<AccDataType>(b_scale_thread_buf[I0]); type_convert<AccDataType>(b_scale_thread_buf[I0]);
}); });
}); });
...@@ -468,18 +478,27 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -468,18 +478,27 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
b_thread_buf); b_thread_buf);
}); });
}); });
static_for<0,MRepeat,1>{}([&](auto m0){ static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, Number<16>{}, 1>{}([&](auto i0) {
a_scale_thread_copy.Run(a_scale_grid_desc, a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf, a_scale_grid_buf,
a_scale_thread_desc, a_scale_thread_desc,
make_tuple(m0, I0), make_tuple(m0, i0, I0),
a_scale_thread_buf); a_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
if((i0 + 1) % 4 == 0)
{
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
}
});
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<0>{})); a_scale_thread_copy_step.At(Number<0>{}));
}); });
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<1>{})); a_scale_thread_copy_step.At(Number<3>{}));
b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf, b_scale_grid_buf,
...@@ -526,7 +545,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -526,7 +545,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) += c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] * c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(a_scale_thread_buf[m0]) * type_convert<AccDataType>(a_scale_thread_buf(Number<m0 * xdlops_gemm.GetRegSizePerXdlops() + t>{})) *
// type_convert<AccDataType>(a_scale_thread_buf(Number<a_offset>{})) *
type_convert<AccDataType>(b_scale_thread_buf[I0]); type_convert<AccDataType>(b_scale_thread_buf[I0]);
}); });
}); });
......
...@@ -1213,8 +1213,9 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 ...@@ -1213,8 +1213,9 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.M, ScaleBlockM), make_tuple(math::integer_divide_ceil(problem.M, ScaleBlockM),
16,
math::integer_divide_ceil(problem.K, ScaleBlockK)), math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK),0, 1));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)), math::integer_divide_ceil(problem.K, ScaleBlockK)),
...@@ -1357,34 +1358,36 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 ...@@ -1357,34 +1358,36 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock); KPerBlock);
static constexpr auto xdlops_gemm =
XdlopsGemm<ComputeTypeA, MPerXdl, NPerXdl, KPack, ComputeTypeA, false>{};
const index_t ScaleSliceSizeM = MXdlPerWave; const index_t ScaleSliceSizeM = MXdlPerWave;
const index_t RegSizePerXdlops = xdlops_gemm.GetRegSizePerXdlops();
const index_t ScaleSliceSizeN = 1; const index_t ScaleSliceSizeN = 1;
const index_t ScaleSliceSizeK = 1; const index_t ScaleSliceSizeK = 1;
constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{})); make_tuple(Number<ScaleSliceSizeM>{}, Number<RegSizePerXdlops>{}, Number<ScaleSliceSizeK>{}));
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{})); make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
// auto a_thread_offset =
// get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) % MWaves * MPerXdl;
auto a_thread_offset = get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 128) * MPerXdl; auto a_thread_offset =
get_thread_local_1d_id() / 128 * 32 + get_thread_local_1d_id() % 64 / 32 * 4;
auto a_scale_thread_copy = auto a_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<AScaleType, ThreadwiseTensorSliceTransfer_v2<AScaleType,
AScaleType, AScaleType,
decltype(a_scale_grid_desc_am_ak), decltype(a_scale_grid_desc_am_ak),
decltype(a_scale_thread_desc), decltype(a_scale_thread_desc),
Sequence<1, ScaleSliceSizeK>, Sequence<1, 1, ScaleSliceSizeK>,
Sequence<0, 1>, Sequence<0, 1, 2>,
1, 1,
1, 1,
1, 1,
false>( false>(
a_scale_grid_desc_am_ak, make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset, 0)); a_scale_grid_desc_am_ak, make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset,0, 0));
auto b_scale_thread_copy = auto b_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<BScaleType, ThreadwiseTensorSliceTransfer_v2<BScaleType,
...@@ -1400,7 +1403,10 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 ...@@ -1400,7 +1403,10 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
constexpr auto a_scale_thread_slice_copy_step = constexpr auto a_scale_thread_slice_copy_step =
make_tuple(make_multi_index(MWaves * MPerXdl, 0), make_multi_index(-MPerBlock, 1)); make_tuple(make_multi_index(MPerXdl, 0, 0),
make_multi_index(4, 0, 0),
make_multi_index(1, 0, 0),
make_multi_index(-MPerBlock, 0, 1));
constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, 1); constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, 1);
const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock; const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment