"src/parserstate.h" did not exist on "a2bd317397d797194acda9cbde1d54da9f2c1928"
Commit b321bd86 authored by rocking's avatar rocking
Browse files

Support pure quant in instance library

parent 26f221eb
...@@ -11,7 +11,8 @@ template <typename DataType_, ...@@ -11,7 +11,8 @@ template <typename DataType_,
ck_tile::index_t ThreadPerBlock_N_, // num threads along N ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_, bool kPadN_,
bool kTwoPass_> bool kTwoPass_,
bool kSmoothX_>
using trait_ = smoothquant_traits_<DataType_, using trait_ = smoothquant_traits_<DataType_,
Repeat_M_, Repeat_M_,
Repeat_N_, Repeat_N_,
...@@ -19,9 +20,10 @@ using trait_ = smoothquant_traits_<DataType_, ...@@ -19,9 +20,10 @@ using trait_ = smoothquant_traits_<DataType_,
ThreadPerBlock_N_, ThreadPerBlock_N_,
Vector_N_, Vector_N_,
kPadN_, kPadN_,
kTwoPass_>; kTwoPass_,
kSmoothX_>;
template <typename data_type> template <typename data_type, bool smooth_x>
float smoothquant_dispatch(smoothquant_traits /*t*/, float smoothquant_dispatch(smoothquant_traits /*t*/,
smoothquant_args a, smoothquant_args a,
const ck_tile::stream_config& s) const ck_tile::stream_config& s)
...@@ -30,99 +32,99 @@ float smoothquant_dispatch(smoothquant_traits /*t*/, ...@@ -30,99 +32,99 @@ float smoothquant_dispatch(smoothquant_traits /*t*/,
// clang-format off // clang-format off
// rm rn tm tn vn pd 2p // rm rn tm tn vn pd 2p
if(a.n <= 64) { if(a.n <= 64) {
r = smoothquant_<trait_<data_type, 1, 1, 4, 64, 1, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 1, 4, 64, 1, true, false, smooth_x>>(s, a);
} }
else if(a.n <= 128) { else if(a.n <= 128) {
if (a.n % 2 == 0) if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 1, 4, 64, 2, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 1, 4, 64, 2, true, false, smooth_x>>(s, a);
else else
r = smoothquant_<trait_<data_type, 1, 2, 4, 64, 1, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 2, 4, 64, 1, true, false, smooth_x>>(s, a);
} }
else if(a.n <= 256) { else if(a.n <= 256) {
if (a.n % 4 == 0) if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 1, 4, 64, 4, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 1, 4, 64, 4, true, false, smooth_x>>(s, a);
else if (a.n % 2 == 0) else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 4, 64, 2, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 2, 4, 64, 2, true, false, smooth_x>>(s, a);
else else
r = smoothquant_<trait_<data_type, 1, 4, 4, 64, 1, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 4, 4, 64, 1, true, false, smooth_x>>(s, a);
} }
else if(a.n <= 512) { else if(a.n <= 512) {
if (a.n % 8 == 0) if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 1, 4, 64, 8, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 1, 4, 64, 8, true, false, smooth_x>>(s, a);
else if (a.n % 4 == 0) else if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 4, 64, 4, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 2, 4, 64, 4, true, false, smooth_x>>(s, a);
else if (a.n % 2 == 0) else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 4, 4, 64, 2, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 4, 4, 64, 2, true, false, smooth_x>>(s, a);
else else
r = smoothquant_<trait_<data_type, 1, 8, 4, 64, 1, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 8, 4, 64, 1, true, false, smooth_x>>(s, a);
} }
else if(a.n <= 768) { else if(a.n <= 768) {
if (a.n % 4 == 0) if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 3, 4, 64, 4, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 3, 4, 64, 4, true, false, smooth_x>>(s, a);
else if (a.n % 2 == 0) else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 6, 4, 64, 2, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 6, 4, 64, 2, true, false, smooth_x>>(s, a);
else else
r = smoothquant_<trait_<data_type, 1,12, 4, 64, 1, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1,12, 4, 64, 1, true, false, smooth_x>>(s, a);
} }
else if(a.n <= 1024) { else if(a.n <= 1024) {
if (a.n % 8 == 0) if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 1, 2, 128, 8, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 1, 2, 128, 8, true, false, smooth_x>>(s, a);
else if (a.n % 4 == 0) else if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 2, 128, 4, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 2, 2, 128, 4, true, false, smooth_x>>(s, a);
else if (a.n % 2 == 0) else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 4, 2, 128, 2, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 4, 2, 128, 2, true, false, smooth_x>>(s, a);
else else
r = smoothquant_<trait_<data_type, 1, 4, 1, 256, 1, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 4, 1, 256, 1, true, false, smooth_x>>(s, a);
} }
else if(a.n <= 1536) { else if(a.n <= 1536) {
if (a.n % 8 == 0) if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 3, 4, 64, 8, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 3, 4, 64, 8, true, false, smooth_x>>(s, a);
else if (a.n % 4 == 0) else if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 3, 2, 128, 4, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 3, 2, 128, 4, true, false, smooth_x>>(s, a);
else if (a.n % 2 == 0) else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 3, 1, 256, 2, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 3, 1, 256, 2, true, false, smooth_x>>(s, a);
else else
r = smoothquant_<trait_<data_type, 1, 6, 1, 256, 1, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 6, 1, 256, 1, true, false, smooth_x>>(s, a);
} }
else if(a.n <= 2048) { else if(a.n <= 2048) {
if (a.n % 8 == 0) if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 1, 1, 256, 8, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 1, 1, 256, 8, true, false, smooth_x>>(s, a);
else if (a.n % 4 == 0) else if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 1, 256, 4, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 2, 1, 256, 4, true, false, smooth_x>>(s, a);
else if (a.n % 2 == 0) else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 4, 1, 256, 2, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 4, 1, 256, 2, true, false, smooth_x>>(s, a);
else else
r = smoothquant_<trait_<data_type, 1, 8, 1, 256, 1, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 8, 1, 256, 1, true, false, smooth_x>>(s, a);
} }
else if(a.n <= 3072) { else if(a.n <= 3072) {
if (a.n % 8 == 0) if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 3, 1, 128, 8, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 3, 1, 128, 8, true, false, smooth_x>>(s, a);
else if (a.n % 4 == 0) else if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 3, 1, 256, 4, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 3, 1, 256, 4, true, false, smooth_x>>(s, a);
else if (a.n % 2 == 0) else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 6, 1, 256, 2, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 6, 1, 256, 2, true, false, smooth_x>>(s, a);
else else
r = smoothquant_<trait_<data_type, 1, 3, 1, 1024, 1, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 3, 1, 1024, 1, true, false, smooth_x>>(s, a);
} }
else if(a.n <= 4096) { else if(a.n <= 4096) {
if (a.n % 8 == 0) if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 1, 256, 8, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 2, 1, 256, 8, true, false, smooth_x>>(s, a);
else if (a.n % 4 == 0) else if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 4, 1, 256, 4, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 4, 1, 256, 4, true, false, smooth_x>>(s, a);
else if (a.n % 2 == 0) else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 1, 1024, 2, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 2, 1, 1024, 2, true, false, smooth_x>>(s, a);
else else
r = smoothquant_<trait_<data_type, 1, 4, 1, 1024, 1, true, false>>(s, a); r = smoothquant_<trait_<data_type, 1, 4, 1, 1024, 1, true, false, smooth_x>>(s, a);
} }
else if(a.n > 4096) { else if(a.n > 4096) {
if (a.n % 8 == 0) if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 1, 256, 8, true, true>>(s, a); r = smoothquant_<trait_<data_type, 1, 2, 1, 256, 8, true, true, smooth_x>>(s, a);
else if (a.n % 4 == 0) else if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 4, 1, 256, 4, true, true>>(s, a); r = smoothquant_<trait_<data_type, 1, 4, 1, 256, 4, true, true, smooth_x>>(s, a);
else if (a.n % 2 == 0) else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 1, 1024, 2, true, true>>(s, a); r = smoothquant_<trait_<data_type, 1, 2, 1, 1024, 2, true, true, smooth_x>>(s, a);
else else
r = smoothquant_<trait_<data_type, 1, 4, 1, 1024, 1, true, true>>(s, a); r = smoothquant_<trait_<data_type, 1, 4, 1, 1024, 1, true, true, smooth_x>>(s, a);
} }
return r; return r;
// clang-format on // clang-format on
...@@ -132,11 +134,17 @@ float smoothquant(smoothquant_traits t, smoothquant_args a, const ck_tile::strea ...@@ -132,11 +134,17 @@ float smoothquant(smoothquant_traits t, smoothquant_args a, const ck_tile::strea
{ {
if(t.data_type.compare("fp16") == 0) if(t.data_type.compare("fp16") == 0)
{ {
return smoothquant_dispatch<ck_tile::fp16_t>(t, a, s); if (t.smooth_x)
return smoothquant_dispatch<ck_tile::fp16_t, true>(t, a, s);
else
return smoothquant_dispatch<ck_tile::fp16_t, false>(t, a, s);
} }
else if(t.data_type.compare("bf16") == 0) else if(t.data_type.compare("bf16") == 0)
{ {
return smoothquant_dispatch<ck_tile::bf16_t>(t, a, s); if (t.smooth_x)
return smoothquant_dispatch<ck_tile::bf16_t, true>(t, a, s);
else
return smoothquant_dispatch<ck_tile::bf16_t, false>(t, a, s);
} }
else else
throw std::runtime_error("Without supported instances!"); throw std::runtime_error("Without supported instances!");
......
...@@ -18,7 +18,8 @@ template <typename DataType_, ...@@ -18,7 +18,8 @@ template <typename DataType_,
ck_tile::index_t ThreadPerBlock_N_, // num threads along N ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_, bool kPadN_,
bool kTwoPass_> bool kTwoPass_,
bool kSmoothX_>
using trait_ = smoothquant_traits_<DataType_, using trait_ = smoothquant_traits_<DataType_,
Repeat_M_, Repeat_M_,
Repeat_N_, Repeat_N_,
...@@ -26,7 +27,8 @@ using trait_ = smoothquant_traits_<DataType_, ...@@ -26,7 +27,8 @@ using trait_ = smoothquant_traits_<DataType_,
ThreadPerBlock_N_, ThreadPerBlock_N_,
Vector_N_, Vector_N_,
kPadN_, kPadN_,
kTwoPass_>; kTwoPass_,
kSmoothX_>;
template <typename Traits_> template <typename Traits_>
float smoothquant_(const S& s, A a) float smoothquant_(const S& s, A a)
...@@ -42,7 +44,7 @@ float smoothquant_(const S& s, A a) ...@@ -42,7 +44,7 @@ float smoothquant_(const S& s, A a)
typename Traits_::Shape, typename Traits_::Shape,
Traits_::kPadN, Traits_::kPadN,
Traits_::kTwoPass, Traits_::kTwoPass,
true>; Traits_::kSmoothX>;
using OnePassPipeline = ck_tile::SmoothquantPipelineOnePass<PipelineProblem>; using OnePassPipeline = ck_tile::SmoothquantPipelineOnePass<PipelineProblem>;
using TwoPassPipeline = ck_tile::SmoothquantPipelineTwoPass<PipelineProblem>; using TwoPassPipeline = ck_tile::SmoothquantPipelineTwoPass<PipelineProblem>;
......
...@@ -34,6 +34,7 @@ auto create_args(int argc, char* argv[]) ...@@ -34,6 +34,7 @@ auto create_args(int argc, char* argv[])
arg_parser.insert("m", "3328", "m dimension") arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension") .insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n") .insert("stride", "-1", "stride per row, if -1 then equal to n")
.insert("sx", "1", "0 is pure quantization, 1 is to apply smoothquant")
.insert("v", "1", "cpu validation or not") .insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not") .insert("kname", "1", "print kernel name or not")
.insert("prec", "fp16", "precision") .insert("prec", "fp16", "precision")
...@@ -53,6 +54,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -53,6 +54,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(stride < 0) if(stride < 0)
stride = n; stride = n;
std::string data_type = arg_parser.get_str("prec"); std::string data_type = arg_parser.get_str("prec");
bool smooth_x = arg_parser.get_bool("sx");
int kname = arg_parser.get_int("kname"); int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v"); int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup"); int warmup = arg_parser.get_int("warmup");
...@@ -92,7 +94,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -92,7 +94,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << "[" << data_type << "]" std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
smoothquant_traits traits{data_type}; smoothquant_traits traits{data_type, smooth_x};
smoothquant_args args{x_buf.GetDeviceBuffer(), smoothquant_args args{x_buf.GetDeviceBuffer(),
xscale_buf.GetDeviceBuffer(), xscale_buf.GetDeviceBuffer(),
...@@ -124,8 +126,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -124,8 +126,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
for(int m_ = 0; m_ < m; ++m_) for(int m_ = 0; m_ < m; ++m_)
{ {
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_)); auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_));
y_host(m_, n_) = v_x * v_xscale; if(smooth_x)
y_host(m_, n_) = v_x * v_xscale;
else
y_host(m_, n_) = v_x;
} }
}; };
......
...@@ -44,7 +44,8 @@ template <typename DataType_, ...@@ -44,7 +44,8 @@ template <typename DataType_,
ck_tile::index_t ThreadPerBlock_N_, // num threads along N ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_, bool kPadN_,
bool kTwoPass_> bool kTwoPass_,
bool kSmoothX_>
struct smoothquant_traits_ struct smoothquant_traits_
{ {
using DataType = ck_tile::remove_cvref_t<DataType_>; using DataType = ck_tile::remove_cvref_t<DataType_>;
...@@ -100,6 +101,7 @@ struct smoothquant_traits_ ...@@ -100,6 +101,7 @@ struct smoothquant_traits_
static constexpr bool kPadN = kPadN_; static constexpr bool kPadN = kPadN_;
static constexpr bool kTwoPass = kTwoPass_; static constexpr bool kTwoPass = kTwoPass_;
static constexpr bool kSmoothX = kSmoothX_;
}; };
template <typename Traits_> template <typename Traits_>
...@@ -109,6 +111,7 @@ float smoothquant_(const ck_tile::stream_config& s, smoothquant_args a); ...@@ -109,6 +111,7 @@ float smoothquant_(const ck_tile::stream_config& s, smoothquant_args a);
struct smoothquant_traits struct smoothquant_traits
{ {
std::string data_type; std::string data_type;
bool smooth_x;
}; };
float smoothquant(smoothquant_traits, smoothquant_args, const ck_tile::stream_config&); float smoothquant(smoothquant_traits, smoothquant_args, const ck_tile::stream_config&);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment