Unverified Commit abae2afc authored by rocking's avatar rocking Committed by GitHub
Browse files

support max3 in smoothquant and add+ rmsnorm + rdquant (#1654)

* Fix cmake example build

* Support max3 in smoothquant one pass

* support max3 in two pass

* support max3 in add_rmsnorm_rdquant
parent bfe983a1
...@@ -18,7 +18,7 @@ function (add_smoothquant_example TARGET_NAME MAIN_SRC) ...@@ -18,7 +18,7 @@ function (add_smoothquant_example TARGET_NAME MAIN_SRC)
target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS}) target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS})
endfunction(add_smoothquant_example TARGET_NAME MAIN_SRC) endfunction(add_smoothquant_example TARGET_NAME MAIN_SRC)
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_smoothquant_example(tile_smoothquant smoothquant.cpp ${INSTANCE_SRCS})
add_smoothquant_example(tile_example_smoothquant example_smoothquant.cpp) add_smoothquant_example(tile_example_smoothquant example_smoothquant.cpp)
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_smoothquant_example(tile_smoothquant smoothquant.cpp ${INSTANCE_SRCS})
...@@ -28,8 +28,9 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass ...@@ -28,8 +28,9 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
static constexpr bool kSaveX = Problem::kSaveX; static constexpr bool kSaveX = Problem::kSaveX;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseMax3 = true; // TODO - Move to trait
static constexpr const char* name = []() { static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync) if constexpr(kNeedCrossWarpSync)
...@@ -69,9 +70,16 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass ...@@ -69,9 +70,16 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
auto reduce_square_sum_func = ReduceOp::SquareAdd{}; auto reduce_square_sum_func = ReduceOp::SquareAdd{};
auto reduce_sum_func = ReduceOp::Add{}; auto reduce_sum_func = ReduceOp::Add{};
auto reduce_absmax_func = ReduceOp::AbsMax{}; auto reduce_absmax_func = ReduceOp::AbsMax{};
auto reduce_max_func = ReduceOp::Max{}; auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) {
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>(); float rtn;
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>(); asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
: "=v"(rtn)
: "v"(acc_), "v"(v_0_), "v"(v_1_));
return rtn;
};
auto reduce_max_func = ReduceOp::Max{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_cross_warp_sync = auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>(); Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
...@@ -116,8 +124,23 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass ...@@ -116,8 +124,23 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
}); });
// compute absmax, each-thread->cross-lane->cross-warp // compute absmax, each-thread->cross-lane->cross-warp
auto absmax = block_reduce2d( auto absmax = [&]() {
y, reduce_absmax_func.GetIdentityValue<ComputeDataType>(), reduce_absmax_func); constexpr auto x_size_per_row =
x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{});
if constexpr(UseMax3 && std::is_same_v<ComputeDataType, float> &&
x_size_per_row % 2 == 0)
{
return block_reduce2d(y,
reduce_absmax_func.GetIdentityValue<ComputeDataType>(),
reduce_absmax3_func,
sequence<1, 2>{});
}
else
{
return block_reduce2d(
y, reduce_absmax_func.GetIdentityValue<ComputeDataType>(), reduce_absmax_func);
}
}();
block_reduce2d_sync(absmax, reduce_max_func); block_reduce2d_sync(absmax, reduce_max_func);
block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func); block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func);
......
...@@ -28,8 +28,9 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass ...@@ -28,8 +28,9 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
static constexpr bool kSaveX = Problem::kSaveX; static constexpr bool kSaveX = Problem::kSaveX;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseMax3 = true; // TODO - Move to trait
static constexpr const char* name = []() { static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync) if constexpr(kNeedCrossWarpSync)
...@@ -76,9 +77,16 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass ...@@ -76,9 +77,16 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
auto reduce_square_sum_func = ReduceOp::SquareAdd{}; auto reduce_square_sum_func = ReduceOp::SquareAdd{};
auto reduce_sum_func = ReduceOp::Add{}; auto reduce_sum_func = ReduceOp::Add{};
auto reduce_absmax_func = ReduceOp::AbsMax{}; auto reduce_absmax_func = ReduceOp::AbsMax{};
auto reduce_max_func = ReduceOp::Max{}; auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) {
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>(); float rtn;
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>(); asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
: "=v"(rtn)
: "v"(acc_), "v"(v_0_), "v"(v_1_));
return rtn;
};
auto reduce_max_func = ReduceOp::Max{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_cross_warp_sync = auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>(); Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
...@@ -177,7 +185,13 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass ...@@ -177,7 +185,13 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
y(idx) = type_convert<ComputeDataType>(y_); y(idx) = type_convert<ComputeDataType>(y_);
}); });
block_reduce2d(y, absmax, reduce_absmax_func); constexpr auto x_size_per_row =
x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{});
if constexpr(UseMax3 && std::is_same_v<ComputeDataType, float> &&
x_size_per_row % 2 == 0)
block_reduce2d(y, absmax, reduce_absmax3_func, sequence<1, 2>{});
else
block_reduce2d(y, absmax, reduce_absmax_func);
if constexpr(kSaveX) if constexpr(kSaveX)
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
......
...@@ -25,6 +25,7 @@ struct SmoothquantPipelineOnePass ...@@ -25,6 +25,7 @@ struct SmoothquantPipelineOnePass
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseMax3 = true; // TODO - Move to trait
static constexpr const char* name = []() { static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync) if constexpr(kNeedCrossWarpSync)
...@@ -52,7 +53,15 @@ struct SmoothquantPipelineOnePass ...@@ -52,7 +53,15 @@ struct SmoothquantPipelineOnePass
xscale_window_, Policy::template MakeXScaleBlockTileDistribution<Problem>()); xscale_window_, Policy::template MakeXScaleBlockTileDistribution<Problem>());
auto reduce_absmax_func = ReduceOp::AbsMax{}; auto reduce_absmax_func = ReduceOp::AbsMax{};
auto reduce_max_func = ReduceOp::Max{}; auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) {
float rtn;
asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
: "=v"(rtn)
: "v"(acc_), "v"(v_0_), "v"(v_1_));
return rtn;
};
auto reduce_max_func = ReduceOp::Max{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>(); auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>(); auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_cross_warp_sync = auto block_reduce2d_cross_warp_sync =
...@@ -68,8 +77,23 @@ struct SmoothquantPipelineOnePass ...@@ -68,8 +77,23 @@ struct SmoothquantPipelineOnePass
xscale); xscale);
// compute absmax, cross-lane->cross-warp // compute absmax, cross-lane->cross-warp
auto absmax = block_reduce2d( auto absmax = [&]() {
y, reduce_absmax_func.GetIdentityValue<ComputeDataType>(), reduce_absmax_func); constexpr auto x_size_per_row =
x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{});
if constexpr(UseMax3 && std::is_same_v<ComputeDataType, float> &&
x_size_per_row % 2 == 0)
{
return block_reduce2d(y,
reduce_absmax_func.GetIdentityValue<ComputeDataType>(),
reduce_absmax3_func,
sequence<1, 2>{});
}
else
{
return block_reduce2d(
y, reduce_absmax_func.GetIdentityValue<ComputeDataType>(), reduce_absmax_func);
}
}();
block_reduce2d_sync(absmax, reduce_max_func); block_reduce2d_sync(absmax, reduce_max_func);
block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func); block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func);
......
...@@ -25,6 +25,7 @@ struct SmoothquantPipelineTwoPass ...@@ -25,6 +25,7 @@ struct SmoothquantPipelineTwoPass
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseMax3 = true; // TODO - Move to trait
static constexpr const char* name = []() { static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync) if constexpr(kNeedCrossWarpSync)
...@@ -56,6 +57,13 @@ struct SmoothquantPipelineTwoPass ...@@ -56,6 +57,13 @@ struct SmoothquantPipelineTwoPass
__builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N)); __builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N));
auto reduce_absmax_func = ReduceOp::AbsMax{}; auto reduce_absmax_func = ReduceOp::AbsMax{};
auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) {
float rtn;
asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
: "=v"(rtn)
: "v"(acc_), "v"(v_0_), "v"(v_1_));
return rtn;
};
auto reduce_max_func = ReduceOp::Max{}; auto reduce_max_func = ReduceOp::Max{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>(); auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>(); auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
...@@ -77,7 +85,13 @@ struct SmoothquantPipelineTwoPass ...@@ -77,7 +85,13 @@ struct SmoothquantPipelineTwoPass
x, x,
xscale); xscale);
block_reduce2d(y, absmax, reduce_absmax_func); constexpr auto x_size_per_row =
x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{});
if constexpr(UseMax3 && std::is_same_v<ComputeDataType, float> &&
x_size_per_row % 2 == 0)
block_reduce2d(y, absmax, reduce_absmax3_func, sequence<1, 2>{});
else
block_reduce2d(y, absmax, reduce_absmax_func);
move_tile_window(x_window, {0, Block_N}); move_tile_window(x_window, {0, Block_N});
move_tile_window(xscale_window, {Block_N}); move_tile_window(xscale_window, {Block_N});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment