Unverified Commit 4e731776 authored by chenjun's avatar chenjun Committed by GitHub
Browse files

Ck tile/smoothquant out stride (#1742)

* add ck_tile/smoothquant out stride parameter

* Remove the default stride value

---------

Co-authored-by: so <a.com>
parent 77a38e02
...@@ -35,7 +35,8 @@ auto create_args(int argc, char* argv[]) ...@@ -35,7 +35,8 @@ auto create_args(int argc, char* argv[])
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension") arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension") .insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n") .insert("x_stride", "-1", "input stride per row, if -1 then equal to n")
.insert("y_stride", "-1", "output stride per row, if -1 then equal to n")
.insert("e", "1e-5", "epsilon") .insert("e", "1e-5", "epsilon")
.insert("v", "1", "cpu validation or not") .insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision") .insert("prec", "fp16", "precision")
...@@ -49,11 +50,14 @@ auto create_args(int argc, char* argv[]) ...@@ -49,11 +50,14 @@ auto create_args(int argc, char* argv[])
template <typename DataType> template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
ck_tile::index_t m = arg_parser.get_int("m"); ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n"); ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride"); ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(stride < 0) if(x_stride < 0)
stride = n; x_stride = n;
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
if(y_stride < 0)
y_stride = n;
std::string data_type = arg_parser.get_str("prec"); std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v"); int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup"); int warmup = arg_parser.get_int("warmup");
...@@ -68,14 +72,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -68,14 +72,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
using ComputeDataType = float; using ComputeDataType = float;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1}); ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<XScaleDataType> xscale_host({n}); ck_tile::HostTensor<XScaleDataType> xscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {stride, 1}); ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {y_stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {stride, 1}); ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host); ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host);
...@@ -116,7 +120,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -116,7 +120,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf.GetDeviceBuffer(), qy_buf.GetDeviceBuffer(),
m, m,
n, n,
stride}; x_stride,
y_stride};
auto kargs = Kernel::MakeKargs(args); auto kargs = Kernel::MakeKargs(args);
...@@ -133,7 +138,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -133,7 +138,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation) if(do_validation)
{ {
using YDataType = ComputeDataType; using YDataType = ComputeDataType;
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {stride, 1}); ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {y_stride, 1});
// smooth outlier // smooth outlier
{ {
auto f = [&](auto n_) { auto f = [&](auto n_) {
...@@ -183,7 +188,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -183,7 +188,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf.FromDevice(qy_host_dev.data()); qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>(); auto [rtol, atol] = get_elimit<QYDataType>();
if(stride == n) if(y_stride == n)
{ {
pass = ck_tile::check_err(qy_host_dev, pass = ck_tile::check_err(qy_host_dev,
qy_host_ref, qy_host_ref,
...@@ -195,10 +200,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -195,10 +200,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
for(int i_r = 0; i_r < m; i_r++) for(int i_r = 0; i_r < m; i_r++)
{ {
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride, std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * y_stride,
qy_host_dev.begin() + i_r * stride + n); qy_host_dev.begin() + i_r * y_stride +
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride, n);
qy_host_ref.begin() + i_r * stride + n); std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * y_stride,
qy_host_ref.begin() + i_r * y_stride +
n);
pass &= ck_tile::check_err(qy_host_dev_row, pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row, qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) + std::string("qy[") + std::to_string(i_r) +
...@@ -210,8 +217,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -210,8 +217,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
std::cout << "[" << data_type << "]" std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << " m:" << m << ", n:" << n << ", x_stride:" << x_stride
<< ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; << ", y_stride:" << y_stride << ", valid:" << (pass ? "y" : "n") << std::flush
<< std::endl;
} }
return pass; return pass;
......
...@@ -33,7 +33,8 @@ auto create_args(int argc, char* argv[]) ...@@ -33,7 +33,8 @@ auto create_args(int argc, char* argv[])
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension") arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension") .insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n") .insert("x_stride", "-1", "input stride per row, if -1 then equal to n")
.insert("y_stride", "-1", "output stride per row, if -1 then equal to n")
.insert("v", "1", "cpu validation or not") .insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not") .insert("kname", "1", "print kernel name or not")
.insert("prec", "fp16", "precision") .insert("prec", "fp16", "precision")
...@@ -47,18 +48,21 @@ auto create_args(int argc, char* argv[]) ...@@ -47,18 +48,21 @@ auto create_args(int argc, char* argv[])
template <typename DataType> template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
ck_tile::index_t m = arg_parser.get_int("m"); ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n"); ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride"); ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(stride < 0) if(x_stride < 0)
stride = n; x_stride = n;
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
if(y_stride < 0)
y_stride = n;
std::string data_type = arg_parser.get_str("prec"); std::string data_type = arg_parser.get_str("prec");
int kname = arg_parser.get_int("kname"); int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v"); int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup"); int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat"); int repeat = arg_parser.get_int("repeat");
assert(stride >= n); assert(x_stride >= n);
using TypeConfig = SmoothquantTypeConfig<DataType>; using TypeConfig = SmoothquantTypeConfig<DataType>;
...@@ -69,14 +73,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -69,14 +73,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
using ComputeDataType = typename TypeConfig::ComputeDataType; using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1}); ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<XScaleDataType> xscale_host({n}); ck_tile::HostTensor<XScaleDataType> xscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {stride, 1}); ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {y_stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {stride, 1}); ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host); ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host);
...@@ -90,7 +94,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -90,7 +94,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
xscale_buf.ToDevice(xscale_host.data()); xscale_buf.ToDevice(xscale_host.data());
std::cout << "[" << data_type << "]" std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride
<< std::flush;
smoothquant_traits traits{data_type}; smoothquant_traits traits{data_type};
...@@ -100,7 +105,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -100,7 +105,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf.GetDeviceBuffer(), qy_buf.GetDeviceBuffer(),
m, m,
n, n,
stride}; x_stride,
y_stride};
float ave_time = smoothquant( float ave_time = smoothquant(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
...@@ -116,7 +122,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -116,7 +122,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation) if(do_validation)
{ {
using YDataType = ComputeDataType; using YDataType = ComputeDataType;
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {stride, 1}); ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {y_stride, 1});
// smooth outlier // smooth outlier
{ {
auto f = [&](auto n_) { auto f = [&](auto n_) {
...@@ -166,7 +172,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -166,7 +172,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf.FromDevice(qy_host_dev.data()); qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>(); auto [rtol, atol] = get_elimit<QYDataType>();
if(stride == n) if(y_stride == n)
{ {
pass = ck_tile::check_err(qy_host_dev, pass = ck_tile::check_err(qy_host_dev,
qy_host_ref, qy_host_ref,
...@@ -178,10 +184,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -178,10 +184,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
for(int i_r = 0; i_r < m; i_r++) for(int i_r = 0; i_r < m; i_r++)
{ {
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride, std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * y_stride,
qy_host_dev.begin() + i_r * stride + n); qy_host_dev.begin() + i_r * y_stride +
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride, n);
qy_host_ref.begin() + i_r * stride + n); std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * y_stride,
qy_host_ref.begin() + i_r * y_stride +
n);
pass &= ck_tile::check_err(qy_host_dev_row, pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row, qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) + std::string("qy[") + std::to_string(i_r) +
......
...@@ -19,7 +19,8 @@ struct SmoothquantHostArgs ...@@ -19,7 +19,8 @@ struct SmoothquantHostArgs
index_t m; index_t m;
index_t n; index_t n;
index_t stride; // row_stride index_t x_stride; // input row_stride
index_t y_stride; // output row_stride
}; };
// TODO: Extract some type to wrapper class // TODO: Extract some type to wrapper class
...@@ -58,14 +59,21 @@ struct Smoothquant ...@@ -58,14 +59,21 @@ struct Smoothquant
index_t m; index_t m;
index_t n; index_t n;
index_t stride; // row_stride index_t x_stride; // input row_stride
index_t y_stride; // out row_stride
}; };
using Hargs = SmoothquantHostArgs; using Hargs = SmoothquantHostArgs;
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs) CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{ {
return Kargs{ return Kargs{hargs.p_x,
hargs.p_x, hargs.p_xscale, hargs.p_yscale, hargs.p_qy, hargs.m, hargs.n, hargs.stride}; hargs.p_xscale,
hargs.p_yscale,
hargs.p_qy,
hargs.m,
hargs.n,
hargs.x_stride,
hargs.y_stride};
} }
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
...@@ -116,7 +124,7 @@ struct Smoothquant ...@@ -116,7 +124,7 @@ struct Smoothquant
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x), static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1), make_tuple(kargs.x_stride, 1),
number<Vector_N>{}, number<Vector_N>{},
number<1>{}); number<1>{});
...@@ -157,7 +165,7 @@ struct Smoothquant ...@@ -157,7 +165,7 @@ struct Smoothquant
auto tmp_ = make_naive_tensor_view<address_space_enum::global>( auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<QYDataType*>(kargs.p_qy), static_cast<QYDataType*>(kargs.p_qy),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1), make_tuple(kargs.y_stride, 1),
number<Vector_N>{}, number<Vector_N>{},
number<1>{}); number<1>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment