Commit b321bd86 authored by rocking's avatar rocking
Browse files

Support pure quant in instance library

parent 26f221eb
......@@ -11,7 +11,8 @@ template <typename DataType_,
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kTwoPass_>
bool kTwoPass_,
bool kSmoothX_>
using trait_ = smoothquant_traits_<DataType_,
Repeat_M_,
Repeat_N_,
......@@ -19,9 +20,10 @@ using trait_ = smoothquant_traits_<DataType_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kTwoPass_>;
kTwoPass_,
kSmoothX_>;
template <typename data_type>
template <typename data_type, bool smooth_x>
float smoothquant_dispatch(smoothquant_traits /*t*/,
smoothquant_args a,
const ck_tile::stream_config& s)
......@@ -30,99 +32,99 @@ float smoothquant_dispatch(smoothquant_traits /*t*/,
// clang-format off
// rm rn tm tn vn pd 2p
if(a.n <= 64) {
r = smoothquant_<trait_<data_type, 1, 1, 4, 64, 1, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 1, 4, 64, 1, true, false, smooth_x>>(s, a);
}
else if(a.n <= 128) {
if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 1, 4, 64, 2, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 1, 4, 64, 2, true, false, smooth_x>>(s, a);
else
r = smoothquant_<trait_<data_type, 1, 2, 4, 64, 1, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 2, 4, 64, 1, true, false, smooth_x>>(s, a);
}
else if(a.n <= 256) {
if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 1, 4, 64, 4, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 1, 4, 64, 4, true, false, smooth_x>>(s, a);
else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 4, 64, 2, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 2, 4, 64, 2, true, false, smooth_x>>(s, a);
else
r = smoothquant_<trait_<data_type, 1, 4, 4, 64, 1, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 4, 4, 64, 1, true, false, smooth_x>>(s, a);
}
else if(a.n <= 512) {
if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 1, 4, 64, 8, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 1, 4, 64, 8, true, false, smooth_x>>(s, a);
else if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 4, 64, 4, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 2, 4, 64, 4, true, false, smooth_x>>(s, a);
else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 4, 4, 64, 2, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 4, 4, 64, 2, true, false, smooth_x>>(s, a);
else
r = smoothquant_<trait_<data_type, 1, 8, 4, 64, 1, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 8, 4, 64, 1, true, false, smooth_x>>(s, a);
}
else if(a.n <= 768) {
if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 3, 4, 64, 4, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 3, 4, 64, 4, true, false, smooth_x>>(s, a);
else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 6, 4, 64, 2, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 6, 4, 64, 2, true, false, smooth_x>>(s, a);
else
r = smoothquant_<trait_<data_type, 1,12, 4, 64, 1, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1,12, 4, 64, 1, true, false, smooth_x>>(s, a);
}
else if(a.n <= 1024) {
if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 1, 2, 128, 8, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 1, 2, 128, 8, true, false, smooth_x>>(s, a);
else if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 2, 128, 4, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 2, 2, 128, 4, true, false, smooth_x>>(s, a);
else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 4, 2, 128, 2, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 4, 2, 128, 2, true, false, smooth_x>>(s, a);
else
r = smoothquant_<trait_<data_type, 1, 4, 1, 256, 1, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 4, 1, 256, 1, true, false, smooth_x>>(s, a);
}
else if(a.n <= 1536) {
if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 3, 4, 64, 8, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 3, 4, 64, 8, true, false, smooth_x>>(s, a);
else if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 3, 2, 128, 4, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 3, 2, 128, 4, true, false, smooth_x>>(s, a);
else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 3, 1, 256, 2, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 3, 1, 256, 2, true, false, smooth_x>>(s, a);
else
r = smoothquant_<trait_<data_type, 1, 6, 1, 256, 1, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 6, 1, 256, 1, true, false, smooth_x>>(s, a);
}
else if(a.n <= 2048) {
if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 1, 1, 256, 8, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 1, 1, 256, 8, true, false, smooth_x>>(s, a);
else if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 1, 256, 4, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 2, 1, 256, 4, true, false, smooth_x>>(s, a);
else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 4, 1, 256, 2, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 4, 1, 256, 2, true, false, smooth_x>>(s, a);
else
r = smoothquant_<trait_<data_type, 1, 8, 1, 256, 1, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 8, 1, 256, 1, true, false, smooth_x>>(s, a);
}
else if(a.n <= 3072) {
if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 3, 1, 128, 8, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 3, 1, 128, 8, true, false, smooth_x>>(s, a);
else if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 3, 1, 256, 4, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 3, 1, 256, 4, true, false, smooth_x>>(s, a);
else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 6, 1, 256, 2, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 6, 1, 256, 2, true, false, smooth_x>>(s, a);
else
r = smoothquant_<trait_<data_type, 1, 3, 1, 1024, 1, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 3, 1, 1024, 1, true, false, smooth_x>>(s, a);
}
else if(a.n <= 4096) {
if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 1, 256, 8, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 2, 1, 256, 8, true, false, smooth_x>>(s, a);
else if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 4, 1, 256, 4, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 4, 1, 256, 4, true, false, smooth_x>>(s, a);
else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 1, 1024, 2, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 2, 1, 1024, 2, true, false, smooth_x>>(s, a);
else
r = smoothquant_<trait_<data_type, 1, 4, 1, 1024, 1, true, false>>(s, a);
r = smoothquant_<trait_<data_type, 1, 4, 1, 1024, 1, true, false, smooth_x>>(s, a);
}
else if(a.n > 4096) {
if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 1, 256, 8, true, true>>(s, a);
r = smoothquant_<trait_<data_type, 1, 2, 1, 256, 8, true, true, smooth_x>>(s, a);
else if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 4, 1, 256, 4, true, true>>(s, a);
r = smoothquant_<trait_<data_type, 1, 4, 1, 256, 4, true, true, smooth_x>>(s, a);
else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 1, 1024, 2, true, true>>(s, a);
r = smoothquant_<trait_<data_type, 1, 2, 1, 1024, 2, true, true, smooth_x>>(s, a);
else
r = smoothquant_<trait_<data_type, 1, 4, 1, 1024, 1, true, true>>(s, a);
r = smoothquant_<trait_<data_type, 1, 4, 1, 1024, 1, true, true, smooth_x>>(s, a);
}
return r;
// clang-format on
......@@ -132,11 +134,17 @@ float smoothquant(smoothquant_traits t, smoothquant_args a, const ck_tile::strea
{
if(t.data_type.compare("fp16") == 0)
{
return smoothquant_dispatch<ck_tile::fp16_t>(t, a, s);
if (t.smooth_x)
return smoothquant_dispatch<ck_tile::fp16_t, true>(t, a, s);
else
return smoothquant_dispatch<ck_tile::fp16_t, false>(t, a, s);
}
else if(t.data_type.compare("bf16") == 0)
{
return smoothquant_dispatch<ck_tile::bf16_t>(t, a, s);
if (t.smooth_x)
return smoothquant_dispatch<ck_tile::bf16_t, true>(t, a, s);
else
return smoothquant_dispatch<ck_tile::bf16_t, false>(t, a, s);
}
else
throw std::runtime_error("Without supported instances!");
......
......@@ -18,7 +18,8 @@ template <typename DataType_,
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kTwoPass_>
bool kTwoPass_,
bool kSmoothX_>
using trait_ = smoothquant_traits_<DataType_,
Repeat_M_,
Repeat_N_,
......@@ -26,7 +27,8 @@ using trait_ = smoothquant_traits_<DataType_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kTwoPass_>;
kTwoPass_,
kSmoothX_>;
template <typename Traits_>
float smoothquant_(const S& s, A a)
......@@ -42,7 +44,7 @@ float smoothquant_(const S& s, A a)
typename Traits_::Shape,
Traits_::kPadN,
Traits_::kTwoPass,
true>;
Traits_::kSmoothX>;
using OnePassPipeline = ck_tile::SmoothquantPipelineOnePass<PipelineProblem>;
using TwoPassPipeline = ck_tile::SmoothquantPipelineTwoPass<PipelineProblem>;
......
......@@ -34,6 +34,7 @@ auto create_args(int argc, char* argv[])
arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n")
.insert("sx", "1", "0 is pure quantization, 1 is to apply smoothquant")
.insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not")
.insert("prec", "fp16", "precision")
......@@ -53,6 +54,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(stride < 0)
stride = n;
std::string data_type = arg_parser.get_str("prec");
bool smooth_x = arg_parser.get_bool("sx");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
......@@ -92,7 +94,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
smoothquant_traits traits{data_type};
smoothquant_traits traits{data_type, smooth_x};
smoothquant_args args{x_buf.GetDeviceBuffer(),
xscale_buf.GetDeviceBuffer(),
......@@ -125,7 +127,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
for(int m_ = 0; m_ < m; ++m_)
{
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_));
if(smooth_x)
y_host(m_, n_) = v_x * v_xscale;
else
y_host(m_, n_) = v_x;
}
};
......
......@@ -44,7 +44,8 @@ template <typename DataType_,
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kTwoPass_>
bool kTwoPass_,
bool kSmoothX_>
struct smoothquant_traits_
{
using DataType = ck_tile::remove_cvref_t<DataType_>;
......@@ -100,6 +101,7 @@ struct smoothquant_traits_
static constexpr bool kPadN = kPadN_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr bool kSmoothX = kSmoothX_;
};
template <typename Traits_>
......@@ -109,6 +111,7 @@ float smoothquant_(const ck_tile::stream_config& s, smoothquant_args a);
struct smoothquant_traits
{
std::string data_type;
bool smooth_x;
};
float smoothquant(smoothquant_traits, smoothquant_args, const ck_tile::stream_config&);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment