Unverified Commit 1ff50e78 authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

[CK_TILE] Fix mock token id, support g1u1/g1u0 through same inline code block (#1808)

* fix mock token id

* prepare host for g1u1

* reformat inline-asm

* restructure uk_0

* restructure gate_up

* done

* change default to init=1

* update readme

* fix a bug in interleave pipeline

* rcp for silu
parent 8c29e06f
......@@ -8,6 +8,9 @@ The benifit of this fused-moe:
* much less kernel instance, easy to maintain
# Implementation and feature support
## NOTES:
currently gate+up in fp16 case will very easily cause accumulator overflow the fp16 max(65504), hence result in INF. Please use BF16 for gate+up case, API side will have no check for this.
## moe-sorting
this is a common pre-process step before the actual moe-gemm. The purpose is to transform the moe loop over from token-by-token to expert-by-expert, make sure very workgroup is working for a single expert (B matrix). Besides, we extend this op to do the zeroing of the output buffer(to be used for reduce buffer with atomic)
......
......@@ -26,7 +26,7 @@ struct fused_moe_args
ck_tile::index_t block_m; // block_m, used to devide the input
ck_tile::index_t hidden_size; // k
ck_tile::index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
ck_tile::index_t intermediate_size; // n / TP, for Gate. and Up, Down is also this value
ck_tile::index_t num_tokens; // input number of tokens for current iteration
ck_tile::index_t num_experts; // number of groups
ck_tile::index_t topk; // need this?
......@@ -45,7 +45,8 @@ struct fused_moe_traits
std::string prec_sq; // smooth quant scale
std::string prec_kw; // topk-weight data type
int block_m;
int gate_only;
int activation; // 0:gelu, 1:silu
int gate_only; // 0:g1u0, 1:g1u1
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};
......
......@@ -77,7 +77,8 @@ struct fused_moegemm_traits
std::string prec_sq; // smooth quant scale
std::string prec_kw; // topk-weight data type
int block_m;
int gate_only;
int activation; // 0:gelu, 1:silu
int gate_only; // 0:g1u0, 1:g1u1
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};
......
......@@ -41,6 +41,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
t.prec_sq,
t.prec_kw,
t.block_m,
t.activation,
t.gate_only,
t.fused_quant};
auto a1 = fused_moegemm_args{
......
......@@ -17,15 +17,67 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
// clang-format off
float r = -1;
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0)
{
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0)
{
constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0)
{
constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0)
{
constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1)
{
constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1)
{
constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1)
{
constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1)
{
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
// clang-format on
......
......@@ -21,21 +21,31 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
typename Ts_::BlockTile_1,
typename Ts_::WarpPerBlock_0,
typename Ts_::WarpTile_0>;
using f_problem =
ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
typename Ts_::GDataType,
typename Ts_::DDataType,
typename Ts_::AccDataType,
typename Ts_::ODataType,
typename Ts_::AScaleDataType,
typename Ts_::GScaleDataType,
typename Ts_::DScaleDataType,
typename Ts_::YSmoothScaleDataType,
typename Ts_::TopkWeightDataType,
typename Ts_::IndexDataType,
ck_tile::element_wise::FastGeluAsm, // TODO: hardcoded
f_shape,
f_traits>;
constexpr auto get_activation_ = []() {
if constexpr(Ts_::Activation == 0)
{
return ck_tile::element_wise::FastGeluAsm{};
}
else
return ck_tile::element_wise::Silu{};
};
using f_act_ = ck_tile::remove_cvref_t<decltype(get_activation_())>;
using f_problem = ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
typename Ts_::GDataType,
typename Ts_::DDataType,
typename Ts_::AccDataType,
typename Ts_::ODataType,
typename Ts_::AScaleDataType,
typename Ts_::GScaleDataType,
typename Ts_::DScaleDataType,
typename Ts_::YSmoothScaleDataType,
typename Ts_::TopkWeightDataType,
typename Ts_::IndexDataType,
f_act_, // TODO: hardcoded
f_shape,
f_traits>;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem>;
......
......@@ -15,7 +15,8 @@ template <typename I,
typename KW,
typename BlockTIle_, // seq<b_token, b_interm, b_hidden, b_down>
typename WarpPerBlock_,
typename WarpTile_, // seq<*,*,*>, used to select mfma
typename WarpTile_, // seq<*,*,*>, used to select mfma
ck_tile::index_t Activation_ = 0, // 0: Gelu 1: Silu
ck_tile::index_t GateOnly_ = 0,
ck_tile::index_t FusedQuant_ = 0>
struct fmoe_ // traits, ugly name, only used for internal
......@@ -44,10 +45,11 @@ struct fmoe_ // traits, ugly name, only used for internal
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>;
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>;
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_>;
using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>;
static constexpr ck_tile::index_t Activation = Activation_; // 0: Gelu 1: Silu
static constexpr ck_tile::index_t GateOnly = GateOnly_;
static constexpr ck_tile::index_t FusedQuant = FusedQuant_;
};
......@@ -8,7 +8,18 @@
// clang-format off
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
......@@ -8,7 +8,19 @@
// clang-format off
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
......@@ -108,12 +108,14 @@ auto create_args(int argc, char* argv[])
.insert(
"gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
.insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm")
.insert("act", "0", "activation after first gemm. 0:gelu, 1:silu")
.insert("balance",
"0",
"if set to 1, will try balance the expert in topk-ids(convenient for testing)")
.insert("init",
"2",
"init method. 0:random stepped float(fast). 1: random uniform, 2:rand normalized"
"1",
"init method. 0:random stepped float(fast). 1: random uniform[-0.5, 0.5], 2:rand "
"normalized[0, 1]"
"normalized(slow)")
.insert("seed", "11939", "seed used to do random")
.insert("warmup", "5", "cold iter")
......@@ -135,6 +137,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t intermediate_size = arg_parser.get_int("i");
ck_tile::index_t stride = arg_parser.get_int("stride");
ck_tile::index_t block_m = arg_parser.get_int("bm");
ck_tile::index_t activation = arg_parser.get_int("act");
if(stride < 0)
stride = hidden_size;
std::string prec_i = arg_parser.get_str("prec_i");
......@@ -194,11 +197,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
return std::string(", st:") + std::to_string(stride);
}();
std::cout << "[" << api_str << "|" << prec_str << "]"
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
<< ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
<< ", go:" << gate_only << ", q:" << fused_quant << std::flush;
std::cout
<< "[" << api_str << "|" << prec_str << "]"
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
<< ", act:"
<< activation
// << ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
<< (gate_only ? ", g1u0" : ", g1u1") << ", q:" << fused_quant << std::flush;
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
using ADataType = typename TypeConfig::ADataType;
......@@ -370,6 +376,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
prec_sq,
prec_kw,
block_m,
activation,
gate_only,
fused_quant};
......@@ -389,7 +396,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_buf.GetDeviceBuffer(),
block_m,
hidden_size,
shared_intermediate_size_0,
intermediate_size / tp,
tokens,
experts,
topk,
......@@ -408,6 +415,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< cal_tbps(ave_time) << " TB/s" << std::flush;
bool pass = true;
#define CPU_FUSED_MOE(act_type_) \
ck_tile::reference_fused_moe<AccDataType, act_type_>(a_host, \
g_host, \
d_host, \
sa_host, \
sg_host, \
sd_host, \
sy_host, \
o_host, \
sorted_token_ids_host, \
sorted_weight_host, \
sorted_expert_ids_host, \
num_sorted_tiles_host, \
topk_ids_host, \
block_m, \
tokens, \
experts, \
hidden_size, \
intermediate_size / tp, \
topk, \
gate_only)
if(do_validation)
{
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
......@@ -419,28 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_host.mData[0],
experts,
block_m);
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>(
a_host,
g_host,
d_host,
sa_host,
sg_host,
sd_host,
sy_host,
o_host,
sorted_token_ids_host,
sorted_weight_host,
sorted_expert_ids_host,
num_sorted_tiles_host,
topk_ids_host,
block_m,
tokens,
experts,
hidden_size,
shared_intermediate_size_0,
topk,
gate_only);
if(activation == 0)
{
CPU_FUSED_MOE(ck_tile::element_wise::Gelu);
}
else
{
CPU_FUSED_MOE(ck_tile::element_wise::Silu);
}
auto o_dev = o_buf.ToHost<ODataType>();
// o_dev.savetxt("gpu-out.txt", "float");
......@@ -491,6 +506,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
prec_sq,
prec_kw,
block_m,
activation,
gate_only,
fused_quant};
......@@ -507,7 +523,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
sorted_expert_ids_buf.GetDeviceBuffer(),
num_sorted_tiles_buf.GetDeviceBuffer(),
hidden_size,
shared_intermediate_size_0,
intermediate_size / tp,
tokens,
experts,
topk,
......@@ -529,27 +545,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation)
{
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>(
a_host,
g_host,
d_host,
sa_host,
sg_host,
sd_host,
sy_host,
o_host,
sorted_token_ids_host,
sorted_weight_host,
sorted_expert_ids_host,
num_sorted_tiles_host,
topk_ids_host,
block_m,
tokens,
experts,
hidden_size,
shared_intermediate_size_0,
topk,
gate_only);
if(activation == 0)
{
CPU_FUSED_MOE(ck_tile::element_wise::Gelu);
}
else
{
CPU_FUSED_MOE(ck_tile::element_wise::Silu);
}
auto o_dev = o_buf.ToHost<ODataType>();
// o_dev.savetxt("gpu-out.txt", "float");
......
......@@ -73,7 +73,7 @@ void reference_fused_moe(
ck_tile::index_t tokens,
ck_tile::index_t experts,
ck_tile::index_t hidden_size,
ck_tile::index_t intermediate_size, // this size is for gate/up
ck_tile::index_t intermediate_size, // this size is for gate/up/down
ck_tile::index_t topk,
ck_tile::index_t gate_only)
{
......@@ -82,19 +82,8 @@ void reference_fused_moe(
assert(sorted_expert_ids_host.get_num_of_dimension() == 1);
assert(num_sorted_tiles_host.get_element_size() == 1);
ck_tile::index_t num_sorted_tiles = num_sorted_tiles_host.mData[0] / block_m;
ck_tile::index_t intermediate_size_0 = intermediate_size;
ck_tile::index_t intermediate_size_1 = intermediate_size / (gate_only ? 1 : 2);
// TODO: better remove this in the future, or modify the token_id value
auto get_topk_id = [&](ck_tile::index_t token_id_, ck_tile::index_t expert_id_) {
for(ck_tile::index_t i_ = 0; i_ < topk; i_++)
{
if(token_ids_host(token_id_, i_) == expert_id_)
return i_;
}
throw std::runtime_error("not correct token/expert pair\n");
return -1; // TODO: not correct!!
};
ck_tile::index_t intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2);
ck_tile::index_t intermediate_size_1 = intermediate_size;
ck_tile::HostTensor<AccDataType> out_topk_tokens({tokens, topk, hidden_size});
......@@ -105,11 +94,31 @@ void reference_fused_moe(
if(i_tile >= num_sorted_tiles)
return;
ck_tile::index_t i_expert = sorted_expert_ids_host.mData[i_tile];
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
ck_tile::index_t i_topk = i_token >> 24;
i_token &= 0xffffff;
if(i_token >= tokens)
return;
(void)token_ids_host;
#else
// TODO: better remove this in the future, or modify the token_id value
auto get_topk_id = [&](ck_tile::index_t token_id_, ck_tile::index_t expert_id_) {
for(ck_tile::index_t i_ = 0; i_ < topk; i_++)
{
if(token_ids_host(token_id_, i_) == expert_id_)
return i_;
}
throw std::runtime_error("not correct token/expert pair\n");
return -1; // TODO: not correct!!
};
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
if(i_token >= tokens)
return;
ck_tile::index_t i_topk = get_topk_id(i_token, i_expert); // TODO: ugly
auto weight = sorted_weight_host.mData[i_flatten];
#endif
auto weight = sorted_weight_host.mData[i_flatten];
ck_tile::HostTensor<AccDataType> acc_0({1, intermediate_size_0});
// first gemm
......
......@@ -719,7 +719,82 @@ struct Silu
constexpr T one = type_convert<T>(1);
y = x * (one / (one + ck_tile::exp(-x)));
};
template <>
CK_TILE_HOST_DEVICE void operator()<fp32x2_t>(fp32x2_t& y, const fp32x2_t& x) const
{
constexpr auto one = type_convert<float>(1);
y[0] = x[0] * __builtin_amdgcn_rcpf(one + ck_tile::exp(-x[0]));
y[1] = x[1] * __builtin_amdgcn_rcpf(one + ck_tile::exp(-x[1]));
};
};
#if 0
// Silu, the formular is not so good to do inline asm (dependency)
// we put the code here purposely if in the future ppl want to try
struct SiluAsm
{
template <typename T>
CK_TILE_HOST void operator()(T& y, T& x) const
{
static_assert(std::is_same_v<T, float>, "Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
y = x * (one / (one + ck_tile::exp(-x)));
};
template <typename T>
CK_TILE_DEVICE void operator()(T& y, T& x) const
{
static_assert(std::is_same_v<T, float>, "Data type is not supported by this operation!");
const uint32_t log2e_neg_ = 0x3fb8aa3b | 0x80000000; // log2e_v<float> * -1;
// NOTE: x/y can't be same register before inline asm
// "+v" as y, "v" as x is not enought, x/y stil maybe put to same register
T tmp = x;
asm volatile("v_mul_f32 %[v_y], %[s_log2e], %[v_x]\n"
"v_exp_f32 %[v_y], %[v_y]\n"
"s_nop 0 ; hazard for exp\n"
"v_add_f32 %[v_y], %[v_y], 1.0\n"
"v_rcp_f32 %[v_y], %[v_y]\n"
"s_nop 0 ; hazard for rcp\n"
"v_mul_f32 %[v_y], %[v_x], %[v_y]\n"
: [v_y] "+v"(y), [v_x] "+v"(tmp)
: [s_log2e] "s"(log2e_neg_)
:);
};
template <>
CK_TILE_HOST void operator()<fp32x2_t>(fp32x2_t& y, fp32x2_t& x) const
{
constexpr auto one = type_convert<float>(1);
y[0] = x[0] * (one / (one + ck_tile::exp(-x[0])));
y[1] = x[1] * (one / (one + ck_tile::exp(-x[1])));
};
template <>
CK_TILE_DEVICE void operator()<fp32x2_t>(fp32x2_t& y, fp32x2_t& x) const
{
const uint32_t log2e_neg_ = 0x3fb8aa3b | 0x80000000; // log2e_v<float> * -1;
// NOTE: x/y can't be same register before inline asm
// float tmp0 = x[0], tmp1 = x[1];
asm volatile("v_mul_f32 %[v_y0], %[s_log2e], %[v_x0]\n"
"v_mul_f32 %[v_y1], %[s_log2e], %[v_x1]\n"
"v_exp_f32 %[v_y0], %[v_y0]\n"
"v_exp_f32 %[v_y1], %[v_y1]\n"
"v_add_f32 %[v_y0], %[v_y0], 1.0\n"
"v_add_f32 %[v_y1], %[v_y1], 1.0\n"
"v_rcp_f32 %[v_y0], %[v_y0]\n"
"v_rcp_f32 %[v_y1], %[v_y1]\n"
"v_mul_f32 %[v_y0], %[v_x0], %[v_y0]\n"
"v_mul_f32 %[v_y1], %[v_x1], %[v_y1]\n"
: [v_y0] "+v"(y[0]), [v_y1] "+v"(y[1]), [v_x0] "+v"(x[0]), [v_x1] "+v"(x[1])
: [s_log2e] "s"(log2e_neg_)
:);
};
};
#endif
struct TanH
{
......
......@@ -65,7 +65,8 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_Base
// in LDS we need store as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
return 2 * 2 * 4 * 4 * (16 * 4 + 4) * sizeof(bf16_t);
constexpr index_t nbufs = 2;
return 2 * 2 * 4 * 4 * (16 * 4 + 4) * sizeof(bf16_t) * nbufs;
}
};
......@@ -173,7 +174,6 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_BF16 : public FlatmmSn_32x128x512_1x4x
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:[smem_]"+r"(smem),
[s_loop_cnt]"+s"(loop_cnt),
[c0]"+v" (v_c0),
......@@ -418,7 +418,6 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16 : public FlatmmSn_32x128x512_1x4x
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:[smem_]"+r"(smem),
[s_loop_cnt]"+s"(loop_cnt),
[c0]"+v" (v_c0),
......
......@@ -477,7 +477,7 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x512_
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s8", "s9", "s12", "s13", "s14", "s15", "s38", "s39", "s52", "s86",
"s36", "s37","s59","s80",
"s36", "s37", "s56", "s59", "s60", "s80",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
"v50", "v54", "v55",
"v64","v65","v66","v67","v68","v69","v70","v71",
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
// clang-format off
// define the CK_TILE_** macro before include this file to change kernel variation
// we will undef everything defined in this file
#ifndef CK_TILE_FLATMM_UK_MFMA
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#endif
......@@ -816,3 +823,5 @@
#undef _UK_MFMA_
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
#undef CK_TILE_FLATMM_UK_MFMA
// clang-format on
......@@ -111,7 +111,7 @@ struct FusedMoeGemmHostArgs
const void* num_sorted_tiles_ptr; // [1]
index_t hidden_size; // k
index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
index_t intermediate_size; // n / TP, for Gate/UP/Down
index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups
index_t topk; // need this?
......@@ -178,7 +178,7 @@ struct FusedMoeGemmKernel
return base_str;
}();
return _SS_("fused_moe_") + _SS_(prec_str) + "_" +
return _SS_("fused_moe_") + _SS_(prec_str) + "_" + (IsGateOnly ? "g1u0_":"g1u1_") +
_TS_(S_::Block_M0) + "x" + _TS_(S_::Block_N0) + "x" + _TS_(S_::Block_K0) + "x" + _TS_(S_::Block_N1) + "_" +
_TS_(S_::WarpPerBlock_M0) + "x" + _TS_(S_::WarpPerBlock_N0) + "x" + _TS_(S_::WarpPerBlock_K0) + "_" +
_TS_(S_::Warp_M0) + "x" + _TS_(S_::Warp_N0) + "x" + _TS_(S_::Warp_K0) + "_" + _SS_(Pipeline::name);
......@@ -204,7 +204,7 @@ struct FusedMoeGemmKernel
const void* num_sorted_tiles_ptr;
index_t hidden_size; // k
index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
index_t intermediate_size; // n / TP, for Gate/Up/Down
index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups
index_t topk; // need this?
......@@ -239,7 +239,7 @@ struct FusedMoeGemmKernel
{
if constexpr(UseUK)
{
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
__shared__ CK_TILE_LDS_ADDR char smem[GetSmemSize()];
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
......@@ -298,6 +298,9 @@ struct FusedMoeGemmKernel
index_t token_id =
reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
token_id &= 0xffffff;
#endif
auto topk_weight = reinterpret_cast<const TopkWeightDataType*>(
kargs.sorted_weight_ptr)[sorted_token_id];
......
......@@ -70,11 +70,16 @@ struct FusedMoeGemmPipeline_FlatmmUk
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
#if 1
constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize();
constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
return max(smem_0, max(smem_1, smem_bridge));
return max(smem_0 + smem_1, smem_bridge);
#else
// keep it here purposely in case we have regression
return 65536;
#endif
}
// this is the thread-offset along row/col
......@@ -125,6 +130,9 @@ struct FusedMoeGemmPipeline_FlatmmUk
array<index_t, n_size> row_ids;
static_for<0, n_size, 1>{}([&](auto i) {
row_ids.at(i) = sorted_token_ids_ptr[coords[i]]; // base_coord + i * MLans;
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
row_ids.at(i) &= 0xffffff;
#endif
});
return row_ids;
......@@ -164,9 +172,12 @@ struct FusedMoeGemmPipeline_FlatmmUk
index_t sorted_tile_id,
index_t intermediate_tile_id)
{
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
ck_tile::index_t shared_intermediate_size_0 = kargs.intermediate_size;
ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size / hidden_radio_0;
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
ck_tile::index_t shared_intermediate_size_0 =
kargs.intermediate_size * hidden_radio_0; // total gate+up
ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size;
// after weight shuffling, gate-only: [nr0, kr0, w0], gate+up: [nr0_gate + nr0_up, kr0, w0]
index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W
index_t kr_0 = kargs.hidden_size / BlockShape::Warp_K0; // divide K in W
......@@ -200,29 +211,35 @@ struct FusedMoeGemmPipeline_FlatmmUk
make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
auto g_win = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 +
interm_idx_nr0 * kr_0 * BlockShape::Block_W0;
auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr,
auto make_gu_win = [&](const auto* ptr_) {
auto view_ = make_naive_tensor_view<address_space_enum::global>(
ptr_,
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
number<kAlignmentG>{},
number<1>{});
auto g_window_ = make_tile_window_linear_raw(
g_view_,
auto win_ = make_tile_window_linear_raw(
view_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
{0, 0, 0},
Policy::template MakeGlobalTileDistribution_G<Problem>(),
sequence<0, 1, 1>{});
return g_window_;
}();
return win_;
};
const GDataType* gu_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 +
interm_idx_nr0 * kr_0 * BlockShape::Block_W0;
auto g_win = make_gu_win(gu_ptr);
// Note: gu swizzled, [nr_u+nr_g, kr, w], hence base offset to up is just interm*hidden
auto u_win = make_gu_win(gu_ptr + kargs.intermediate_size * kargs.hidden_size);
auto g_res = g_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
auto u_res = u_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
auto g_coords = generate_tuple([&](auto i) { return g_win.cached_coords_[i].get_offset(); },
number<decltype(g_win)::NumAccess_NonLinear>{});
......@@ -309,28 +326,73 @@ struct FusedMoeGemmPipeline_FlatmmUk
auto w_scale = GetWeightScale(
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
auto uk_0 = Policy::template GetUK_0<Problem>();
auto acc_0 = uk_0(a_res,
a_coords,
g_res,
g_coords,
smem,
kargs.hidden_size,
BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 *
BlockShape::Block_W0); // tile offset for B matrix each unroll
sweep_tile(
acc_0,
[&](auto idx0, auto idx1) {
fp32x2_t v_{acc_0(idx0), acc_0(idx1)};
typename Problem::GateActivation{}(v_, v_);
acc_0(idx0) = v_.x;
acc_0(idx1) = v_.y;
},
sequence<1, 2>{});
auto y_pre = cast_tile<YDataType>(acc_0);
auto uk_0 = Policy::template GetUK_0<Problem>();
auto y_pre = [&]() {
if constexpr(IsGateOnly)
{
auto acc_0 = uk_0(a_res,
a_coords,
g_res,
g_coords,
smem,
kargs.hidden_size,
BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 *
BlockShape::Block_W0); // tile offset for B matrix each unroll
sweep_tile(
acc_0,
[&](auto idx0, auto idx1) {
fp32x2_t v_{acc_0(idx0), acc_0(idx1)};
typename Problem::GateActivation{}(v_, v_);
acc_0(idx0) = v_.x;
acc_0(idx1) = v_.y;
},
sequence<1, 2>{});
return cast_tile<YDataType>(acc_0);
}
else
{
uint32x8_t gu_res;
gu_res[0] = g_res[0];
gu_res[1] = g_res[1];
gu_res[2] = g_res[2];
gu_res[3] = g_res[3];
gu_res[4] = u_res[0];
gu_res[5] = u_res[1];
gu_res[6] = u_res[2];
gu_res[7] = u_res[3];
auto acc_0 = uk_0(a_res,
a_coords,
gu_res,
g_coords,
smem,
kargs.hidden_size,
BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 * BlockShape::Block_W0,
bool_constant<true>{}); // tile offset for B matrix each unroll
sweep_tile(
acc_0.at(number<0>{}),
[&](auto idx0, auto idx1) {
fp32x2_t v_{acc_0.at(number<0>{})(idx0), acc_0.at(number<0>{})(idx1)};
typename Problem::GateActivation{}(v_, v_);
acc_0.at(number<0>{})(idx0) = v_.x;
acc_0.at(number<0>{})(idx1) = v_.y;
},
sequence<1, 2>{});
auto reduced_acc_0 =
tile_elementwise_in([&](const auto& a_, const auto& b_) { return a_ * b_; },
acc_0.at(number<0>{}),
acc_0.at(number<1>{}));
return cast_tile<YDataType>(reduced_acc_0);
}
}();
block_sync_lds();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment