#include "ck_tile/host.hpp" #include "smoothquant.hpp" #include // different threshold for different dtype template auto get_elimit() { double rtol = 1e-5; double atol = 1e-5; return ck_tile::make_tuple(rtol, atol); } template <> auto get_elimit() { double rtol = 1e-5; double atol = 1e-5; return ck_tile::make_tuple(rtol, atol); } template <> auto get_elimit() { // due to rounding, int8 quantization might have 1 abs error double rtol = 1; double atol = 1; return ck_tile::make_tuple(rtol, atol); } auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser.insert("m", "3328", "m dimension") .insert("n", "4096", "n dimension") .insert("x_stride", "-1", "input stride per row, if -1 then equal to n") .insert("y_stride", "-1", "output stride per row, if -1 then equal to n") .insert("v", "1", "cpu validation or not") .insert("kname", "1", "print kernel name or not") .insert("prec", "fp16", "precision") .insert("warmup", "5", "cold iter") .insert("repeat", "20", "hot iter"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); } template bool run(const ck_tile::ArgParser& arg_parser) { ck_tile::index_t m = arg_parser.get_int("m"); ck_tile::index_t n = arg_parser.get_int("n"); ck_tile::index_t x_stride = arg_parser.get_int("x_stride"); if(x_stride < 0) x_stride = n; ck_tile::index_t y_stride = arg_parser.get_int("y_stride"); if(y_stride < 0) y_stride = n; std::string data_type = arg_parser.get_str("prec"); int kname = arg_parser.get_int("kname"); int do_validation = arg_parser.get_int("v"); int warmup = arg_parser.get_int("warmup"); int repeat = arg_parser.get_int("repeat"); assert(x_stride >= n); using TypeConfig = SmoothquantTypeConfig; using XDataType = typename TypeConfig::XDataType; using XScaleDataType = typename TypeConfig::XScaleDataType; using YScaleDataType = typename TypeConfig::YScaleDataType; using QYDataType = typename TypeConfig::QYDataType; using ComputeDataType = typename TypeConfig::ComputeDataType; // host verify ck_tile::HostTensor x_host({m, n}, {x_stride, 1}); ck_tile::HostTensor xscale_host({n}); ck_tile::HostTensor yscale_host_ref({m}, {1}); ck_tile::HostTensor yscale_host_dev({m}, {1}); ck_tile::HostTensor qy_host_ref({m, n}, {y_stride, 1}); ck_tile::HostTensor qy_host_dev({m, n}, {y_stride, 1}); ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution{1e-3, .5f}(xscale_host); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); x_buf.ToDevice(x_host.data()); xscale_buf.ToDevice(xscale_host.data()); std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride << std::flush; smoothquant_traits traits{data_type}; smoothquant_args args{x_buf.GetDeviceBuffer(), xscale_buf.GetDeviceBuffer(), yscale_buf.GetDeviceBuffer(), qy_buf.GetDeviceBuffer(), m, n, x_stride, y_stride}; float ave_time = smoothquant( traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(XScaleDataType) * n + sizeof(YScaleDataType) * m + sizeof(QYDataType) * m * n; float gb_per_sec = num_byte / 1.E6 / ave_time; std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; bool pass = true; if(do_validation) { using YDataType = ComputeDataType; ck_tile::HostTensor y_host({m, n}, {y_stride, 1}); // smooth outlier { auto f = [&](auto n_) { auto v_xscale = ck_tile::type_convert(xscale_host(n_)); for(int m_ = 0; m_ < m; ++m_) { auto v_x = ck_tile::type_convert(x_host(m_, n_)); y_host(m_, n_) = v_x * v_xscale; } }; ck_tile::make_ParallelTensorFunctor(f, xscale_host.get_element_space_size())( std::thread::hardware_concurrency()); } // yscale { ck_tile::HostTensor y_rowwise_amax_host({m}); using ReduceAmax = ck_tile::ReduceOp::AbsMax; ck_tile::reference_reduce( y_host, y_rowwise_amax_host, ReduceAmax{}); auto op = [](const auto& v0) { return v0 / ck_tile::type_convert(ck_tile::numeric::max()); }; ck_tile::reference_unary_elementwise( y_rowwise_amax_host, yscale_host_ref, op); yscale_buf.FromDevice(yscale_host_dev.mData.data()); auto [rtol, atol] = get_elimit(); pass &= ck_tile::check_err(yscale_host_dev, yscale_host_ref, std::string("yscale Error: Incorrect results!"), rtol, atol); } // rowwise quantization { ck_tile::reference_rowwise_quantization2d( y_host, yscale_host_ref, qy_host_ref); qy_buf.FromDevice(qy_host_dev.data()); auto [rtol, atol] = get_elimit(); if(y_stride == n) { pass = ck_tile::check_err(qy_host_dev, qy_host_ref, std::string("qy Error: Incorrect results!"), rtol, atol); } else { for(int i_r = 0; i_r < m; i_r++) { std::vector qy_host_dev_row(qy_host_dev.begin() + i_r * y_stride, qy_host_dev.begin() + i_r * y_stride + n); std::vector qy_host_ref_row(qy_host_ref.begin() + i_r * y_stride, qy_host_ref.begin() + i_r * y_stride + n); pass &= ck_tile::check_err(qy_host_dev_row, qy_host_ref_row, std::string("qy[") + std::to_string(i_r) + std::string("] Error: Incorrect results!"), rtol, atol); } } } std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } return pass; } int main(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); if(!result) return -1; const std::string data_type = arg_parser.get_str("prec"); if(data_type == "fp16") { return run(arg_parser) ? 0 : -2; } else if(data_type == "bf16") { return run(arg_parser) ? 0 : -2; } return -3; }