Commit 9dac9713 authored by mtgu0705's avatar mtgu0705
Browse files

enable blockwise pipelie v1 and v2. v1 is work for small K.

parent d58d55e5
...@@ -72,7 +72,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_ ...@@ -72,7 +72,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 2, S<1, 32, 1, 8>, S<8>, 1, 2, S<1, 32, 1, 8>, S<8>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>;
// clang-format on // clang-format on
int main(int argc, char* argv[]) int main(int argc, char* argv[])
......
...@@ -96,7 +96,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -96,7 +96,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack,
true>
{ {
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize, using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
...@@ -117,7 +118,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -117,7 +118,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>; KPack,
true>;
using Base::I0; using Base::I0;
using Base::KRepeat; using Base::KRepeat;
using Base::xdlops_gemm; using Base::xdlops_gemm;
...@@ -231,11 +233,26 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -231,11 +233,26 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_scale_thread_copy.Run(a_scale_grid_desc, static_for<0, MRepeat, 1>{}([&](auto m0) {
a_scale_grid_buf, a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_thread_desc, a_scale_grid_buf,
make_tuple(I0, I0), a_scale_thread_desc,
a_scale_thread_buf); make_tuple(m0, I0),
a_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<0>{}));
});
if(num_loop_per_scale == 1)
{
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<2>{}));
}
else
{
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<1>{}));
}
b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf, b_scale_grid_buf,
...@@ -243,7 +260,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -243,7 +260,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
make_tuple(I0, I0), make_tuple(I0, I0),
b_scale_thread_buf); b_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
// Local prefill 1 // Local prefill 1
...@@ -318,17 +334,32 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -318,17 +334,32 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) += c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] * c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(a_scale_thread_buf[I0]) * type_convert<AccDataType>(a_scale_thread_buf[m0]) *
type_convert<AccDataType>(b_scale_thread_buf[I0]); type_convert<AccDataType>(b_scale_thread_buf[I0]);
}); });
}); });
}); });
a_scale_thread_copy.Run(a_scale_grid_desc, static_for<0, MRepeat, 1>{}([&](auto m0) {
a_scale_grid_buf, a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_thread_desc, a_scale_grid_buf,
make_tuple(I0, I0), a_scale_thread_desc,
a_scale_thread_buf); make_tuple(m0, I0),
a_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<0>{}));
});
if(num_loop_per_scale == 1)
{
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<2>{}));
}
else
{
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<1>{}));
}
b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf, b_scale_grid_buf,
...@@ -336,7 +367,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -336,7 +367,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
make_tuple(I0, I0), make_tuple(I0, I0),
b_scale_thread_buf); b_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
block_sync_lds(); block_sync_lds();
...@@ -400,7 +430,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -400,7 +430,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) += c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] * c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(a_scale_thread_buf[I0]) * type_convert<AccDataType>(a_scale_thread_buf[m0]) *
type_convert<AccDataType>(b_scale_thread_buf[I0]); type_convert<AccDataType>(b_scale_thread_buf[I0]);
}); });
}); });
......
...@@ -96,7 +96,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -96,7 +96,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack,
true>
{ {
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize, using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
...@@ -117,7 +118,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -117,7 +118,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>; KPack,
true>;
using Base::I0; using Base::I0;
using Base::KRepeat; using Base::KRepeat;
using Base::xdlops_gemm; using Base::xdlops_gemm;
...@@ -270,11 +272,26 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -270,11 +272,26 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_scale_thread_copy.Run(a_scale_grid_desc, static_for<0, MRepeat, 1>{}([&](auto m0) {
a_scale_grid_buf, a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_thread_desc, a_scale_grid_buf,
make_tuple(I0, I0), a_scale_thread_desc,
a_scale_thread_buf); make_tuple(m0, I0),
a_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<0>{}));
});
if(num_loop_per_scale == 1)
{
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<2>{}));
}
else
{
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<1>{}));
}
b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf, b_scale_grid_buf,
...@@ -282,7 +299,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -282,7 +299,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
make_tuple(I0, I0), make_tuple(I0, I0),
b_scale_thread_buf); b_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
// Local prefill 1 // Local prefill 1
...@@ -360,17 +376,32 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -360,17 +376,32 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) += c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] * c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(a_scale_thread_buf[I0]) * type_convert<AccDataType>(a_scale_thread_buf[m0]) *
type_convert<AccDataType>(b_scale_thread_buf[I0]); type_convert<AccDataType>(b_scale_thread_buf[I0]);
}); });
}); });
}); });
a_scale_thread_copy.Run(a_scale_grid_desc, static_for<0, MRepeat, 1>{}([&](auto m0) {
a_scale_grid_buf, a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_thread_desc, a_scale_grid_buf,
make_tuple(I0, I0), a_scale_thread_desc,
a_scale_thread_buf); make_tuple(m0, I0),
a_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<0>{}));
});
if(num_loop_per_scale == 1)
{
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<2>{}));
}
else
{
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<1>{}));
}
b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf, b_scale_grid_buf,
...@@ -378,8 +409,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -378,8 +409,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
make_tuple(I0, I0), make_tuple(I0, I0),
b_scale_thread_buf); b_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step); b_scale_thread_copy_step);
...@@ -453,17 +482,32 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -453,17 +482,32 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) += c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] * c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(a_scale_thread_buf[I0]) * type_convert<AccDataType>(a_scale_thread_buf[m0]) *
type_convert<AccDataType>(b_scale_thread_buf[I0]); type_convert<AccDataType>(b_scale_thread_buf[I0]);
}); });
}); });
}); });
a_scale_thread_copy.Run(a_scale_grid_desc, static_for<0, MRepeat, 1>{}([&](auto m0) {
a_scale_grid_buf, a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_thread_desc, a_scale_grid_buf,
make_tuple(I0, I0), a_scale_thread_desc,
a_scale_thread_buf); make_tuple(m0, I0),
a_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<0>{}));
});
if(num_loop_per_scale == 1)
{
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<2>{}));
}
else
{
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<1>{}));
}
b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf, b_scale_grid_buf,
...@@ -471,7 +515,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -471,7 +515,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
make_tuple(I0, I0), make_tuple(I0, I0),
b_scale_thread_buf); b_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
block_sync_lds(); block_sync_lds();
...@@ -528,7 +571,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -528,7 +571,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) += c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] * c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(a_scale_thread_buf[I0]) * type_convert<AccDataType>(a_scale_thread_buf[m0]) *
type_convert<AccDataType>(b_scale_thread_buf[I0]); type_convert<AccDataType>(b_scale_thread_buf[I0]);
}); });
}); });
...@@ -586,7 +629,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -586,7 +629,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) += c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] * c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(a_scale_thread_buf[I0]) * type_convert<AccDataType>(a_scale_thread_buf[m0]) *
type_convert<AccDataType>(b_scale_thread_buf[I0]); type_convert<AccDataType>(b_scale_thread_buf[I0]);
}); });
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment