Commit 70fa98ad authored by carlushuang's avatar carlushuang
Browse files

update code

parent 7c81aee8
set(TILE_EXAPMLE_FUSED_MOE "tile_example_fused_moe")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding ${TILE_EXAPMLE_FUSED_MOE}")
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_executable(${TILE_EXAPMLE_FUSED_MOE} EXCLUDE_FROM_ALL main.cpp)
target_include_directories(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${INSTANCE_SRCS})
set(TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
target_compile_options(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS})
...@@ -16,33 +16,33 @@ struct FusedMoeGemmTypeConfig; ...@@ -16,33 +16,33 @@ struct FusedMoeGemmTypeConfig;
template <typename ST, typename SW, typename SQ, typename KW> template <typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, ST, SW, SQ, KW>; struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, ST, SW, SQ, KW>;
{ {
using ADataType = ck_tile::bf16_t; using ADataType = ck_tile::bf16_t;
using GDataType = ck_tile::bf16_t; using GDataType = ck_tile::bf16_t;
using DDataType = ck_tile::bf16_t; using DDataType = ck_tile::bf16_t;
using AccDataType = float; using AccDataType = float;
using ODataType = ck_tile::bf16_t; using ODataType = ck_tile::bf16_t;
using AScaleDataType = ck_tile::remove_cvref_t<ST>; using AScaleDataType = ck_tile::remove_cvref_t<ST>;
using W0ScaleDataType = ck_tile::remove_cvref_t<SW>; using GScaleDataType = ck_tile::remove_cvref_t<SW>;
using W1ScaleDataType = ck_tile::remove_cvref_t<SW>; using DScaleDataType = ck_tile::remove_cvref_t<SW>;
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>; using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>; using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
using IndexDataType = ck_tile::index_t; using IndexDataType = ck_tile::index_t;
}; };
template <typename ST, typename SW, typename SQ, typename KW> template <typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, ST, SW, SQ, KW>; struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, ST, SW, SQ, KW>;
{ {
using ADataType = ck_tile::int8_t; using ADataType = ck_tile::int8_t;
using GDataType = ck_tile::int8_t; using GDataType = ck_tile::int8_t;
using DDataType = ck_tile::int8_t; using DDataType = ck_tile::int8_t;
using AccDataType = int32_t; using AccDataType = int32_t;
using ODataType = ck_tile::bf16_t; using ODataType = ck_tile::bf16_t;
using AScaleDataType = ck_tile::remove_cvref_t<ST>; using AScaleDataType = ck_tile::remove_cvref_t<ST>;
using W0ScaleDataType = ck_tile::remove_cvref_t<SW>; using GScaleDataType = ck_tile::remove_cvref_t<SW>;
using W1ScaleDataType = ck_tile::remove_cvref_t<SW>; using DScaleDataType = ck_tile::remove_cvref_t<SW>;
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>; using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>; using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
using IndexDataType = ck_tile::index_t; using IndexDataType = ck_tile::index_t;
}; };
// runtime args // runtime args
...@@ -53,14 +53,16 @@ struct fused_moegemm_args : public ck_tile::Layernorm2dFwdHostArgs ...@@ -53,14 +53,16 @@ struct fused_moegemm_args : public ck_tile::Layernorm2dFwdHostArgs
// This is the public API, will be generated by script // This is the public API, will be generated by script
struct fused_moegemm_traits struct fused_moegemm_traits
{ {
std::string prec_i; // input precision std::string prec_i; // input precision
std::string prec_w; // weight precision std::string prec_w; // weight precision
std::string prec_o; // output precision std::string prec_o; // output precision
std::string prec_st; // token scale data type std::string prec_st; // token scale data type
std::string prec_sw; // weight scale data type std::string prec_sw; // weight scale data type
std::string prec_sq; // smooth quant scale std::string prec_sq; // smooth quant scale
std::string prec_kw; // topk-weight data type std::string prec_kw; // topk-weight data type
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant int block_m;
int gate_only;
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
}; };
float fused_moegemm(fused_moegemm_traits, fused_moegemm_args, const ck_tile::stream_config&); float fused_moegemm(fused_moegemm_traits, fused_moegemm_args, const ck_tile::stream_config&);
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
// Note: this internal API only declare, not define here, otherwise will block `make -j`
template <typename Traits_>
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a);
float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile::stream_config& s)
{
template <ck_tile::index_t... Is>
using S = ck_tile::sequence<Is...>;
float r = -1;
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && block_m == 32 &&
gate_only == 1)
{
using t_ = fmoe_<ck_tile::bf16_t,
ck_tile::bf16_t,
ck_tile::bf16_t,
float,
float,
float,
float,
S<32, 512, 128, 128>,
S<4, 1, 1>,
S<32, 32, 16>,
1,
0>;
fused_moegemm_<t_>(s, a);
}
return r;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "fused_moegemm_api_traits.hpp"
#include "ck_tile/ops/fused_moe.hpp"
template <typename Ts_>
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
{
using f_traits = ck_tile::FusedMoeGemmTraits<Ts_::GateOnly, Ts_::FusedQuant == 1, 1 /*atomic*/>;
using f_shape = ck_tile::FusedMoeGemmShape<typename Ts_::BlockTile_0,
typename Ts_::WarpPerBlock_0,
typename Ts::WarpTile_0,
typename Ts_::BlockTile_1,
typename Ts_::WarpPerBlock_0,
typename Ts::WarpTile_0>;
using f_problem = ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
typename Ts_::GDataType,
typename Ts_::DDataType,
typename Ts_::AccDataType,
typename Ts_::ODataType,
typename Ts_::AScaleDataType,
typename Ts_::GScaleDataType,
typename Ts_::DScaleDataType,
typename Ts_::YSmoothScaleDataType,
typename Ts_::TopkWeightDataType,
typename Ts_::IndexDataType,
ck_tile::Gelu, // TODO: hardcoded
f_shape,
f_traits>
using f_pipeline = ck_tile::FusedMoeGemmPipeline_Flatmm<f_problem>;
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>;
const dim3 grids = f_kernel::GridSize(a);
constexpr dim3 blocks = f_kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = f_kernel::MakeKargs(a);
if(s.log_level_ > 0)
std::cout << ", " << f_kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(f_kernel{}, grids, blocks, 0, kargs));
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename I,
typename W,
typename O,
typename ST,
typename SW,
typename SQ,
typename KW,
typename BlockTIle_, // seq<b_token, b_interm, b_hidden, b_down>
typename WarpPerBlock_,
typename WarpTile_, // seq<*,*,*>, used to select mfma
ck_tile::index_t GateOnly_ = 0,
ck_tile::index_t FusedQuant_ = 0>
struct fmoe_ // traits, ugly name, only used for internal
{
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
using ADataType = remove_cvref_t<typename TypeConfig::ADataType>;
using GDataType = remove_cvref_t<typename TypeConfig::GDataType>;
using DDataType = remove_cvref_t<typename TypeConfig::DDataType>;
using AccDataType = remove_cvref_t<typename TypeConfig::AccDataType>;
using ODataType = remove_cvref_t<typename TypeConfig::ODataType>;
using AScaleDataType = remove_cvref_t<typename TypeConfig::AScaleDataType>;
using GScaleDataType = remove_cvref_t<typename TypeConfig::GScaleDataType>;
using DScaleDataType = remove_cvref_t<typename TypeConfig::DScaleDataType>;
using YSmoothScaleDataType = remove_cvref_t<typename TypeConfig::YSmoothScaleDataType>;
using TopkWeightDataType = remove_cvref_t<typename TypeConfig::TopkWeightDataType>;
using IndexDataType = remove_cvref_t<typename TypeConfig::IndexDataType>;
static constexpr index_t BT_ = BlockTIle_::at(number<0>{}); // block token
static constexpr index_t BI_ = BlockTIle_::at(number<1>{}); // block intermediate
static constexpr index_t BH_ = BlockTIle_::at(number<2>{}); // block hidden
static constexpr index_t BD_ = BlockTIle_::at(number<3>{}); // block down
using BlockTile_0 = ck_tile::sequence<BT_, BI_, BH_>;
using WarpPerBlock_0 = remove_cvref_t<WarpPerBlock_>;
using WarpTile_0 = remove_cvref_t<WarpTile_>;
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>;
using WarpPerBlock_1 = remove_cvref_t<WarpPerBlock_>;
using WarpTile_1 = remove_cvref_t<WarpTile_>;
static constexpr ck_tile::index_t GateOnly = GateOnly_;
static constexpr ck_tile::index_t FusedQuant = FusedQuant_;
};
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "layernorm2d_fwd.hpp" #include "fused_moegemm.hpp"
#include <algorithm> #include <algorithm>
#include <cstring> #include <cstring>
#include <unordered_set>
#include <vector>
#include <set>
// different threshold for different dtype // different threshold for different dtype
template <typename DataType> template <typename DataType>
...@@ -20,18 +23,64 @@ auto get_elimit<ck_tile::bf16_t>() ...@@ -20,18 +23,64 @@ auto get_elimit<ck_tile::bf16_t>()
return ck_tile::make_tuple(rtol, atol); return ck_tile::make_tuple(rtol, atol);
} }
// mfma_type, 0:32x32, 1:16x16 // mfma_type, 0:32x32, 1:16x16
template<typename H> // TODO: padding?
auto shuffle_moe_weight(const H& t, std::string mfma_dtype, int mfma_type = 0) template <typename T>
auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, int mfma_type = 0)
{ {
static_assert(t.get_lengths().size() == 3); static_assert(t.get_lengths().size() == 3);
int b_ = t.get_lengths()[0]; int b_ = t.get_lengths()[0];
int n_ = t.get_lengths()[1]; int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[2]; int k_ = t.get_lengths()[2];
if ((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0) { if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0)
std::vector<ck_tile::index_t> new_lens {b_, n_/32, 32, k_/16, 2, 8}; {
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 16, 2, 8});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
}
else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1)
{
ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 32, 4, 8});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
}
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0)
{
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 32, 2, 16});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
}
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1)
{
ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 64, 4, 16});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
}
return t;
} }
template <typename IndexType>
void topid_unique_gen(
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
{
size_t total_size = topk * tokens;
std::srand(seed);
std::set<IndexType> unique_set;
IndexType current_v;
for(size_t i = 0; i < total_size; i++)
{
if(i % topk == 0)
{
unique_set.clear();
}
current_v = std::rand() % num_expert;
while(unique_set.find(current_v) != unique_set.end())
{
current_v = std::rand() % num_expert;
}
unique_set.insert(current_v);
host_tensor[i] = current_v;
}
} }
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
...@@ -55,8 +104,11 @@ auto create_args(int argc, char* argv[]) ...@@ -55,8 +104,11 @@ auto create_args(int argc, char* argv[])
.insert("prec_sq", "auto", "(dynamic) smooth quant data type. auto will set to fp32") .insert("prec_sq", "auto", "(dynamic) smooth quant data type. auto will set to fp32")
.insert("prec_kw", "auto", "topk-weight data type. auto will set to fp32") .insert("prec_kw", "auto", "topk-weight data type. auto will set to fp32")
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
.insert("gonly", "0", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate") .insert(
.insert("balance", "1", "if set to 1, will try balance the expert in topk-ids(convenient for testing)") "gate_only", "0", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
.insert("balance",
"1",
"if set to 1, will try balance the expert in topk-ids(convenient for testing)")
.insert("warmup", "5", "cold iter") .insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter"); .insert("repeat", "20", "hot iter");
...@@ -64,133 +116,178 @@ auto create_args(int argc, char* argv[]) ...@@ -64,133 +116,178 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
} }
// I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type, SQ:smooth-quant-type, KW:topk-weight-type // I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type,
// SQ:smooth-quant-type, KW:topk-weight-type
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW> template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
ck_tile::index_t tokens = arg_parser.get_int("t"); ck_tile::index_t tokens = arg_parser.get_int("t");
ck_tile::index_t experts = arg_parser.get_int("e"); ck_tile::index_t experts = arg_parser.get_int("e");
ck_tile::index_t topk = arg_parser.get_int("k"); ck_tile::index_t topk = arg_parser.get_int("k");
ck_tile::index_t hidden_size = arg_parser.get_int("h"); ck_tile::index_t hidden_size = arg_parser.get_int("h");
ck_tile::index_t intermediate_size = arg_parser.get_int("i"); ck_tile::index_t intermediate_size = arg_parser.get_int("i");
ck_tile::index_t stride = arg_parser.get_int("stride"); ck_tile::index_t stride = arg_parser.get_int("stride");
ck_tile::index_t block_m = arg_parser.get_int("bm"); ck_tile::index_t block_m = arg_parser.get_int("bm");
if(stride < 0) if(stride < 0)
stride = hidden_size; stride = hidden_size;
std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_w = arg_parser.get_str("prec_w"); std::string prec_w = arg_parser.get_str("prec_w");
std::string prec_o = arg_parser.get_str("prec_o"); std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_st = arg_parser.get_str("prec_st"); std::string prec_st = arg_parser.get_str("prec_st");
std::string prec_sw = arg_parser.get_str("prec_sw"); std::string prec_sw = arg_parser.get_str("prec_sw");
std::string prec_sq = arg_parser.get_str("prec_sq"); std::string prec_sq = arg_parser.get_str("prec_sq");
std::string prec_kw = arg_parser.get_str("prec_kw"); std::string prec_kw = arg_parser.get_str("prec_kw");
prec_st = (prec_st == "auto") ? "fp32" : prec_st; prec_st = (prec_st == "auto") ? "fp32" : prec_st;
prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw; prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq; prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw; prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
int kname = arg_parser.get_int("kname"); int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v"); int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup"); int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat"); int repeat = arg_parser.get_int("repeat");
int fused_quant = arg_parser.get_int("fquant"); int fused_quant = arg_parser.get_int("fquant");
int gonly = arg_parser.get_int("gonly"); int gate_only = arg_parser.get_int("gate_only");
int balance = arg_parser.get_int("balance"); int balance = arg_parser.get_int("balance");
int tp = arg_parser.get_int("tp"); int tp = arg_parser.get_int("tp");
ck_tile::index_t shared_intermediate_size = intermediate_size * (gonly ? 1 : 2) / tp;
ck_tile::index_t shared_intermediate_size = intermediate_size * (gate_only ? 1 : 2) / tp;
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>; using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
using ADataType = typename TypeConfig::ADataType ; using ADataType = typename TypeConfig::ADataType;
using GDataType = typename TypeConfig::GDataType ; using GDataType = typename TypeConfig::GDataType;
using DDataType = typename TypeConfig::DDataType ; using DDataType = typename TypeConfig::DDataType;
using AccDataType = typename TypeConfig::AccDataType ; using AccDataType = typename TypeConfig::AccDataType;
using ODataType = typename TypeConfig::ODataType ; using ODataType = typename TypeConfig::ODataType;
using AScaleDataType = typename TypeConfig::AScaleDataType ; using AScaleDataType = typename TypeConfig::AScaleDataType;
using W0ScaleDataType = typename TypeConfig::W0ScaleDataType ; using GScaleDataType = typename TypeConfig::GScaleDataType;
using W1ScaleDataType = typename TypeConfig::W1ScaleDataType ; using DScaleDataType = typename TypeConfig::DScaleDataType;
using YSmoothScaleDataType = typename TypeConfig::YSmoothScaleDataType; using YSmoothScaleDataType = typename TypeConfig::YSmoothScaleDataType;
using TopkWeightDataType = typename TypeConfig::TopkWeightDataType ; using TopkWeightDataType = typename TypeConfig::TopkWeightDataType;
using IndexDataType = typename TypeConfig::IndexDataType ; using IndexDataType = typename TypeConfig::IndexDataType;
// host verify // host verify
ck_tile::HostTensor<ADataType> a_host({tokens, hidden_size}, {stride, 1}); ck_tile::HostTensor<ADataType> a_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<ADataType> g_host({e, shared_intermediate_size, hidden_size}); ck_tile::HostTensor<GDataType> g_host({e, shared_intermediate_size, hidden_size});
ck_tile::HostTensor<ADataType> d_host({e, intermediate_size, hidden_size}); ck_tile::HostTensor<DDataType> d_host({e, intermediate_size, hidden_size});
ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<AScaleDataType> sa_host({tokens});
ck_tile::HostTensor<XResidualDataType> x_residual_host({m, n}, {stride, 1}); ck_tile::HostTensor<GScaleDataType> sg_host({shared_intermediate_size});
ck_tile::HostTensor<YResidualDataType> y_residual_host({m, n}, {stride, 1}); ck_tile::HostTensor<DScaleDataType> sd_host({intermediate_size});
ck_tile::HostTensor<YSmoothScaleDataType> sy_host({intermediate_size}); // smooth-quant
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {stride, 1}); ck_tile::HostTensor<IndexDataType> topk_ids_host({tokens, topk}); // to be sort
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {stride, 1}); ck_tile::HostTensor<TopkWeightDataType> topk_weight_host({tokens, topk}); // to be sort
ck_tile::HostTensor<MeanDataType> mean_host_ref({m}); int max_num_tokens_padded = topk * tokens + experts * (block_m - 1);
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({m}); ck_tile::HostTensor<IndexDataType> sorted_token_ids_host({max_num_tokens_padded});
ck_tile::HostTensor<YScaleDataType> y_scale_host_ref({m}); ck_tile::HostTensor<TopkWeightDataType> sorted_weight_host({max_num_tokens_padded});
ck_tile::HostTensor<YScaleDataType> y_scale_host_dev({m}); ck_tile::HostTensor<IndexDataType> sorted_expert_ids_host(
{(max_num_tokens_padded + block_m - 1) / block_m});
ck_tile::HostTensor<XScaleDataType> x_scale_host({n}); ck_tile::HostTensor<IndexDataType> num_sorted_tiles_host({1});
ck_tile::HostTensor<XScaleDataType> x_scale_host_dev({n});
// permute weight
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w);
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w);
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host); ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host); ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f}(g_perm_host);
ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host); ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f}(d_perm_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host); ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f}(sa_host);
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host); ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f}(sg_host);
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f}(sd_host);
ck_tile::DeviceMem x_buf(a_host.get_element_space_size_in_bytes()); ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f}(sy_host);
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f}(topk_weight_host);
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); // do moe sorting
ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes()); if(balance)
ck_tile::DeviceMem x_scale_buf(x_scale_host_dev.get_element_space_size_in_bytes()); {
int e_cnt = 0 for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++)
ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes()); {
ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes()); topk_ids_host.mData[i] = e_cnt;
e_cnt++;
x_buf.ToDevice(a_host.data()); if(e_cnt >= experts)
gamma_buf.ToDevice(gamma_host.data()); e_cnt = 0;
beta_buf.ToDevice(beta_host.data()); }
x_residual_buf.ToDevice(x_residual_host.data()); }
x_scale_buf.ToDevice(x_scale_host.data()); else
{
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, experts, 11913);
}
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
topk_ids_host,
topk_weight_host,
sorted_token_ids_host,
sorted_weight_host,
sorted_expert_ids_host,
num_sorted_tiles_host.mData[0],
experts,
block_m);
// done, preparing GPU buffer
ck_tile::DeviceMem a_buf(a_host);
ck_tile::DeviceMem g_perm_buf(g_perm_host);
ck_tile::DeviceMem d_perm_buf(d_perm_host);
ck_tile::DeviceMem sa_buf(sa_host);
ck_tile::DeviceMem sg_buf(sg_host);
ck_tile::DeviceMem sd_buf(sd_host);
ck_tile::DeviceMem sy_buf(sy_host);
ck_tile::DeviceMem o_buf(o_host);
ck_tile::DeviceMem sorted_token_ids_buf(sorted_token_ids_host);
ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host);
ck_tile::DeviceMem sorted_expert_ids_buf(sorted_expert_ids_host);
ck_tile::DeviceMem num_sorted_tiles_buf(num_sorted_tiles_host);
auto prec_str = [&]() { auto prec_str = [&]() {
auto base_str = prec_i; auto base_str = prec_i;
if(prec_i != prec_w)
base_str += "x" + prec_w;
if(prec_i != prec_o) if(prec_i != prec_o)
base_str += "=" + prec_o;
if(fused_quant != 0)
{ {
base_str += "|" + prec_o; base_str += std::string("(") + prec_sa + "|" + prec_sg + "|" + prec_sq + ")";
}
if(fused_quant == 1)
{
base_str += std::string("(") + prec_sy + ")";
} }
return base_str; return base_str;
}(); }();
std::cout << "[" << prec_str << "]" std::cout << "[" << prec_str << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; << " t:" << tokens << ", e:" << experts << ", k:" << topk << ", st:" << stride
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
layernorm2d_fwd_traits traits{ << ", go:" << gate_only << ", q:" << fused_quant << std::flush;
prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant};
fused_moegemm_traits traits{prec_i,
layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), prec_w,
fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr, prec_o,
fused_quant == 1 ? x_scale_buf.GetDeviceBuffer() : nullptr, prec_st,
gamma_buf.GetDeviceBuffer(), prec_sw,
beta_buf.GetDeviceBuffer(), prec_sq,
prec_kw,
y_buf.GetDeviceBuffer(), block_m,
fused_add == 1 ? y_residual_buf.GetDeviceBuffer() : nullptr, gate_only,
fused_quant != 0 ? y_scale_buf.GetDeviceBuffer() : nullptr, fused_quant};
nullptr, // p_mean, unsupported yet
nullptr, // p_invStd, unsupported yet fused_moegemm_args args{a_buf.GetDeviceBuffer(),
fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr,
epsilon, g_buf.GetDeviceBuffer(),
m, d_buf.GetDeviceBuffer(),
n, fused_quant != 0
stride}; ? sg_buf.GetDeviceBuffer(),
fused_quant != 0
float ave_time = layernorm2d_fwd( ? sd_buf.GetDeviceBuffer(),
fused_quant == 1
? sy_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
sorted_token_ids_buf.GetDeviceBuffer(),
sorted_weight_buf.GetDeviceBuffer(),
sorted_expert_ids_buf.GetDeviceBuffer(),
num_sorted_tiles_buf.GetDeviceBuffer(),
hidden_size,
intermediate_size,
num_tokens,
experts,
stride };
float ave_time = fused_moegemm(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
if(ave_time < 0) if(ave_time < 0)
...@@ -199,22 +296,30 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -199,22 +296,30 @@ bool run(const ck_tile::ArgParser& arg_parser)
return false; return false;
} }
#if 0
std::size_t num_byte = sizeof(ADataType) * m * n + sizeof(GammaDataType) * n + std::size_t num_byte = sizeof(ADataType) * m * n + sizeof(GammaDataType) * n +
sizeof(BetaDataType) * n + sizeof(YDataType) * m * n; sizeof(BetaDataType) * n + sizeof(YDataType) * m * n;
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
#else
std::size_t flop_gemm_0 = 2 * tokens * topk * shared_intermediate_size * hidden_size;
std::size_t flop_gemm_1 = 2 * tokens * topk * hidden_size * hidden_size;
double tflops = (flop_gemm_0 + flop_gemm_1) / (static_cast<double>(ave_time) * 1e-3) / 1e12;
// float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << tflops << " tflops" << std::flush;
#endif
bool pass = true; bool pass = true;
if(do_validation) if(do_validation)
{ {
#if 0
// reference // reference
if(fused_add != 0) if(fused_add != 0)
{ {
// fused pre_add/pre_add_store // fused pre_add/pre_add_store
// TODO we accumulate directly to a_host for simplcity here... // TODO we accumulate directly to a_host for simplcity here...
std::transform(a_host.mData.cbegin(), std::transform(a_host.mData.cbegin(),
a_host.mData.cend(), a_host.mData.cend(),
x_residual_host.mData.cbegin(), x_residual_host.mData.cbegin(),
...@@ -353,6 +458,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -353,6 +458,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
#else
std::cout << std::flush << std::endl;
#endif
} }
return pass; return pass;
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp" #include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
#include "ck_tile/host/reference/reference_permute.hpp" #include "ck_tile/host/reference/reference_permute.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp" #include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp"
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <stdint.h> #include <stdint.h>
#include <stdexcept> #include <stdexcept>
#include "ck_tile/host/hip_check_error.hpp" #include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile { namespace ck_tile {
template <typename T> template <typename T>
...@@ -36,6 +37,19 @@ struct DeviceMem ...@@ -36,6 +37,19 @@ struct DeviceMem
mpDeviceBuf = nullptr; mpDeviceBuf = nullptr;
} }
} }
template <T>
DeviceMem(const HostTensor<T>& t) : mMemSize(t.get_element_space_size_in_bytes())
{
if(mMemSize != 0)
{
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
else
{
mpDeviceBuf = nullptr;
}
ToDevice(t.data());
}
void Realloc(std::size_t mem_size) void Realloc(std::size_t mem_size)
{ {
if(mpDeviceBuf) if(mpDeviceBuf)
...@@ -92,6 +106,22 @@ struct DeviceMem ...@@ -92,6 +106,22 @@ struct DeviceMem
HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
} }
} }
// construct a host tensor with type T
template <typename T>
HostTensor<T> ToHost(std::size_t cpySize = mMemSize)
{
// TODO: host tensor could be slightly larger than the device tensor
// we just copy all data from GPU buffer
std::size_t host_elements =
(cpySize + sizeof(T) - 1) / sizeof(T) HostTensor<T> h_({host_elements});
if(mpDeviceBuf)
{
HIP_CHECK_ERROR(hipMemcpy(h_.data(), mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
}
return h_;
}
void SetZero() const void SetZero() const
{ {
if(mpDeviceBuf) if(mpDeviceBuf)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
template <typename WeightType, typename IndexType = index_t>
CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
const HostTensor<WeightType>& weights,
HostTensor<IndexType>& sorted_token_ids,
HostTensor<WeightType>& sorted_weight,
HostTensor<IndexType>& sorted_expert_ids,
index_t& unit_cnt,
const index_t experts,
const index_t unit_size)
{
const index_t num_token = topk_ids.mDesc.get_lengths()[0];
const index_t topk = topk_ids.mDesc.get_lengths()[1];
std::vector<std::vector<IndexType>> expert_tokens(experts,
std::vector<IndexType>(unit_size, num_token));
std::vector<std::vector<WeightType>> expert_token_weights(
experts, std::vector<WeightType>(unit_size, 0));
std::vector<IndexType> expert_slices(experts, 1);
std::vector<IndexType> expert_slice_idxs(experts, 0);
for(index_t t = 0; t < num_token; t++)
{
for(index_t k = 0; k < topk; k++)
{
IndexType e = topk_ids(t, k);
WeightType w = weights(t, k);
index_t idx = expert_slice_idxs[e];
if(idx > expert_slices[e] * unit_size - 1)
{
expert_slices[e]++;
index_t new_size = expert_slices[e] * unit_size;
expert_tokens[e].resize(new_size);
expert_token_weights[e].resize(new_size);
for(index_t i = (expert_slices[e] - 1) * unit_size; i < new_size; i++)
{
expert_tokens[e][i] = num_token;
expert_token_weights[e][i] = 0;
}
}
expert_tokens[e][idx] = t;
expert_token_weights[e][idx] = w;
expert_slice_idxs[e]++;
}
}
IndexType* out_tokens = sorted_token_ids.data();
WeightType* out_weights = sorted_weight.data();
IndexType* out_expert_id = sorted_expert_ids.data();
for(index_t e = 0; e < experts; e++)
{
memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size);
out_tokens += expert_slices[e] * unit_size;
memcpy(out_weights,
expert_token_weights[e].data(),
sizeof(WeightType) * expert_slices[e] * unit_size);
out_weights += expert_slices[e] * unit_size;
for(index_t s = 0; s < expert_slices[e]; s++)
{
out_expert_id[s] = e;
unit_cnt++;
}
out_expert_id += expert_slices[e];
}
return;
}
} // namespace ck_tile
...@@ -56,11 +56,10 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v ...@@ -56,11 +56,10 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
} }
template <typename DataType> template <typename DataType>
CK_TILE_HOST auto CK_TILE_HOST auto reference_permute(const HostTensor<DataType>& x, std::vector<index_t> perm)
reference_permute(const HostTensor<DataType>& x, std::vector<index_t> perm)
{ {
auto x_shape = x.get_lengths(); auto x_shape = x.get_lengths();
ck_tile::index_t rank = perm.size(); ck_tile::index_t rank = perm.size();
std::vector<ck_tile::index_t> y_shape = [&]() { std::vector<ck_tile::index_t> y_shape = [&]() {
std::vector<ck_tile::index_t> tmp(rank, 0); std::vector<ck_tile::index_t> tmp(rank, 0);
for(int i = 0; i < static_cast<int>(rank); i++) for(int i = 0; i < static_cast<int>(rank); i++)
......
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
#pragma once #pragma once
#include "ck_tile/ops/fused_moe/kernel/fused_moe_kernel.hpp" #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moe_shape.hpp" #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moe_tile_partitioner.hpp" #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_flatmm.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_flatmm_policy.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_problem.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_traits.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -22,17 +22,17 @@ ...@@ -22,17 +22,17 @@
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
// //
// max_tokens_post_padded : top_k * input_tokens + num_experts * (M_a - 1) // max_num_tokens_padded : top_k * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU // * this could be larger than actual, since actual tokens are on GPU
// //
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -| // |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o] // sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
// //
// * length is max_tokens_post_padded, actual size is num_tokens_post_padded_ptr // * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
// //
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5] // sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_tokens_post_padded + block_size - 1) / block_size // * length is (max_num_tokens_padded + block_size - 1) / block_size
// //
// num_tokens_post_padded_ptr : [28] // num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7] // num_sorted_tiles_ptr : [7]
...@@ -43,11 +43,12 @@ ...@@ -43,11 +43,12 @@
// 3) use num_sorted_tiles_ptr, already divided by M_a // 3) use num_sorted_tiles_ptr, already divided by M_a
// //
// * below used for indexing // * below used for indexing
// 1) sorted_token_ids_ptr // 1) sorted_token_ids_ptr [max_num_tokens_padded]
// 2) sorted_weight_ptr // 2) sorted_weight_ptr
// 3) sorted_expert_ids_ptr // 3) sorted_expert_ids_ptr
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one) // 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
// //
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
// //
// [indexing implementation-2] // [indexing implementation-2]
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]] // before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
...@@ -92,15 +93,15 @@ struct FusedMoeGemmHostArgs ...@@ -92,15 +93,15 @@ struct FusedMoeGemmHostArgs
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
void* o_ptr; // [m, k], output token void* o_ptr; // [m, k], output token
const void* sorted_token_ids_ptr; const void* sorted_token_ids_ptr; // [max_num_tokens_padded]
const void* sorted_weight_ptr; const void* sorted_weight_ptr; // [max_num_tokens_padded]
const void* sorted_expert_ids_ptr; const void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) / block_size]
const void* num_sorted_tiles_ptr; const void* num_sorted_tiles_ptr; // [1]
index_t hidden_size; // k index_t hidden_size; // k
index_t intermediate_size; // n (TP slice this) index_t intermediate_size; // n (TP slice this)
index_t num_tokens; // input number of tokens for current iteration index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups index_t num_experts; // number of groups
// index_t top_k; // need this? // index_t top_k; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size index_t stride_token; // for input/output, stride for each row, should >= hidden_size
...@@ -134,10 +135,10 @@ struct FusedMoeGemmKernel ...@@ -134,10 +135,10 @@ struct FusedMoeGemmKernel
using Traits = typename Pipeline::Problem::Traits; using Traits = typename Pipeline::Problem::Traits;
static constexpr bool IsGateOnly = Traits::IsGateOnly; static constexpr bool IsGateOnly = Traits::IsGateOnly;
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant; static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
static constexpr bool PadHiddenSize = Traits::PadHiddenSize; static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize; static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
// clang-format off // clang-format off
template <typename T> struct t2s; template <typename T> struct t2s;
...@@ -173,10 +174,10 @@ struct FusedMoeGemmKernel ...@@ -173,10 +174,10 @@ struct FusedMoeGemmKernel
const void* sorted_expert_ids_ptr; const void* sorted_expert_ids_ptr;
const void* num_sorted_tiles_ptr; const void* num_sorted_tiles_ptr;
index_t hidden_size; // k index_t hidden_size; // k
index_t intermediate_size; // n (TP slice this) index_t intermediate_size; // n (TP slice this)
index_t num_tokens; // input number of tokens for current iteration index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups index_t num_experts; // number of groups
// index_t top_k; // need this? // index_t top_k; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size index_t stride_token; // for input/output, stride for each row, should >= hidden_size
...@@ -214,7 +215,7 @@ struct FusedMoeGemmKernel ...@@ -214,7 +215,7 @@ struct FusedMoeGemmKernel
index_t nr_0 = kargs.intermediate_size / Pipeline::Block_Nr0; index_t nr_0 = kargs.intermediate_size / Pipeline::Block_Nr0;
index_t kr_0 = kargs.hidden_size / Pipeline::Block_Kr0; index_t kr_0 = kargs.hidden_size / Pipeline::Block_Kr0;
index_t nr_1 = kargs.hidden_size / Pipeline::Block_Nr1; // should be same as kr_0 index_t nr_1 = kargs.hidden_size / Pipeline::Block_Nr1; // should be same as kr_0
index_t kr_1 = kargs.intermediate_size / Pipeline::Block_Kr1; // should be same as nr_0 index_t kr_1 = kargs.intermediate_size / Pipeline::Block_Kr1; // should be same as nr_0
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size; index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
...@@ -280,11 +281,12 @@ struct FusedMoeGemmKernel ...@@ -280,11 +281,12 @@ struct FusedMoeGemmKernel
make_tuple(kr_0 * BlockShape::Block_W0, number<Pipeline::Block_W0>{}, 1), make_tuple(kr_0 * BlockShape::Block_W0, number<Pipeline::Block_W0>{}, 1),
number<Pipeline::kAlignmentG>{}, number<Pipeline::kAlignmentG>{},
number<1>{}); number<1>{});
const auto g_view_1_ = pad_tensor_view(g_view_, const auto g_view_1_ =
make_tuple(number<Pipeline::Block_Nr0>{}, pad_tensor_view(g_view_,
number<Pipeline::Block_Kr0>{}, make_tuple(number<Pipeline::Block_Nr0>{},
number<Pipeline::Block_W0>{}), number<Pipeline::Block_Kr0>{},
sequence<PadIntermediateSize, PadHiddenSize, 0>{}); number<Pipeline::Block_W0>{}),
sequence<PadIntermediateSize, PadHiddenSize, 0>{});
const auto g_window_ = make_tile_window(g_view_1_, const auto g_window_ = make_tile_window(g_view_1_,
make_tuple(number<BlockShape::Block_Nr0>{}, make_tuple(number<BlockShape::Block_Nr0>{},
...@@ -308,11 +310,12 @@ struct FusedMoeGemmKernel ...@@ -308,11 +310,12 @@ struct FusedMoeGemmKernel
make_tuple(kr_1 * Pipeline::Block_W1, Pipeline::Block_W1, 1), make_tuple(kr_1 * Pipeline::Block_W1, Pipeline::Block_W1, 1),
number<Pipeline::kAlignmentD>{}, number<Pipeline::kAlignmentD>{},
number<1>{}); number<1>{});
const auto d_view_1_ = pad_tensor_view(d_view_, const auto d_view_1_ =
make_tuple(number<Pipeline::kBlockNr_1>{}, pad_tensor_view(d_view_,
number<Pipeline::kBlockKr_1>{}, make_tuple(number<Pipeline::kBlockNr_1>{},
number<Pipeline::Block_W1>{}), number<Pipeline::kBlockKr_1>{},
sequence<PadHiddenSize, PadIntermediateSize, 0>{}); number<Pipeline::Block_W1>{}),
sequence<PadHiddenSize, PadIntermediateSize, 0>{});
const auto d_window_ = make_tile_window(d_view_1_, const auto d_window_ = make_tile_window(d_view_1_,
make_tuple(number<Pipeline::kBlockNr_1>{}, make_tuple(number<Pipeline::kBlockNr_1>{},
......
...@@ -44,10 +44,10 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -44,10 +44,10 @@ struct FusedMoeGemmPipeline_Flatmm
using Traits = typename Pipeline::Problem::Traits; using Traits = typename Pipeline::Problem::Traits;
static constexpr bool IsGateOnly = Traits::IsGateOnly; static constexpr bool IsGateOnly = Traits::IsGateOnly;
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant; static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
static constexpr bool PadHiddenSize = Traits::PadHiddenSize; static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize; static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
static constexpr index_t kAlignmentA = Policy::GetAlignment_A<Problem>(); static constexpr index_t kAlignmentA = Policy::GetAlignment_A<Problem>();
static constexpr index_t kAlignmentG = Policy::GetAlignment_G<Problem>(); static constexpr index_t kAlignmentG = Policy::GetAlignment_G<Problem>();
...@@ -133,11 +133,12 @@ struct FusedMoeGemmPipeline_Flatmm ...@@ -133,11 +133,12 @@ struct FusedMoeGemmPipeline_Flatmm
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1), make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
number<kAlignmentG>{}, number<kAlignmentG>{},
number<1>{}); number<1>{});
const auto u_view_1_ = pad_tensor_view(u_view_, const auto u_view_1_ =
make_tuple(number<BlockShape::Block_Nr0>{}, pad_tensor_view(u_view_,
number<BlockShape::Block_Kr0>{}, make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_W0>{}), number<BlockShape::Block_Kr0>{},
sequence<PadIntermediateSize, PadHiddenSize, 0>{}); number<BlockShape::Block_W0>{}),
sequence<PadIntermediateSize, PadHiddenSize, 0>{});
return u_view_1_; return u_view_1_;
} }
}(); }();
......
...@@ -225,7 +225,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -225,7 +225,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetMatrixCoreSwizzledBlockTIle_0() CK_TILE_HOST_DEVICE static constexpr auto GetMatrixCoreSwizzledBlockTIle_0()
{ {
if constexpr(Problem::Traits::PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten) if constexpr(Problem::Traits::PermuteEnum ==
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
{ {
using WarpGemm = GetWarpGemm0<Problem>{}; // assume warpgemm0/1 are the same using WarpGemm = GetWarpGemm0<Problem>{}; // assume warpgemm0/1 are the same
constexpr index_t NPerBlock = Problem::BlockShape::Block_N0; constexpr index_t NPerBlock = Problem::BlockShape::Block_N0;
...@@ -703,7 +704,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -703,7 +704,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetMatrixCoreSwizzledBlockTIle_0() CK_TILE_HOST_DEVICE static constexpr auto GetMatrixCoreSwizzledBlockTIle_0()
{ {
if constexpr(Problem::Traits::PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten) if constexpr(Problem::Traits::PermuteEnum ==
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
{ {
using WarpGemm = GetWarpGemm0<Problem>{}; // assume warpgemm0/1 are the same using WarpGemm = GetWarpGemm0<Problem>{}; // assume warpgemm0/1 are the same
constexpr index_t NPerBlock = Problem::BlockShape::Block_N0; constexpr index_t NPerBlock = Problem::BlockShape::Block_N0;
...@@ -723,7 +725,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -723,7 +725,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetMatrixCoreSwizzledBlockTIle_1() CK_TILE_HOST_DEVICE static constexpr auto GetMatrixCoreSwizzledBlockTIle_1()
{ {
if constexpr(Problem::Traits::PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten) if constexpr(Problem::Traits::PermuteEnum ==
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
{ {
using WarpGemm = GetWarpGemm1<Problem>{}; // assume warpgemm0/1 are the same using WarpGemm = GetWarpGemm1<Problem>{}; // assume warpgemm0/1 are the same
constexpr index_t NPerBlock = Problem::BlockShape::kBlockN_1; constexpr index_t NPerBlock = Problem::BlockShape::kBlockN_1;
......
...@@ -14,8 +14,8 @@ template <typename ADataType_, ...@@ -14,8 +14,8 @@ template <typename ADataType_,
typename AccDataType_, typename AccDataType_,
typename ODataType_, typename ODataType_,
typename AScaleDataType_, typename AScaleDataType_,
typename W0ScaleDataType_, typename GScaleDataType_,
typename W1ScaleDataType_, typename DScaleDataType_,
typename YSmoothScaleDataType_, typename YSmoothScaleDataType_,
typename TopkWeightDataType_, typename TopkWeightDataType_,
typename IndexDataType_, // data type for all indexing typename IndexDataType_, // data type for all indexing
......
...@@ -19,14 +19,18 @@ enum class FusedMoeGemmWeightPermuteEnum ...@@ -19,14 +19,18 @@ enum class FusedMoeGemmWeightPermuteEnum
template <bool IsGateOnly_, template <bool IsGateOnly_,
bool UseSmoothQuant_, bool UseSmoothQuant_,
index_t OAtomic_, // 0-no atomic, 1-atomic-pk-f16/bf16, 2-atomic-f32 index_t OAtomic_, // 0-no atomic, 1-atomic-pk-f16/bf16, 2-atomic-f32
FusedMoeGemmWeightPermuteEnum PermuteEnum_ = FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten; FusedMoeGemmWeightPermuteEnum PermuteEnum_ =
bool PadHiddenSize_ = false, bool PadIntermediateSize_ = false > struct FusedMoeGemmTraits FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten,
bool PadHiddenSize_ = false,
bool PadIntermediateSize_ = false>
struct FusedMoeGemmTraits
{ {
// Gate+Up or Gate only // Gate+Up or Gate only
static constexpr bool IsGateOnly = IsGateOnly_; static constexpr bool IsGateOnly = IsGateOnly_;
static constexpr bool UseSmoothQuant = UseSmoothQuant_; static constexpr bool UseSmoothQuant = UseSmoothQuant_;
static constexpr index_t OAtomic = OAtomic_; static constexpr index_t OAtomic = OAtomic_;
static constexpr bool PadHiddenSize = PadHiddenSize_; static constexpr FusedMoeGemmWeightPermuteEnum PermuteEnum = PermuteEnum_;
static constexpr bool PadIntermediateSize = PadIntermediateSize_; static constexpr bool PadHiddenSize = PadHiddenSize_;
static constexpr bool PadIntermediateSize = PadIntermediateSize_;
}; };
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment