Commit b74918bc authored by ThomasNing's avatar ThomasNing
Browse files

compiled version of cross gpu connection

parents 3fcad951 1c45ca35
...@@ -14,12 +14,34 @@ ...@@ -14,12 +14,34 @@
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "gemm_basic.hpp" #include "gemm_basic.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_MEMORY 2
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#endif
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{ {
// ToDo: This will be modified by the codegen code later. #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler
constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128; constexpr ck_tile::index_t N_Tile = 32;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 4;
constexpr ck_tile::index_t N_Warp = 1;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
// Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32; constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t M_Warp = 2;
...@@ -28,12 +50,12 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -28,12 +50,12 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 16;
#endif
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. constexpr bool kPadM = false;
constexpr bool kPadA = true; constexpr bool kPadN = false;
constexpr bool kPadB = true; constexpr bool kPadK = false;
constexpr bool kPadC = true;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
...@@ -46,11 +68,14 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -46,11 +68,14 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>; using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue< using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, false, kPadC>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<
#endif
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>; ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
...@@ -63,13 +88,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -63,13 +88,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value; constexpr auto tail_number_v = tail_number_.value;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<
#endif
ck_tile::UniversalGemmPipelineProblem<ADataType, ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
GemmShape, GemmShape,
Traits, Traits,
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
ck_tile::GemmPipelineScheduler::Interwave,
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
ck_tile::GemmPipelineScheduler::Intrawave, ck_tile::GemmPipelineScheduler::Intrawave,
#endif
has_hot_loop_v, has_hot_loop_v,
tail_number_v>>; tail_number_v>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
...@@ -86,6 +119,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -86,6 +119,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0) if(s.log_level_ > 0)
{ {
std::cout << "Launching kernel with args:" std::cout << "Launching kernel with args:"
...@@ -174,8 +212,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -174,8 +212,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{ {
std::ostringstream err; std::ostringstream err;
err << "When there's no hot loop, this tail number \"" << tail_num err << "When there's no hot loop, this tail number \"" << tail_num
<< "\" is not supported! " << __FILE__ << ":" << __LINE__ << "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< ", in function: " << __func__; << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
} }
......
...@@ -40,7 +40,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t, ...@@ -40,7 +40,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
else if(t.permute.compare("0,1,3,4,2,5") == 0) else if(t.permute.compare("0,1,3,4,2,5") == 0)
{ {
constexpr matrix_core_permute_style pstyle = constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv; matrix_core_permute_style::b_nr_kr_kw_nw_kv;
using Kernel = using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>; matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
...@@ -83,7 +83,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t, ...@@ -83,7 +83,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
else if(t.permute.compare("0,1,3,4,2,5") == 0) else if(t.permute.compare("0,1,3,4,2,5") == 0)
{ {
constexpr matrix_core_permute_style pstyle = constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv; matrix_core_permute_style::b_nr_kr_kw_nw_kv;
using Kernel = using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>; matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
......
...@@ -42,8 +42,8 @@ enum class matrix_core_permute_style ...@@ -42,8 +42,8 @@ enum class matrix_core_permute_style
{ {
permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6 permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6 permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5 b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv, b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
}; };
// assume this is B matrix, originally we have batch*n*k // assume this is B matrix, originally we have batch*n*k
...@@ -203,7 +203,7 @@ struct matrix_core_swizzle_kernel ...@@ -203,7 +203,7 @@ struct matrix_core_swizzle_kernel
else else
{ {
// clang-format off // clang-format off
// permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten // b_nr_kr_kw_nw_kv or b_nr_kr_waveflatten
constexpr index_t Kv = Alignment; constexpr index_t Kv = Alignment;
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
...@@ -332,7 +332,7 @@ struct matrix_core_swizzle_kernel ...@@ -332,7 +332,7 @@ struct matrix_core_swizzle_kernel
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
return tmp_1; return tmp_1;
#else #else
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv, // b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
constexpr index_t kv = Alignment; constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
...@@ -376,13 +376,13 @@ struct matrix_core_swizzle_kernel ...@@ -376,13 +376,13 @@ struct matrix_core_swizzle_kernel
else else
{ {
#if MERGE_2D_013425 #if MERGE_2D_013425
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv // b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv
return make_tile_window(dst_view, return make_tile_window(dst_view,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{i_n * NPerBlock, i_k * KPerBlock}, {i_n * NPerBlock, i_k * KPerBlock},
get_dst_dist()); get_dst_dist());
#else #else
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv // b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv
constexpr index_t kv = Alignment; constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
......
...@@ -264,7 +264,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -264,7 +264,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5")) if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5"))
{ {
// permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5 // b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
matrix_core_swizzle_traits t; matrix_core_swizzle_traits t;
t.data_type = data_type; t.data_type = data_type;
t.permute = arg_parser.get_str("perm"); t.permute = arg_parser.get_str("perm");
......
...@@ -18,7 +18,7 @@ function (add_smoothquant_example TARGET_NAME MAIN_SRC) ...@@ -18,7 +18,7 @@ function (add_smoothquant_example TARGET_NAME MAIN_SRC)
target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS}) target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS})
endfunction(add_smoothquant_example TARGET_NAME MAIN_SRC) endfunction(add_smoothquant_example TARGET_NAME MAIN_SRC)
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_smoothquant_example(tile_smoothquant smoothquant.cpp ${INSTANCE_SRCS})
add_smoothquant_example(tile_example_smoothquant example_smoothquant.cpp) add_smoothquant_example(tile_example_smoothquant example_smoothquant.cpp)
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_smoothquant_example(tile_smoothquant smoothquant.cpp ${INSTANCE_SRCS})
...@@ -35,7 +35,8 @@ auto create_args(int argc, char* argv[]) ...@@ -35,7 +35,8 @@ auto create_args(int argc, char* argv[])
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension") arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension") .insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n") .insert("x_stride", "-1", "input stride per row, if -1 then equal to n")
.insert("y_stride", "-1", "output stride per row, if -1 then equal to n")
.insert("e", "1e-5", "epsilon") .insert("e", "1e-5", "epsilon")
.insert("v", "1", "cpu validation or not") .insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision") .insert("prec", "fp16", "precision")
...@@ -49,11 +50,14 @@ auto create_args(int argc, char* argv[]) ...@@ -49,11 +50,14 @@ auto create_args(int argc, char* argv[])
template <typename DataType> template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
ck_tile::index_t m = arg_parser.get_int("m"); ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n"); ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride"); ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(stride < 0) if(x_stride < 0)
stride = n; x_stride = n;
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
if(y_stride < 0)
y_stride = n;
std::string data_type = arg_parser.get_str("prec"); std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v"); int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup"); int warmup = arg_parser.get_int("warmup");
...@@ -68,14 +72,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -68,14 +72,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
using ComputeDataType = float; using ComputeDataType = float;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1}); ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<XScaleDataType> xscale_host({n}); ck_tile::HostTensor<XScaleDataType> xscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {stride, 1}); ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {y_stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {stride, 1}); ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host); ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host);
...@@ -116,7 +120,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -116,7 +120,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf.GetDeviceBuffer(), qy_buf.GetDeviceBuffer(),
m, m,
n, n,
stride}; x_stride,
y_stride};
auto kargs = Kernel::MakeKargs(args); auto kargs = Kernel::MakeKargs(args);
...@@ -133,7 +138,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -133,7 +138,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation) if(do_validation)
{ {
using YDataType = ComputeDataType; using YDataType = ComputeDataType;
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {stride, 1}); ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {y_stride, 1});
// smooth outlier // smooth outlier
{ {
auto f = [&](auto n_) { auto f = [&](auto n_) {
...@@ -183,7 +188,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -183,7 +188,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf.FromDevice(qy_host_dev.data()); qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>(); auto [rtol, atol] = get_elimit<QYDataType>();
if(stride == n) if(y_stride == n)
{ {
pass = ck_tile::check_err(qy_host_dev, pass = ck_tile::check_err(qy_host_dev,
qy_host_ref, qy_host_ref,
...@@ -195,10 +200,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -195,10 +200,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
for(int i_r = 0; i_r < m; i_r++) for(int i_r = 0; i_r < m; i_r++)
{ {
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride, std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * y_stride,
qy_host_dev.begin() + i_r * stride + n); qy_host_dev.begin() + i_r * y_stride +
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride, n);
qy_host_ref.begin() + i_r * stride + n); std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * y_stride,
qy_host_ref.begin() + i_r * y_stride +
n);
pass &= ck_tile::check_err(qy_host_dev_row, pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row, qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) + std::string("qy[") + std::to_string(i_r) +
...@@ -210,8 +217,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -210,8 +217,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
std::cout << "[" << data_type << "]" std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << " m:" << m << ", n:" << n << ", x_stride:" << x_stride
<< ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; << ", y_stride:" << y_stride << ", valid:" << (pass ? "y" : "n") << std::flush
<< std::endl;
} }
return pass; return pass;
......
...@@ -33,7 +33,8 @@ auto create_args(int argc, char* argv[]) ...@@ -33,7 +33,8 @@ auto create_args(int argc, char* argv[])
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension") arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension") .insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n") .insert("x_stride", "-1", "input stride per row, if -1 then equal to n")
.insert("y_stride", "-1", "output stride per row, if -1 then equal to n")
.insert("v", "1", "cpu validation or not") .insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not") .insert("kname", "1", "print kernel name or not")
.insert("prec", "fp16", "precision") .insert("prec", "fp16", "precision")
...@@ -47,18 +48,21 @@ auto create_args(int argc, char* argv[]) ...@@ -47,18 +48,21 @@ auto create_args(int argc, char* argv[])
template <typename DataType> template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
ck_tile::index_t m = arg_parser.get_int("m"); ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n"); ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride"); ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(stride < 0) if(x_stride < 0)
stride = n; x_stride = n;
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
if(y_stride < 0)
y_stride = n;
std::string data_type = arg_parser.get_str("prec"); std::string data_type = arg_parser.get_str("prec");
int kname = arg_parser.get_int("kname"); int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v"); int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup"); int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat"); int repeat = arg_parser.get_int("repeat");
assert(stride >= n); assert(x_stride >= n);
using TypeConfig = SmoothquantTypeConfig<DataType>; using TypeConfig = SmoothquantTypeConfig<DataType>;
...@@ -69,14 +73,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -69,14 +73,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
using ComputeDataType = typename TypeConfig::ComputeDataType; using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1}); ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<XScaleDataType> xscale_host({n}); ck_tile::HostTensor<XScaleDataType> xscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {stride, 1}); ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {y_stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {stride, 1}); ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host); ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host);
...@@ -90,7 +94,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -90,7 +94,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
xscale_buf.ToDevice(xscale_host.data()); xscale_buf.ToDevice(xscale_host.data());
std::cout << "[" << data_type << "]" std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride
<< std::flush;
smoothquant_traits traits{data_type}; smoothquant_traits traits{data_type};
...@@ -100,7 +105,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -100,7 +105,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf.GetDeviceBuffer(), qy_buf.GetDeviceBuffer(),
m, m,
n, n,
stride}; x_stride,
y_stride};
float ave_time = smoothquant( float ave_time = smoothquant(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
...@@ -116,7 +122,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -116,7 +122,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation) if(do_validation)
{ {
using YDataType = ComputeDataType; using YDataType = ComputeDataType;
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {stride, 1}); ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {y_stride, 1});
// smooth outlier // smooth outlier
{ {
auto f = [&](auto n_) { auto f = [&](auto n_) {
...@@ -166,7 +172,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -166,7 +172,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf.FromDevice(qy_host_dev.data()); qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>(); auto [rtol, atol] = get_elimit<QYDataType>();
if(stride == n) if(y_stride == n)
{ {
pass = ck_tile::check_err(qy_host_dev, pass = ck_tile::check_err(qy_host_dev,
qy_host_ref, qy_host_ref,
...@@ -178,10 +184,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -178,10 +184,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
for(int i_r = 0; i_r < m; i_r++) for(int i_r = 0; i_r < m; i_r++)
{ {
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride, std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * y_stride,
qy_host_dev.begin() + i_r * stride + n); qy_host_dev.begin() + i_r * y_stride +
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride, n);
qy_host_ref.begin() + i_r * stride + n); std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * y_stride,
qy_host_ref.begin() + i_r * y_stride +
n);
pass &= ck_tile::check_err(qy_host_dev_row, pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row, qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) + std::string("qy[") + std::to_string(i_r) +
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <string> #include <string>
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "ck_tile/ops/moe_sorting.hpp" #include "ck_tile/ops/fused_moe.hpp"
struct moe_sorting_trait struct moe_sorting_trait
{ {
......
function (add_moe_smoothquant_example TARGET_NAME MAIN_SRC)
message("adding ${TARGET_NAME}")
# not using add_example_executable() to add target, since we don't want this to have
# to be included in "make all/install/check"
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC})
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
foreach(source IN LISTS ARGN)
list(APPEND INSTANCE_SRCS ${source})
endforeach()
target_sources(${TARGET_NAME} PRIVATE ${INSTANCE_SRCS})
set(COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS})
endfunction(add_moe_smoothquant_example TARGET_NAME MAIN_SRC)
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_moe_smoothquant_example(tile_example_moe_smoothquant moe_smoothquant.cpp ${INSTANCE_SRCS})
# moe-smoothquant
This folder contains example for moe-smoothquant using ck_tile tile-programming implementation.
![](misc/moe-sm.png)
Unlike standard smoothquant op, the input scale is from different expert `[expert, hidden]`, we need reuse the `topk-id` from previous `topk-softmax` and select the corresponding `expert` from current topk, and expand the output/per-token-scale by `topk`
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_moe_smoothquant -j
```
This will result in an executable `build/bin/tile_example_moe_smoothquant`
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
#if 0
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true, false>>(const S&, A);
#endif
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 2, 128, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 2, 128, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 2, 128, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 2, 128, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 1, 256, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 128, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 1024, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, true>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 1, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 2, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 1, true , false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 6, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 12, 4, 64, 1, true, false>>(const S&, A);
// clang-format on
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment