Unverified Commit 5e93fa9e authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #247 from ROCm/merge_from_public

Merge from public
parents 965b7ba4 2298a1a4
...@@ -145,7 +145,7 @@ decode_seqlen(mode_enum mode, ...@@ -145,7 +145,7 @@ decode_seqlen(mode_enum mode,
std::string k_val, std::string k_val,
std::string k_pad_val, std::string k_pad_val,
ck_tile::index_t seqlen_k_min = 0, ck_tile::index_t seqlen_k_min = 0,
bool use_kvcache = false, bool need_append_kvcache = false,
std::optional<unsigned> seed = std::nullopt) std::optional<unsigned> seed = std::nullopt)
{ {
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str())) #define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
...@@ -159,7 +159,7 @@ decode_seqlen(mode_enum mode, ...@@ -159,7 +159,7 @@ decode_seqlen(mode_enum mode,
const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k); const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k);
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k_max); std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k_max);
if(1 < batch && use_kvcache) if(1 < batch && need_append_kvcache)
{ {
// to keep the original s_k value, we always use seqlen_k_max in first batch // to keep the original s_k value, we always use seqlen_k_max in first batch
randints(std::next(seqlen_ks.begin()), randints(std::next(seqlen_ks.begin()),
......
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_gemm_mem_pipeline EXCLUDE_FROM_ALL gemm_mem_pipeline.cpp) add_executable(tile_example_universal_gemm EXCLUDE_FROM_ALL universal_gemm.cpp)
...@@ -92,6 +92,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -92,6 +92,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0) if(s.log_level_ > 0)
{ {
std::cout << "Launching kernel with args:" std::cout << "Launching kernel with args:"
......
...@@ -31,15 +31,13 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -31,15 +31,13 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
float ave_time = gemm_calc<ALayout, BLayout, CLayout>( float ave_time = gemm_calc<ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::string op_name{"Gemm{MemBoundPipeline}"};
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte = std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl; << std::endl;
...@@ -114,7 +112,6 @@ int run_gemm_example_with_layouts(int argc, ...@@ -114,7 +112,6 @@ int run_gemm_example_with_layouts(int argc,
f_host_tensor_descriptor(M, N, stride_C, CLayout{})); f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
// TODO: add different init types // TODO: add different init types
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k); ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n); ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
...@@ -202,14 +199,16 @@ int run_gemm_example(int argc, char* argv[]) ...@@ -202,14 +199,16 @@ int run_gemm_example(int argc, char* argv[])
{ {
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
} }
else if(a_layout == "C" && b_layout == "C") // TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
{ // work.
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); // else if(a_layout == "C" && b_layout == "C")
} // {
else if(a_layout == "C" && b_layout == "R") // return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
{ // }
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); // else if(a_layout == "C" && b_layout == "R")
} // {
// return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
// }
else else
{ {
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
......
...@@ -14,12 +14,34 @@ ...@@ -14,12 +14,34 @@
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "gemm_basic.hpp" #include "gemm_basic.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_MEMORY 2
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#endif
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{ {
// ToDo: This will be modified by the codegen code later. #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler
constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128; constexpr ck_tile::index_t N_Tile = 32;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 4;
constexpr ck_tile::index_t N_Warp = 1;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
// Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32; constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t M_Warp = 2;
...@@ -28,12 +50,12 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -28,12 +50,12 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 16;
#endif
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. constexpr bool kPadM = false;
constexpr bool kPadM = true; constexpr bool kPadN = false;
constexpr bool kPadN = true; constexpr bool kPadK = false;
constexpr bool kPadK = true;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
...@@ -49,8 +71,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -49,8 +71,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<
#endif
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>; ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
...@@ -63,13 +88,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -63,13 +88,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value; constexpr auto tail_number_v = tail_number_.value;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<
#endif
ck_tile::UniversalGemmPipelineProblem<ADataType, ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
GemmShape, GemmShape,
Traits, Traits,
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
ck_tile::GemmPipelineScheduler::Interwave,
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
ck_tile::GemmPipelineScheduler::Intrawave, ck_tile::GemmPipelineScheduler::Intrawave,
#endif
has_hot_loop_v, has_hot_loop_v,
tail_number_v>>; tail_number_v>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
...@@ -86,6 +119,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -86,6 +119,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0) if(s.log_level_ > 0)
{ {
std::cout << "Launching kernel with args:" std::cout << "Launching kernel with args:"
...@@ -174,8 +212,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -174,8 +212,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{ {
std::ostringstream err; std::ostringstream err;
err << "When there's no hot loop, this tail number \"" << tail_num err << "When there's no hot loop, this tail number \"" << tail_num
<< "\" is not supported! " << __FILE__ << ":" << __LINE__ << "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< ", in function: " << __func__; << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
} }
......
...@@ -40,7 +40,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t, ...@@ -40,7 +40,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
else if(t.permute.compare("0,1,3,4,2,5") == 0) else if(t.permute.compare("0,1,3,4,2,5") == 0)
{ {
constexpr matrix_core_permute_style pstyle = constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv; matrix_core_permute_style::b_nr_kr_kw_nw_kv;
using Kernel = using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>; matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
...@@ -83,7 +83,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t, ...@@ -83,7 +83,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
else if(t.permute.compare("0,1,3,4,2,5") == 0) else if(t.permute.compare("0,1,3,4,2,5") == 0)
{ {
constexpr matrix_core_permute_style pstyle = constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv; matrix_core_permute_style::b_nr_kr_kw_nw_kv;
using Kernel = using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>; matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
......
...@@ -42,8 +42,8 @@ enum class matrix_core_permute_style ...@@ -42,8 +42,8 @@ enum class matrix_core_permute_style
{ {
permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6 permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6 permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5 b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv, b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
}; };
// assume this is B matrix, originally we have batch*n*k // assume this is B matrix, originally we have batch*n*k
...@@ -203,7 +203,7 @@ struct matrix_core_swizzle_kernel ...@@ -203,7 +203,7 @@ struct matrix_core_swizzle_kernel
else else
{ {
// clang-format off // clang-format off
// permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten // b_nr_kr_kw_nw_kv or b_nr_kr_waveflatten
constexpr index_t Kv = Alignment; constexpr index_t Kv = Alignment;
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
...@@ -332,7 +332,7 @@ struct matrix_core_swizzle_kernel ...@@ -332,7 +332,7 @@ struct matrix_core_swizzle_kernel
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
return tmp_1; return tmp_1;
#else #else
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv, // b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
constexpr index_t kv = Alignment; constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
...@@ -376,13 +376,13 @@ struct matrix_core_swizzle_kernel ...@@ -376,13 +376,13 @@ struct matrix_core_swizzle_kernel
else else
{ {
#if MERGE_2D_013425 #if MERGE_2D_013425
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv // b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv
return make_tile_window(dst_view, return make_tile_window(dst_view,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{i_n * NPerBlock, i_k * KPerBlock}, {i_n * NPerBlock, i_k * KPerBlock},
get_dst_dist()); get_dst_dist());
#else #else
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv // b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv
constexpr index_t kv = Alignment; constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
......
...@@ -264,7 +264,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -264,7 +264,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5")) if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5"))
{ {
// permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5 // b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
matrix_core_swizzle_traits t; matrix_core_swizzle_traits t;
t.data_type = data_type; t.data_type = data_type;
t.permute = arg_parser.get_str("perm"); t.permute = arg_parser.get_str("perm");
......
...@@ -18,7 +18,7 @@ function (add_smoothquant_example TARGET_NAME MAIN_SRC) ...@@ -18,7 +18,7 @@ function (add_smoothquant_example TARGET_NAME MAIN_SRC)
target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS}) target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS})
endfunction(add_smoothquant_example TARGET_NAME MAIN_SRC) endfunction(add_smoothquant_example TARGET_NAME MAIN_SRC)
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_smoothquant_example(tile_smoothquant smoothquant.cpp ${INSTANCE_SRCS})
add_smoothquant_example(tile_example_smoothquant example_smoothquant.cpp) add_smoothquant_example(tile_example_smoothquant example_smoothquant.cpp)
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_smoothquant_example(tile_smoothquant smoothquant.cpp ${INSTANCE_SRCS})
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <string> #include <string>
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "ck_tile/ops/moe_sorting.hpp" #include "ck_tile/ops/fused_moe.hpp"
struct moe_sorting_trait struct moe_sorting_trait
{ {
......
function (add_moe_smoothquant_example TARGET_NAME MAIN_SRC)
message("adding ${TARGET_NAME}")
# not using add_example_executable() to add target, since we don't want this to have
# to be included in "make all/install/check"
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC})
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
foreach(source IN LISTS ARGN)
list(APPEND INSTANCE_SRCS ${source})
endforeach()
target_sources(${TARGET_NAME} PRIVATE ${INSTANCE_SRCS})
set(COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS})
endfunction(add_moe_smoothquant_example TARGET_NAME MAIN_SRC)
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_moe_smoothquant_example(tile_example_moe_smoothquant moe_smoothquant.cpp ${INSTANCE_SRCS})
# moe-smoothquant
This folder contains example for moe-smoothquant using ck_tile tile-programming implementation.
![](misc/moe-sm.png)
Unlike standard smoothquant op, the input scale is from different expert `[expert, hidden]`, we need reuse the `topk-id` from previous `topk-softmax` and select the corresponding `expert` from current topk, and expand the output/per-token-scale by `topk`
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_moe_smoothquant -j
```
This will result in an executable `build/bin/tile_example_moe_smoothquant`
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
#if 0
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true, false>>(const S&, A);
#endif
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 2, 128, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 2, 128, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 2, 128, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 2, 128, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 1, 256, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 128, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 1024, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, true>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 1, true, false>>(const S&, A);
// clang-format on
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment