"test/contraction/test_contraction_interface.cpp" did not exist on "3eee1b9b8fa13d044509089c7fc8186f4439d412"
Commit 26f221eb authored by rocking's avatar rocking
Browse files

Support Pure quant kernel

parent bb652696
...@@ -89,6 +89,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -89,6 +89,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
xscale_buf.ToDevice(xscale_host.data()); xscale_buf.ToDevice(xscale_host.data());
constexpr bool kTwoPass = true; constexpr bool kTwoPass = true;
constexpr bool kSmoothX = true;
using BlockWarps = ck_tile::sequence<2, 2>; using BlockWarps = ck_tile::sequence<2, 2>;
using BlockTile = ck_tile::sequence<2, 128>; using BlockTile = ck_tile::sequence<2, 128>;
...@@ -103,7 +104,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -103,7 +104,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
QYDataType, QYDataType,
Shape, Shape,
true, true,
kTwoPass>; kTwoPass,
kSmoothX>;
using OnePassPipeline = ck_tile::SmoothquantPipelineOnePass<Problem>; using OnePassPipeline = ck_tile::SmoothquantPipelineOnePass<Problem>;
using TwoPassPipeline = ck_tile::SmoothquantPipelineTwoPass<Problem>; using TwoPassPipeline = ck_tile::SmoothquantPipelineTwoPass<Problem>;
...@@ -141,8 +143,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -141,8 +143,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
for(int m_ = 0; m_ < m; ++m_) for(int m_ = 0; m_ < m; ++m_)
{ {
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_)); auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_));
y_host(m_, n_) = v_x * v_xscale; if constexpr(kSmoothX)
y_host(m_, n_) = v_x * v_xscale;
else
y_host(m_, n_) = v_x;
} }
}; };
......
...@@ -41,7 +41,8 @@ float smoothquant_(const S& s, A a) ...@@ -41,7 +41,8 @@ float smoothquant_(const S& s, A a)
typename SmoothquantTypeConfig<DataType>::QYDataType, typename SmoothquantTypeConfig<DataType>::QYDataType,
typename Traits_::Shape, typename Traits_::Shape,
Traits_::kPadN, Traits_::kPadN,
Traits_::kTwoPass>; Traits_::kTwoPass,
true>;
using OnePassPipeline = ck_tile::SmoothquantPipelineOnePass<PipelineProblem>; using OnePassPipeline = ck_tile::SmoothquantPipelineOnePass<PipelineProblem>;
using TwoPassPipeline = ck_tile::SmoothquantPipelineTwoPass<PipelineProblem>; using TwoPassPipeline = ck_tile::SmoothquantPipelineTwoPass<PipelineProblem>;
......
...@@ -41,7 +41,8 @@ float moe_smoothquant_(const S& s, A a) ...@@ -41,7 +41,8 @@ float moe_smoothquant_(const S& s, A a)
typename MoeSmoothquantTypeConfig<DataType>::QYDataType, typename MoeSmoothquantTypeConfig<DataType>::QYDataType,
typename Traits_::Shape, typename Traits_::Shape,
Traits_::kPadN, Traits_::kPadN,
Traits_::kTwoPass>; Traits_::kTwoPass,
true>;
using OnePassPipeline = ck_tile::SmoothquantPipelineOnePass<PipelineProblem>; using OnePassPipeline = ck_tile::SmoothquantPipelineOnePass<PipelineProblem>;
using TwoPassPipeline = ck_tile::SmoothquantPipelineTwoPass<PipelineProblem>; using TwoPassPipeline = ck_tile::SmoothquantPipelineTwoPass<PipelineProblem>;
......
...@@ -40,6 +40,7 @@ struct Smoothquant ...@@ -40,6 +40,7 @@ struct Smoothquant
static constexpr bool kPadM = false; // always no need to pad along M static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kTwoPass = Problem::kTwoPass; static constexpr bool kTwoPass = Problem::kTwoPass;
static constexpr bool kSmoothX = Problem::kSmoothX;
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
...@@ -95,6 +96,7 @@ struct Smoothquant ...@@ -95,6 +96,7 @@ struct Smoothquant
std::string n; std::string n;
if (kPadN) n += "_pn"; if (kPadN) n += "_pn";
if (kTwoPass) n += "_2p"; if (kTwoPass) n += "_2p";
if (kSmoothX) n += "_sx";
return n; }(); return n; }();
#define _SS_ std::string #define _SS_ std::string
...@@ -127,17 +129,22 @@ struct Smoothquant ...@@ -127,17 +129,22 @@ struct Smoothquant
}(); }();
const auto xscale_window = [&]() { const auto xscale_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( if constexpr(kSmoothX)
static_cast<const XScaleDataType*>(kargs.p_xscale), {
make_tuple(kargs.n), const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
make_tuple(1), static_cast<const XScaleDataType*>(kargs.p_xscale),
number<Vector_N>{}, make_tuple(kargs.n),
number<1>{}); make_tuple(1),
number<Vector_N>{},
const auto tmp2_ = number<1>{});
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<kPadN>{});
const auto tmp2_ =
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0}); pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<kPadN>{});
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
}
else
return make_null_tile_window(make_tuple(number<Block_N>{}));
}(); }();
auto yscale_window = [&]() { auto yscale_window = [&]() {
......
...@@ -23,9 +23,10 @@ struct SmoothquantPipelineOnePass ...@@ -23,9 +23,10 @@ struct SmoothquantPipelineOnePass
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>; using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM static constexpr bool kPadM = false; // No need to pad M
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseMax3 = true; // TODO - Move to trait static constexpr bool kSmoothX = Problem::kSmoothX;
static constexpr bool UseMax3 = true; // TODO - Move to Problem
static constexpr const char* name = []() { static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync) if constexpr(kNeedCrossWarpSync)
...@@ -67,14 +68,22 @@ struct SmoothquantPipelineOnePass ...@@ -67,14 +68,22 @@ struct SmoothquantPipelineOnePass
auto block_reduce2d_cross_warp_sync = auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>(); Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
const auto x = load_tile(x_window); const auto x = load_tile(x_window);
const auto xscale = load_tile(xscale_window);
auto y = tile_elementwise_in( auto y = [&]() {
[&](const auto& a, const auto& b) { if constexpr(kSmoothX)
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b); {
}, const auto xscale = load_tile(xscale_window);
x, return tile_elementwise_in(
xscale); [&](const auto& a, const auto& b) {
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
},
x,
xscale);
}
else
return cast_tile<ComputeDataType>(x);
}();
// compute absmax, cross-lane->cross-warp // compute absmax, cross-lane->cross-warp
auto absmax = [&]() { auto absmax = [&]() {
......
...@@ -15,7 +15,8 @@ template <typename XDataType_, ...@@ -15,7 +15,8 @@ template <typename XDataType_,
typename QYDataType_, typename QYDataType_,
typename BlockShape_, typename BlockShape_,
bool kPadN_, bool kPadN_,
bool kTwoPass_> bool kTwoPass_,
bool kSmoothX_>
struct SmoothquantPipelineProblem struct SmoothquantPipelineProblem
{ {
using XDataType = remove_cvref_t<XDataType_>; using XDataType = remove_cvref_t<XDataType_>;
...@@ -30,6 +31,7 @@ struct SmoothquantPipelineProblem ...@@ -30,6 +31,7 @@ struct SmoothquantPipelineProblem
static constexpr bool kPadN = kPadN_; static constexpr bool kPadN = kPadN_;
static constexpr bool kTwoPass = kTwoPass_; static constexpr bool kTwoPass = kTwoPass_;
static constexpr bool kSmoothX = kSmoothX_;
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -23,8 +23,9 @@ struct SmoothquantPipelineTwoPass ...@@ -23,8 +23,9 @@ struct SmoothquantPipelineTwoPass
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>; using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM static constexpr bool kPadM = false; // No need to pad M
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kSmoothX = Problem::kSmoothX;
static constexpr bool UseMax3 = true; // TODO - Move to trait static constexpr bool UseMax3 = true; // TODO - Move to trait
static constexpr const char* name = []() { static constexpr const char* name = []() {
...@@ -76,14 +77,23 @@ struct SmoothquantPipelineTwoPass ...@@ -76,14 +77,23 @@ struct SmoothquantPipelineTwoPass
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
const auto x = load_tile(x_window); const auto x = load_tile(x_window);
const auto xscale = load_tile(xscale_window);
const auto y = tile_elementwise_in( auto y = [&]() {
[&](const auto& a, const auto& b) { if constexpr(kSmoothX)
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b); {
}, const auto xscale = load_tile(xscale_window);
x, return tile_elementwise_in(
xscale); [&](const auto& a, const auto& b) {
return type_convert<ComputeDataType>(a) *
type_convert<ComputeDataType>(b);
},
x,
xscale);
}
else
return cast_tile<ComputeDataType>(x);
}();
constexpr auto x_size_per_row = constexpr auto x_size_per_row =
x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{}); x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{});
...@@ -93,8 +103,10 @@ struct SmoothquantPipelineTwoPass ...@@ -93,8 +103,10 @@ struct SmoothquantPipelineTwoPass
else else
block_reduce2d(y, absmax, reduce_absmax_func); block_reduce2d(y, absmax, reduce_absmax_func);
if constexpr(kSmoothX)
move_tile_window(xscale_window, {Block_N});
move_tile_window(x_window, {0, Block_N}); move_tile_window(x_window, {0, Block_N});
move_tile_window(xscale_window, {Block_N});
} }
// compute absmax, cross-lane->cross-warp // compute absmax, cross-lane->cross-warp
...@@ -113,21 +125,32 @@ struct SmoothquantPipelineTwoPass ...@@ -113,21 +125,32 @@ struct SmoothquantPipelineTwoPass
ck_tile::index_t stride_to_right_most_window = ck_tile::index_t stride_to_right_most_window =
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
if constexpr(kSmoothX)
move_tile_window(xscale_window, {-Block_N});
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(xscale_window, {-Block_N});
move_tile_window(qy_window, {0, stride_to_right_most_window}); move_tile_window(qy_window, {0, stride_to_right_most_window});
// recompute y and quantize y to qy // recompute y and quantize y to qy
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
const auto x = load_tile(x_window); const auto x = load_tile(x_window);
const auto xscale = load_tile(xscale_window);
const auto y = tile_elementwise_in( auto y = [&]() {
[&](const auto& a, const auto& b) { if constexpr(kSmoothX)
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b); {
}, const auto xscale = load_tile(xscale_window);
x, return tile_elementwise_in(
xscale); [&](const auto& a, const auto& b) {
return type_convert<ComputeDataType>(a) *
type_convert<ComputeDataType>(b);
},
x,
xscale);
}
else
return cast_tile<ComputeDataType>(x);
}();
auto qy = make_static_distributed_tensor<QYDataType>(y.get_tile_distribution()); auto qy = make_static_distributed_tensor<QYDataType>(y.get_tile_distribution());
sweep_tile(qy, [&](auto idx) { sweep_tile(qy, [&](auto idx) {
...@@ -137,8 +160,10 @@ struct SmoothquantPipelineTwoPass ...@@ -137,8 +160,10 @@ struct SmoothquantPipelineTwoPass
}); });
store_tile(qy_window, qy); store_tile(qy_window, qy);
if constexpr(kSmoothX)
move_tile_window(xscale_window, {0, -Block_N});
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(xscale_window, {0, -Block_N});
move_tile_window(qy_window, {0, -Block_N}); move_tile_window(qy_window, {0, -Block_N});
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment