Commit 1fcd3329 authored by mtgu0705's avatar mtgu0705
Browse files

Enable multiply_multiply for Scale_Block_M = 1 for deepseek

parent e5bc56a4
...@@ -26,6 +26,7 @@ using S = ck::Sequence<Is...>; ...@@ -26,6 +26,7 @@ using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using FP8 = ck::f8_t; using FP8 = ck::f8_t;
using F16 = ck::half_t;
using F32 = float; using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
...@@ -55,7 +56,7 @@ using CDEElementOp = PassThrough; ...@@ -55,7 +56,7 @@ using CDEElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t Scale_Block_M = 128; static constexpr ck::index_t Scale_Block_M = 1;
static constexpr ck::index_t Scale_Block_N = 128; static constexpr ck::index_t Scale_Block_N = 128;
static constexpr ck::index_t Scale_Block_K = 128; static constexpr ck::index_t Scale_Block_K = 128;
...@@ -67,8 +68,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_ ...@@ -67,8 +68,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_
256, Scale_Block_M, Scale_Block_N, Scale_Block_K, 256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
128, 128, 128, 128,
128, 16, 16, 128, 16, 16,
16, 16, 32, 32,
4, 4, 2, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>,
...@@ -187,6 +188,18 @@ int main(int argc, char* argv[]) ...@@ -187,6 +188,18 @@ int main(int argc, char* argv[])
a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0}); a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0}); b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
break; break;
case 5:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
break;
case 6:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_m_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
b1_k_n.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
break;
default: default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5}); a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5}); b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
......
...@@ -338,11 +338,18 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -338,11 +338,18 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_scale_thread_copy.Run(a_scale_grid_desc, static_for<0, MRepeat, 1>{}([&](auto m0){
a_scale_grid_buf, a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_thread_desc, a_scale_grid_buf,
make_tuple(I0, I0), a_scale_thread_desc,
a_scale_thread_buf); make_tuple(m0, I0),
a_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<0>{}));
});
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<1>{}));
b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf, b_scale_grid_buf,
...@@ -350,7 +357,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -350,7 +357,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
make_tuple(I0, I0), make_tuple(I0, I0),
b_scale_thread_buf); b_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
// Local prefill 1 // Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
...@@ -437,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -437,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) += c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] * c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(a_scale_thread_buf[I0]) * type_convert<AccDataType>(a_scale_thread_buf[m0]) *
type_convert<AccDataType>(b_scale_thread_buf[I0]); type_convert<AccDataType>(b_scale_thread_buf[I0]);
}); });
}); });
...@@ -462,11 +468,18 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -462,11 +468,18 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
b_thread_buf); b_thread_buf);
}); });
}); });
a_scale_thread_copy.Run(a_scale_grid_desc, static_for<0,MRepeat,1>{}([&](auto m0){
a_scale_grid_buf, a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_thread_desc, a_scale_grid_buf,
make_tuple(I0, I0), a_scale_thread_desc,
a_scale_thread_buf); make_tuple(m0, I0),
a_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<0>{}));
});
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<1>{}));
b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf, b_scale_grid_buf,
...@@ -474,7 +487,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -474,7 +487,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
make_tuple(I0, I0), make_tuple(I0, I0),
b_scale_thread_buf); b_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
HotLoopScheduler(); HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -514,7 +526,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -514,7 +526,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) += c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] * c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(a_scale_thread_buf[I0]) * type_convert<AccDataType>(a_scale_thread_buf[m0]) *
type_convert<AccDataType>(b_scale_thread_buf[I0]); type_convert<AccDataType>(b_scale_thread_buf[I0]);
}); });
}); });
......
...@@ -363,10 +363,10 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 ...@@ -363,10 +363,10 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
return false; return false;
} }
if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK != KPerBlock) // if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK != KPerBlock)
{ // {
return false; // return false;
} // }
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::NKPadding ||
......
...@@ -1357,7 +1357,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 ...@@ -1357,7 +1357,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock); KPerBlock);
const index_t ScaleSliceSizeM = 1; const index_t ScaleSliceSizeM = MXdlPerWave;
const index_t ScaleSliceSizeN = 1; const index_t ScaleSliceSizeN = 1;
const index_t ScaleSliceSizeK = 1; const index_t ScaleSliceSizeK = 1;
...@@ -1365,20 +1365,24 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 ...@@ -1365,20 +1365,24 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{})); make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{}));
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{})); make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
auto a_thread_offset =
get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) % MWaves * MPerXdl;
auto a_scale_thread_copy = auto a_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<AScaleType, ThreadwiseTensorSliceTransfer_v2<AScaleType,
AScaleType, AScaleType,
decltype(a_scale_grid_desc_am_ak), decltype(a_scale_grid_desc_am_ak),
decltype(a_scale_thread_desc), decltype(a_scale_thread_desc),
Sequence<ScaleSliceSizeM, ScaleSliceSizeK>, Sequence<1, ScaleSliceSizeK>,
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
1, 1,
1, 1,
false>( false>(
a_scale_grid_desc_am_ak, make_multi_index(block_m_id * MPerBlock / ScaleBlockM, 0)); a_scale_grid_desc_am_ak, make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset, 0));
auto b_scale_thread_copy = auto b_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<BScaleType, ThreadwiseTensorSliceTransfer_v2<BScaleType,
...@@ -1393,7 +1397,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 ...@@ -1393,7 +1397,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
false>( false>(
b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1); constexpr auto a_scale_thread_slice_copy_step =
make_tuple(make_multi_index(MWaves * MPerXdl, 0), make_multi_index(-MPerBlock, 1));
constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, 1); constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, 1);
const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock; const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment