Unverified Commit f1e53807 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge branch 'develop' into ck_host_lib

parents 7450417d d9f1ead3
...@@ -35,7 +35,8 @@ auto create_args(int argc, char* argv[]) ...@@ -35,7 +35,8 @@ auto create_args(int argc, char* argv[])
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension") arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension") .insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n") .insert("x_stride", "-1", "input stride per row, if -1 then equal to n")
.insert("y_stride", "-1", "output stride per row, if -1 then equal to n")
.insert("e", "1e-5", "epsilon") .insert("e", "1e-5", "epsilon")
.insert("v", "1", "cpu validation or not") .insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision") .insert("prec", "fp16", "precision")
...@@ -49,44 +50,47 @@ auto create_args(int argc, char* argv[]) ...@@ -49,44 +50,47 @@ auto create_args(int argc, char* argv[])
template <typename DataType> template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
ck_tile::index_t m = arg_parser.get_int("m"); ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n"); ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride"); ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(stride < 0) if(x_stride < 0)
stride = n; x_stride = n;
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
if(y_stride < 0)
y_stride = n;
std::string data_type = arg_parser.get_str("prec"); std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v"); int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup"); int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat"); int repeat = arg_parser.get_int("repeat");
assert(stride >= n); assert(x_stride >= n);
using XDataType = DataType; using XDataType = DataType;
using XScaleDataType = float; using SmoothScaleDataType = float;
using YScaleDataType = float; using YScaleDataType = float;
using QYDataType = ck_tile::int8_t; using QYDataType = ck_tile::int8_t;
using ComputeDataType = float; using ComputeDataType = float;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1}); ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<XScaleDataType> xscale_host({n}); ck_tile::HostTensor<SmoothScaleDataType> smscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {stride, 1}); ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {y_stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {stride, 1}); ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host); ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data()); x_buf.ToDevice(x_host.data());
xscale_buf.ToDevice(xscale_host.data()); smscale_buf.ToDevice(smscale_host.data());
constexpr bool kTwoPass = true; constexpr bool kTwoPass = true;
...@@ -97,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -97,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>; using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Problem = ck_tile::SmoothquantPipelineProblem<XDataType, using Problem = ck_tile::SmoothquantPipelineProblem<XDataType,
XScaleDataType, SmoothScaleDataType,
ComputeDataType, ComputeDataType,
YScaleDataType, YScaleDataType,
QYDataType, QYDataType,
...@@ -111,12 +115,13 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -111,12 +115,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
using Kernel = ck_tile::Smoothquant<Pipeline>; using Kernel = ck_tile::Smoothquant<Pipeline>;
ck_tile::SmoothquantHostArgs args{x_buf.GetDeviceBuffer(), ck_tile::SmoothquantHostArgs args{x_buf.GetDeviceBuffer(),
xscale_buf.GetDeviceBuffer(), smscale_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(), yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(), qy_buf.GetDeviceBuffer(),
m, m,
n, n,
stride}; x_stride,
y_stride};
auto kargs = Kernel::MakeKargs(args); auto kargs = Kernel::MakeKargs(args);
...@@ -133,20 +138,20 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -133,20 +138,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation) if(do_validation)
{ {
using YDataType = ComputeDataType; using YDataType = ComputeDataType;
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {stride, 1}); ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {y_stride, 1});
// smooth outlier // smooth outlier
{ {
auto f = [&](auto n_) { auto f = [&](auto n_) {
auto v_xscale = ck_tile::type_convert<ComputeDataType>(xscale_host(n_)); auto v_smscale = ck_tile::type_convert<ComputeDataType>(smscale_host(n_));
for(int m_ = 0; m_ < m; ++m_) for(int m_ = 0; m_ < m; ++m_)
{ {
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_)); auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_));
y_host(m_, n_) = v_x * v_xscale; y_host(m_, n_) = v_x * v_smscale;
} }
}; };
ck_tile::make_ParallelTensorFunctor(f, xscale_host.get_element_space_size())( ck_tile::make_ParallelTensorFunctor(f, smscale_host.get_element_space_size())(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
} }
...@@ -183,7 +188,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -183,7 +188,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf.FromDevice(qy_host_dev.data()); qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>(); auto [rtol, atol] = get_elimit<QYDataType>();
if(stride == n) if(y_stride == n)
{ {
pass = ck_tile::check_err(qy_host_dev, pass = ck_tile::check_err(qy_host_dev,
qy_host_ref, qy_host_ref,
...@@ -195,10 +200,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -195,10 +200,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
for(int i_r = 0; i_r < m; i_r++) for(int i_r = 0; i_r < m; i_r++)
{ {
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride, std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * y_stride,
qy_host_dev.begin() + i_r * stride + n); qy_host_dev.begin() + i_r * y_stride +
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride, n);
qy_host_ref.begin() + i_r * stride + n); std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * y_stride,
qy_host_ref.begin() + i_r * y_stride +
n);
pass &= ck_tile::check_err(qy_host_dev_row, pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row, qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) + std::string("qy[") + std::to_string(i_r) +
...@@ -210,8 +217,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -210,8 +217,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
std::cout << "[" << data_type << "]" std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << " m:" << m << ", n:" << n << ", x_stride:" << x_stride
<< ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; << ", y_stride:" << y_stride << ", valid:" << (pass ? "y" : "n") << std::flush
<< std::endl;
} }
return pass; return pass;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp> #include <ck_tile/core.hpp>
#include "smoothquant.hpp" #include "smoothquant.hpp"
...@@ -35,7 +35,7 @@ float smoothquant_(const S& s, A a) ...@@ -35,7 +35,7 @@ float smoothquant_(const S& s, A a)
using PipelineProblem = ck_tile::SmoothquantPipelineProblem< using PipelineProblem = ck_tile::SmoothquantPipelineProblem<
typename SmoothquantTypeConfig<DataType>::XDataType, typename SmoothquantTypeConfig<DataType>::XDataType,
typename SmoothquantTypeConfig<DataType>::XScaleDataType, typename SmoothquantTypeConfig<DataType>::SmoothScaleDataType,
typename SmoothquantTypeConfig<DataType>::ComputeDataType, typename SmoothquantTypeConfig<DataType>::ComputeDataType,
typename SmoothquantTypeConfig<DataType>::YScaleDataType, typename SmoothquantTypeConfig<DataType>::YScaleDataType,
typename SmoothquantTypeConfig<DataType>::QYDataType, typename SmoothquantTypeConfig<DataType>::QYDataType,
......
...@@ -33,7 +33,8 @@ auto create_args(int argc, char* argv[]) ...@@ -33,7 +33,8 @@ auto create_args(int argc, char* argv[])
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension") arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension") .insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n") .insert("x_stride", "-1", "input stride per row, if -1 then equal to n")
.insert("y_stride", "-1", "output stride per row, if -1 then equal to n")
.insert("v", "1", "cpu validation or not") .insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not") .insert("kname", "1", "print kernel name or not")
.insert("prec", "fp16", "precision") .insert("prec", "fp16", "precision")
...@@ -47,65 +48,70 @@ auto create_args(int argc, char* argv[]) ...@@ -47,65 +48,70 @@ auto create_args(int argc, char* argv[])
template <typename DataType> template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
ck_tile::index_t m = arg_parser.get_int("m"); ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n"); ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride"); ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(stride < 0) if(x_stride < 0)
stride = n; x_stride = n;
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
if(y_stride < 0)
y_stride = n;
std::string data_type = arg_parser.get_str("prec"); std::string data_type = arg_parser.get_str("prec");
int kname = arg_parser.get_int("kname"); int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v"); int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup"); int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat"); int repeat = arg_parser.get_int("repeat");
assert(stride >= n); assert(x_stride >= n);
using TypeConfig = SmoothquantTypeConfig<DataType>; using TypeConfig = SmoothquantTypeConfig<DataType>;
using XDataType = typename TypeConfig::XDataType; using XDataType = typename TypeConfig::XDataType;
using XScaleDataType = typename TypeConfig::XScaleDataType; using SmoothScaleDataType = typename TypeConfig::SmoothScaleDataType;
using YScaleDataType = typename TypeConfig::YScaleDataType; using YScaleDataType = typename TypeConfig::YScaleDataType;
using QYDataType = typename TypeConfig::QYDataType; using QYDataType = typename TypeConfig::QYDataType;
using ComputeDataType = typename TypeConfig::ComputeDataType; using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1}); ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<XScaleDataType> xscale_host({n}); ck_tile::HostTensor<SmoothScaleDataType> smscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {stride, 1}); ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {y_stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {stride, 1}); ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host); ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data()); x_buf.ToDevice(x_host.data());
xscale_buf.ToDevice(xscale_host.data()); smscale_buf.ToDevice(smscale_host.data());
std::cout << "[" << data_type << "]" std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride
<< std::flush;
smoothquant_traits traits{data_type}; smoothquant_traits traits{data_type};
smoothquant_args args{x_buf.GetDeviceBuffer(), smoothquant_args args{x_buf.GetDeviceBuffer(),
xscale_buf.GetDeviceBuffer(), smscale_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(), yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(), qy_buf.GetDeviceBuffer(),
m, m,
n, n,
stride}; x_stride,
y_stride};
float ave_time = smoothquant( float ave_time = smoothquant(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(XScaleDataType) * n + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(SmoothScaleDataType) * n +
sizeof(YScaleDataType) * m + sizeof(QYDataType) * m * n; sizeof(YScaleDataType) * m + sizeof(QYDataType) * m * n;
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
...@@ -116,20 +122,20 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -116,20 +122,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation) if(do_validation)
{ {
using YDataType = ComputeDataType; using YDataType = ComputeDataType;
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {stride, 1}); ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {y_stride, 1});
// smooth outlier // smooth outlier
{ {
auto f = [&](auto n_) { auto f = [&](auto n_) {
auto v_xscale = ck_tile::type_convert<ComputeDataType>(xscale_host(n_)); auto v_smscale = ck_tile::type_convert<ComputeDataType>(smscale_host(n_));
for(int m_ = 0; m_ < m; ++m_) for(int m_ = 0; m_ < m; ++m_)
{ {
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_)); auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_));
y_host(m_, n_) = v_x * v_xscale; y_host(m_, n_) = v_x * v_smscale;
} }
}; };
ck_tile::make_ParallelTensorFunctor(f, xscale_host.get_element_space_size())( ck_tile::make_ParallelTensorFunctor(f, smscale_host.get_element_space_size())(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
} }
...@@ -166,7 +172,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -166,7 +172,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf.FromDevice(qy_host_dev.data()); qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>(); auto [rtol, atol] = get_elimit<QYDataType>();
if(stride == n) if(y_stride == n)
{ {
pass = ck_tile::check_err(qy_host_dev, pass = ck_tile::check_err(qy_host_dev,
qy_host_ref, qy_host_ref,
...@@ -178,10 +184,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -178,10 +184,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
for(int i_r = 0; i_r < m; i_r++) for(int i_r = 0; i_r < m; i_r++)
{ {
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride, std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * y_stride,
qy_host_dev.begin() + i_r * stride + n); qy_host_dev.begin() + i_r * y_stride +
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride, n);
qy_host_ref.begin() + i_r * stride + n); std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * y_stride,
qy_host_ref.begin() + i_r * y_stride +
n);
pass &= ck_tile::check_err(qy_host_dev_row, pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row, qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) + std::string("qy[") + std::to_string(i_r) +
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,21 +14,21 @@ struct SmoothquantTypeConfig; ...@@ -14,21 +14,21 @@ struct SmoothquantTypeConfig;
template <> template <>
struct SmoothquantTypeConfig<ck_tile::half_t> struct SmoothquantTypeConfig<ck_tile::half_t>
{ {
using XDataType = ck_tile::half_t; using XDataType = ck_tile::half_t;
using XScaleDataType = float; using SmoothScaleDataType = float;
using YScaleDataType = float; using YScaleDataType = float;
using QYDataType = ck_tile::int8_t; using QYDataType = ck_tile::int8_t;
using ComputeDataType = float; using ComputeDataType = float;
}; };
template <> template <>
struct SmoothquantTypeConfig<ck_tile::bf16_t> struct SmoothquantTypeConfig<ck_tile::bf16_t>
{ {
using XDataType = ck_tile::bf16_t; using XDataType = ck_tile::bf16_t;
using XScaleDataType = float; using SmoothScaleDataType = float;
using YScaleDataType = float; using YScaleDataType = float;
using QYDataType = ck_tile::int8_t; using QYDataType = ck_tile::int8_t;
using ComputeDataType = float; using ComputeDataType = float;
}; };
// runtime args // runtime args
......
...@@ -3,18 +3,42 @@ ...@@ -3,18 +3,42 @@
#include "moe_sorting_api.hpp" #include "moe_sorting_api.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \ #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \ using ms_problem = \
auto kargs = kernel::MakeKargs(a); \ ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num, expert_tile>; \
const dim3 grids = kernel::GridSize(a); \ using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
const dim3 blocks = kernel::BlockSize(a); \ auto kargs = kernel::MakeKargs(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \ const dim3 grids = kernel::GridSize(a); \
float ave_time = ck_tile::launch_kernel( \ const dim3 blocks = kernel::BlockSize(a); \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time; return ave_time;
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \
} \
else if(a.num_experts <= 16) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \
} \
else if(a.num_experts <= 32) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \
} \
else if(a.num_experts <= 64) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \
} \
else \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
{ {
if(t.weight_type == "fp32" && t.index_type == "int32") if(t.weight_type == "fp32" && t.index_type == "int32")
...@@ -49,21 +73,12 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi ...@@ -49,21 +73,12 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
case(6): { case(6): {
MOE_SORTING_DISPATCH(6); MOE_SORTING_DISPATCH(6);
} }
case(7): {
MOE_SORTING_DISPATCH(7);
}
case(8): { case(8): {
MOE_SORTING_DISPATCH(8); MOE_SORTING_DISPATCH(8);
} }
case(9): {
MOE_SORTING_DISPATCH(9);
}
case(10): { case(10): {
MOE_SORTING_DISPATCH(10); MOE_SORTING_DISPATCH(10);
} }
case(11): {
MOE_SORTING_DISPATCH(11);
}
default: { default: {
MOE_SORTING_DISPATCH(4); MOE_SORTING_DISPATCH(4);
} }
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <string> #include <string>
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "ck_tile/ops/moe_sorting.hpp" #include "ck_tile/ops/fused_moe.hpp"
struct moe_sorting_trait struct moe_sorting_trait
{ {
......
...@@ -16,4 +16,5 @@ $EXE -t=127 -e=99 -k=19 ...@@ -16,4 +16,5 @@ $EXE -t=127 -e=99 -k=19
$EXE -t=71 -e=11 -k=11 $EXE -t=71 -e=11 -k=11
$EXE -t=1 -e=1 -k=1 $EXE -t=1 -e=1 -k=1
$EXE -t=99 -e=2 -k=1 $EXE -t=99 -e=2 -k=1
$EXE -t=333 -e=99 -k=13 $EXE -t=333 -e=99 -k=13
\ No newline at end of file $EXE -t=128 -e=32 -k=5 -moe_buf_size=262144
function (add_moe_smoothquant_example TARGET_NAME MAIN_SRC)
message("adding ${TARGET_NAME}")
# not using add_example_executable() to add target, since we don't want this to have
# to be included in "make all/install/check"
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC})
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
foreach(source IN LISTS ARGN)
list(APPEND INSTANCE_SRCS ${source})
endforeach()
target_sources(${TARGET_NAME} PRIVATE ${INSTANCE_SRCS})
set(COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS})
endfunction(add_moe_smoothquant_example TARGET_NAME MAIN_SRC)
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_moe_smoothquant_example(tile_example_moe_smoothquant moe_smoothquant.cpp ${INSTANCE_SRCS})
# moe-smoothquant
This folder contains example for moe-smoothquant using ck_tile tile-programming implementation.
![](misc/moe-sm.png)
Unlike standard smoothquant op, the input scale is from different expert `[expert, hidden]`, we need reuse the `topk-id` from previous `topk-softmax` and select the corresponding `expert` from current topk, and expand the output/per-token-scale by `topk`
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_moe_smoothquant -j
```
This will result in an executable `build/bin/tile_example_moe_smoothquant`
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
#if 0
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true, false>>(const S&, A);
#endif
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 2, 128, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 2, 128, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 2, 128, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 256, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 2, 128, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 2, 128, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 2, 128, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 256, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 4, 64, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 2, 128, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 6, 1, 256, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 4, 64, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 2, 128, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 6, 1, 256, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 1, 256, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 8, 1, 256, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 1, 256, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 8, 1, 256, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 4, 64, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 4, 64, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 1, 128, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 6, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 1, 1024, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 1, 128, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 6, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 1, 1024, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 256, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 1024, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 1024, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 256, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 1024, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 256, 8, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 256, 4, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 1024, 2, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 1024, 1, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 256, 8, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 256, 4, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 1024, 2, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 1, true, true>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 4, 64, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 8, 4, 64, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 4, 64, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 8, 4, 64, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 4, 64, 1, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 4, 64, 2, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 4, 64, 1, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 4, 64, 1, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 4, 64, 2, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 4, 64, 1, true , false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 6, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 12, 4, 64, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 6, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 12, 4, 64, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
#if 0
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 8, true ,false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 4, true ,false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 2, true ,false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true ,false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true ,false>>(const S&, A);
#endif
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 2, 128, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 2, 128, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 2, 128, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 256, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 1, 2, 128, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 2, 128, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 2, 128, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 256, 1, true, false>>(const S&, A);
// clang-format on
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment