Commit 87ad5225 authored by aska-0096's avatar aska-0096
Browse files

Bug fix

parent a75152d6
...@@ -62,7 +62,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio ...@@ -62,7 +62,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
// static constexpr ck::index_t Scale_Block_M = 128; // static constexpr ck::index_t Scale_Block_M = 128;
static constexpr ck::index_t Scale_Block_N = 1; static constexpr ck::index_t Scale_Block_N = 1;
static constexpr ck::index_t Scale_Block_K = 128; static constexpr ck::index_t Scale_Block_K = 64;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_Xdl_CShuffle_V3 using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_Xdl_CShuffle_V3
// clang-format off // clang-format off
...@@ -70,18 +70,18 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X ...@@ -70,18 +70,18 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X
A0DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, A0DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, AElementOp, BElementOp, CDEElementOp, GemmSpec,
256, Scale_Block_N, Scale_Block_K, 256, Scale_Block_N, Scale_Block_K,
128, 128, 128, 128, 128, 64,
// 16, 16, // 16, 16,
8, 8, 8, 8,
16, 16, 16, 16,
4, 4, 4, 4,
// S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>,
// ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>;
// clang-format on // clang-format on
template <typename IntType> template <typename IntType>
......
...@@ -346,15 +346,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -346,15 +346,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// make_tuple(I0, I0), // make_tuple(I0, I0),
// a_scale_thread_buf); // a_scale_thread_buf);
b_scale_thread_copy.Run(b_scale_grid_desc, static_for<0, NRepeat, 1>{}([&](auto n0) {
b_scale_grid_buf, b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_thread_desc, b_scale_grid_buf,
make_tuple(I0, I0), b_scale_thread_desc,
b_scale_thread_buf); make_tuple(n0, I0),
b_scale_thread_buf);
// a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
// Local prefill 1 b_scale_thread_copy_step.At(Number<0>{}));
});
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<1>{}));
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
...@@ -470,15 +475,23 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -470,15 +475,23 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// make_tuple(I0, I0), // make_tuple(I0, I0),
// a_scale_thread_buf); // a_scale_thread_buf);
b_scale_thread_copy.Run(b_scale_grid_desc, static_for<0, NRepeat, 1>{}([&](auto n0) {
b_scale_grid_buf, b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_thread_desc, b_scale_grid_buf,
make_tuple(I0, I0), b_scale_thread_desc,
b_scale_thread_buf); make_tuple(n0, I0),
b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{}));
});
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<1>{}));
// a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, // a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
// a_scale_thread_copy_step); // a_scale_thread_copy_step);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
HotLoopScheduler(); HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
i += 1; i += 1;
...@@ -517,7 +530,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -517,7 +530,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
c_thread_buf(Number<c_offset>{}) += c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] * c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(a_scale_thread_buf[I0]) * // type_convert<AccDataType>(a_scale_thread_buf[I0]) *
type_convert<AccDataType>(b_scale_thread_buf[I0]); type_convert<AccDataType>(b_scale_thread_buf[n0]);
}); });
}); });
}); });
......
...@@ -1383,21 +1383,28 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3 ...@@ -1383,21 +1383,28 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
// a_scale_grid_desc_am_ak, make_multi_index(block_m_id * MPerBlock / ScaleBlockM, // a_scale_grid_desc_am_ak, make_multi_index(block_m_id * MPerBlock / ScaleBlockM,
// 0)); // 0));
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
auto b_thread_offset =
get_thread_local_1d_id() % NPerXdl + (get_thread_local_1d_id() / 64) % NWaves * NPerXdl;
auto b_scale_thread_copy = auto b_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<BScaleType, ThreadwiseTensorSliceTransfer_v2<BScaleType,
BScaleType, BScaleType,
decltype(b_scale_grid_desc_bn_ak), decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_desc), decltype(b_scale_thread_desc),
Sequence<ScaleSliceSizeN, ScaleSliceSizeK>, Sequence<1, ScaleSliceSizeK>,
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
1, 1,
1, 1,
false>( false>(
b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset, 0));
// constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1); // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, 1); constexpr auto b_scale_thread_slice_copy_step =
make_tuple(make_multi_index(NWaves * NPerXdl, 0), make_multi_index(-NPerBlock, 1));
const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock; const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment