Commit b74918bc authored by ThomasNing's avatar ThomasNing
Browse files

compiled version of cross gpu connection

parents 3fcad951 1c45ca35
......@@ -14,12 +14,34 @@
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_MEMORY 2
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#endif
template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{
// ToDo: This will be modified by the codegen code later.
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t N_Tile = 32;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 4;
constexpr ck_tile::index_t N_Warp = 1;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
// Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
......@@ -28,12 +50,12 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#endif
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true;
constexpr bool kPadB = true;
constexpr bool kPadC = true;
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr int kBlockPerCu = 1;
......@@ -46,11 +68,14 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, false, kPadC>>;
using Traits = ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>;
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<
#endif
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
......@@ -63,13 +88,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<
#endif
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
Traits,
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
ck_tile::GemmPipelineScheduler::Interwave,
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
ck_tile::GemmPipelineScheduler::Intrawave,
#endif
has_hot_loop_v,
tail_number_v>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
......@@ -86,6 +119,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
......@@ -174,8 +212,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{
std::ostringstream err;
err << "When there's no hot loop, this tail number \"" << tail_num
<< "\" is not supported! " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__;
<< "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
......
......@@ -40,7 +40,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
else if(t.permute.compare("0,1,3,4,2,5") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
matrix_core_permute_style::b_nr_kr_kw_nw_kv;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
......@@ -83,7 +83,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
else if(t.permute.compare("0,1,3,4,2,5") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
matrix_core_permute_style::b_nr_kr_kw_nw_kv;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
......
......@@ -42,8 +42,8 @@ enum class matrix_core_permute_style
{
permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
};
// assume this is B matrix, originally we have batch*n*k
......@@ -203,7 +203,7 @@ struct matrix_core_swizzle_kernel
else
{
// clang-format off
// permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten
// b_nr_kr_kw_nw_kv or b_nr_kr_waveflatten
constexpr index_t Kv = Alignment;
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
......@@ -332,7 +332,7 @@ struct matrix_core_swizzle_kernel
make_tuple(sequence<0>{}, sequence<1>{}));
return tmp_1;
#else
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
......@@ -376,13 +376,13 @@ struct matrix_core_swizzle_kernel
else
{
#if MERGE_2D_013425
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv
return make_tile_window(dst_view,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{i_n * NPerBlock, i_k * KPerBlock},
get_dst_dist());
#else
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv
constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
......
......@@ -264,7 +264,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5"))
{
// permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
// b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
matrix_core_swizzle_traits t;
t.data_type = data_type;
t.permute = arg_parser.get_str("perm");
......
......@@ -18,7 +18,7 @@ function (add_smoothquant_example TARGET_NAME MAIN_SRC)
target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS})
endfunction(add_smoothquant_example TARGET_NAME MAIN_SRC)
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_smoothquant_example(tile_smoothquant smoothquant.cpp ${INSTANCE_SRCS})
add_smoothquant_example(tile_example_smoothquant example_smoothquant.cpp)
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_smoothquant_example(tile_smoothquant smoothquant.cpp ${INSTANCE_SRCS})
......@@ -35,7 +35,8 @@ auto create_args(int argc, char* argv[])
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n")
.insert("x_stride", "-1", "input stride per row, if -1 then equal to n")
.insert("y_stride", "-1", "output stride per row, if -1 then equal to n")
.insert("e", "1e-5", "epsilon")
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision")
......@@ -49,11 +50,14 @@ auto create_args(int argc, char* argv[])
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride");
if(stride < 0)
stride = n;
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(x_stride < 0)
x_stride = n;
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
if(y_stride < 0)
y_stride = n;
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
......@@ -68,14 +72,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
using ComputeDataType = float;
// host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1});
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<XScaleDataType> xscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {y_stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host);
......@@ -116,7 +120,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf.GetDeviceBuffer(),
m,
n,
stride};
x_stride,
y_stride};
auto kargs = Kernel::MakeKargs(args);
......@@ -133,7 +138,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation)
{
using YDataType = ComputeDataType;
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {stride, 1});
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {y_stride, 1});
// smooth outlier
{
auto f = [&](auto n_) {
......@@ -183,7 +188,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>();
if(stride == n)
if(y_stride == n)
{
pass = ck_tile::check_err(qy_host_dev,
qy_host_ref,
......@@ -195,10 +200,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride,
qy_host_dev.begin() + i_r * stride + n);
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride,
qy_host_ref.begin() + i_r * stride + n);
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * y_stride,
qy_host_dev.begin() + i_r * y_stride +
n);
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * y_stride,
qy_host_ref.begin() + i_r * y_stride +
n);
pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) +
......@@ -210,8 +217,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride
<< ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride
<< ", y_stride:" << y_stride << ", valid:" << (pass ? "y" : "n") << std::flush
<< std::endl;
}
return pass;
......
......@@ -33,7 +33,8 @@ auto create_args(int argc, char* argv[])
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n")
.insert("x_stride", "-1", "input stride per row, if -1 then equal to n")
.insert("y_stride", "-1", "output stride per row, if -1 then equal to n")
.insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not")
.insert("prec", "fp16", "precision")
......@@ -47,18 +48,21 @@ auto create_args(int argc, char* argv[])
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride");
if(stride < 0)
stride = n;
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(x_stride < 0)
x_stride = n;
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
if(y_stride < 0)
y_stride = n;
std::string data_type = arg_parser.get_str("prec");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
assert(stride >= n);
assert(x_stride >= n);
using TypeConfig = SmoothquantTypeConfig<DataType>;
......@@ -69,14 +73,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1});
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<XScaleDataType> xscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {y_stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host);
......@@ -90,7 +94,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
xscale_buf.ToDevice(xscale_host.data());
std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride
<< std::flush;
smoothquant_traits traits{data_type};
......@@ -100,7 +105,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf.GetDeviceBuffer(),
m,
n,
stride};
x_stride,
y_stride};
float ave_time = smoothquant(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
......@@ -116,7 +122,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation)
{
using YDataType = ComputeDataType;
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {stride, 1});
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {y_stride, 1});
// smooth outlier
{
auto f = [&](auto n_) {
......@@ -166,7 +172,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>();
if(stride == n)
if(y_stride == n)
{
pass = ck_tile::check_err(qy_host_dev,
qy_host_ref,
......@@ -178,10 +184,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride,
qy_host_dev.begin() + i_r * stride + n);
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride,
qy_host_ref.begin() + i_r * stride + n);
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * y_stride,
qy_host_dev.begin() + i_r * y_stride +
n);
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * y_stride,
qy_host_ref.begin() + i_r * y_stride +
n);
pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) +
......
......@@ -5,7 +5,7 @@
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/moe_sorting.hpp"
#include "ck_tile/ops/fused_moe.hpp"
struct moe_sorting_trait
{
......
function (add_moe_smoothquant_example TARGET_NAME MAIN_SRC)
message("adding ${TARGET_NAME}")
# not using add_example_executable() to add target, since we don't want this to have
# to be included in "make all/install/check"
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC})
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
foreach(source IN LISTS ARGN)
list(APPEND INSTANCE_SRCS ${source})
endforeach()
target_sources(${TARGET_NAME} PRIVATE ${INSTANCE_SRCS})
set(COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS})
endfunction(add_moe_smoothquant_example TARGET_NAME MAIN_SRC)
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_moe_smoothquant_example(tile_example_moe_smoothquant moe_smoothquant.cpp ${INSTANCE_SRCS})
# moe-smoothquant
This folder contains example for moe-smoothquant using ck_tile tile-programming implementation.
![](misc/moe-sm.png)
Unlike standard smoothquant op, the input scale is from different expert `[expert, hidden]`, we need reuse the `topk-id` from previous `topk-softmax` and select the corresponding `expert` from current topk, and expand the output/per-token-scale by `topk`
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_moe_smoothquant -j
```
This will result in an executable `build/bin/tile_example_moe_smoothquant`
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
#if 0
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true, false>>(const S&, A);
#endif
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 2, 128, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 2, 128, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 2, 128, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 2, 128, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 1, 256, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 128, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 1024, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, true>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 1, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 2, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 1, true , false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 6, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 12, 4, 64, 1, true, false>>(const S&, A);
// clang-format on
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment