Unverified Commit a11cf2c6 authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into codegen_hiprtc

parents a72e9efa 64d5c4d6
...@@ -41,6 +41,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf ...@@ -41,6 +41,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
t.prec_sq, t.prec_sq,
t.prec_kw, t.prec_kw,
t.block_m, t.block_m,
t.activation,
t.gate_only, t.gate_only,
t.fused_quant}; t.fused_quant};
auto a1 = fused_moegemm_args{ auto a1 = fused_moegemm_args{
......
...@@ -17,15 +17,67 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile: ...@@ -17,15 +17,67 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
// clang-format off // clang-format off
float r = -1; float r = -1;
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0)
{ {
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0)
{
constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0)
{
constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0)
{
constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1)
{
constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1)
{
constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1)
{
constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a); r = fused_moegemm_<t_>(s, a);
} }
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" && else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1)
{ {
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a); r = fused_moegemm_<t_>(s, a);
} }
// clang-format on // clang-format on
......
...@@ -21,21 +21,31 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) ...@@ -21,21 +21,31 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
typename Ts_::BlockTile_1, typename Ts_::BlockTile_1,
typename Ts_::WarpPerBlock_0, typename Ts_::WarpPerBlock_0,
typename Ts_::WarpTile_0>; typename Ts_::WarpTile_0>;
using f_problem =
ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType, constexpr auto get_activation_ = []() {
typename Ts_::GDataType, if constexpr(Ts_::Activation == 0)
typename Ts_::DDataType, {
typename Ts_::AccDataType, return ck_tile::element_wise::FastGeluAsm{};
typename Ts_::ODataType, }
typename Ts_::AScaleDataType, else
typename Ts_::GScaleDataType, return ck_tile::element_wise::Silu{};
typename Ts_::DScaleDataType, };
typename Ts_::YSmoothScaleDataType, using f_act_ = ck_tile::remove_cvref_t<decltype(get_activation_())>;
typename Ts_::TopkWeightDataType,
typename Ts_::IndexDataType, using f_problem = ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
ck_tile::element_wise::FastGeluAsm, // TODO: hardcoded typename Ts_::GDataType,
f_shape, typename Ts_::DDataType,
f_traits>; typename Ts_::AccDataType,
typename Ts_::ODataType,
typename Ts_::AScaleDataType,
typename Ts_::GScaleDataType,
typename Ts_::DScaleDataType,
typename Ts_::YSmoothScaleDataType,
typename Ts_::TopkWeightDataType,
typename Ts_::IndexDataType,
f_act_, // TODO: hardcoded
f_shape,
f_traits>;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>; // using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem>; using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem>;
......
...@@ -15,7 +15,8 @@ template <typename I, ...@@ -15,7 +15,8 @@ template <typename I,
typename KW, typename KW,
typename BlockTIle_, // seq<b_token, b_interm, b_hidden, b_down> typename BlockTIle_, // seq<b_token, b_interm, b_hidden, b_down>
typename WarpPerBlock_, typename WarpPerBlock_,
typename WarpTile_, // seq<*,*,*>, used to select mfma typename WarpTile_, // seq<*,*,*>, used to select mfma
ck_tile::index_t Activation_ = 0, // 0: Gelu 1: Silu
ck_tile::index_t GateOnly_ = 0, ck_tile::index_t GateOnly_ = 0,
ck_tile::index_t FusedQuant_ = 0> ck_tile::index_t FusedQuant_ = 0>
struct fmoe_ // traits, ugly name, only used for internal struct fmoe_ // traits, ugly name, only used for internal
...@@ -44,10 +45,11 @@ struct fmoe_ // traits, ugly name, only used for internal ...@@ -44,10 +45,11 @@ struct fmoe_ // traits, ugly name, only used for internal
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>; using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>; using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>;
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>; using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_>;
using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>; using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>; using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>;
static constexpr ck_tile::index_t Activation = Activation_; // 0: Gelu 1: Silu
static constexpr ck_tile::index_t GateOnly = GateOnly_; static constexpr ck_tile::index_t GateOnly = GateOnly_;
static constexpr ck_tile::index_t FusedQuant = FusedQuant_; static constexpr ck_tile::index_t FusedQuant = FusedQuant_;
}; };
...@@ -8,7 +8,18 @@ ...@@ -8,7 +8,18 @@
// clang-format off // clang-format off
template float fused_moegemm_< template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0> fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a); >(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on // clang-format on
...@@ -8,7 +8,19 @@ ...@@ -8,7 +8,19 @@
// clang-format off // clang-format off
template float fused_moegemm_< template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0> fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a); >(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on // clang-format on
...@@ -108,12 +108,14 @@ auto create_args(int argc, char* argv[]) ...@@ -108,12 +108,14 @@ auto create_args(int argc, char* argv[])
.insert( .insert(
"gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate") "gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
.insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm") .insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm")
.insert("act", "0", "activation after first gemm. 0:gelu, 1:silu")
.insert("balance", .insert("balance",
"0", "0",
"if set to 1, will try balance the expert in topk-ids(convenient for testing)") "if set to 1, will try balance the expert in topk-ids(convenient for testing)")
.insert("init", .insert("init",
"2", "1",
"init method. 0:random stepped float(fast). 1: random uniform, 2:rand normalized" "init method. 0:random stepped float(fast). 1: random uniform[-0.5, 0.5], 2:rand "
"normalized[0, 1]"
"normalized(slow)") "normalized(slow)")
.insert("seed", "11939", "seed used to do random") .insert("seed", "11939", "seed used to do random")
.insert("warmup", "5", "cold iter") .insert("warmup", "5", "cold iter")
...@@ -135,6 +137,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -135,6 +137,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t intermediate_size = arg_parser.get_int("i"); ck_tile::index_t intermediate_size = arg_parser.get_int("i");
ck_tile::index_t stride = arg_parser.get_int("stride"); ck_tile::index_t stride = arg_parser.get_int("stride");
ck_tile::index_t block_m = arg_parser.get_int("bm"); ck_tile::index_t block_m = arg_parser.get_int("bm");
ck_tile::index_t activation = arg_parser.get_int("act");
if(stride < 0) if(stride < 0)
stride = hidden_size; stride = hidden_size;
std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_i = arg_parser.get_str("prec_i");
...@@ -194,11 +197,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -194,11 +197,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
return std::string(", st:") + std::to_string(stride); return std::string(", st:") + std::to_string(stride);
}(); }();
std::cout << "[" << api_str << "|" << prec_str << "]" std::cout
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str << "[" << api_str << "|" << prec_str << "]"
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp << " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str
<< ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1 << ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
<< ", go:" << gate_only << ", q:" << fused_quant << std::flush; << ", act:"
<< activation
// << ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
<< (gate_only ? ", g1u0" : ", g1u1") << ", q:" << fused_quant << std::flush;
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>; using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
using ADataType = typename TypeConfig::ADataType; using ADataType = typename TypeConfig::ADataType;
...@@ -370,6 +376,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -370,6 +376,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
prec_sq, prec_sq,
prec_kw, prec_kw,
block_m, block_m,
activation,
gate_only, gate_only,
fused_quant}; fused_quant};
...@@ -389,7 +396,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -389,7 +396,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_buf.GetDeviceBuffer(), num_sorted_tiles_buf.GetDeviceBuffer(),
block_m, block_m,
hidden_size, hidden_size,
shared_intermediate_size_0, intermediate_size / tp,
tokens, tokens,
experts, experts,
topk, topk,
...@@ -408,6 +415,28 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -408,6 +415,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< cal_tbps(ave_time) << " TB/s" << std::flush; << cal_tbps(ave_time) << " TB/s" << std::flush;
bool pass = true; bool pass = true;
#define CPU_FUSED_MOE(act_type_) \
ck_tile::reference_fused_moe<AccDataType, act_type_>(a_host, \
g_host, \
d_host, \
sa_host, \
sg_host, \
sd_host, \
sy_host, \
o_host, \
sorted_token_ids_host, \
sorted_weight_host, \
sorted_expert_ids_host, \
num_sorted_tiles_host, \
topk_ids_host, \
block_m, \
tokens, \
experts, \
hidden_size, \
intermediate_size / tp, \
topk, \
gate_only)
if(do_validation) if(do_validation)
{ {
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>( ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
...@@ -419,28 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -419,28 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_host.mData[0], num_sorted_tiles_host.mData[0],
experts, experts,
block_m); block_m);
if(activation == 0)
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>( {
a_host, CPU_FUSED_MOE(ck_tile::element_wise::Gelu);
g_host, }
d_host, else
sa_host, {
sg_host, CPU_FUSED_MOE(ck_tile::element_wise::Silu);
sd_host, }
sy_host,
o_host,
sorted_token_ids_host,
sorted_weight_host,
sorted_expert_ids_host,
num_sorted_tiles_host,
topk_ids_host,
block_m,
tokens,
experts,
hidden_size,
shared_intermediate_size_0,
topk,
gate_only);
auto o_dev = o_buf.ToHost<ODataType>(); auto o_dev = o_buf.ToHost<ODataType>();
// o_dev.savetxt("gpu-out.txt", "float"); // o_dev.savetxt("gpu-out.txt", "float");
...@@ -491,6 +506,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -491,6 +506,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
prec_sq, prec_sq,
prec_kw, prec_kw,
block_m, block_m,
activation,
gate_only, gate_only,
fused_quant}; fused_quant};
...@@ -507,7 +523,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -507,7 +523,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
sorted_expert_ids_buf.GetDeviceBuffer(), sorted_expert_ids_buf.GetDeviceBuffer(),
num_sorted_tiles_buf.GetDeviceBuffer(), num_sorted_tiles_buf.GetDeviceBuffer(),
hidden_size, hidden_size,
shared_intermediate_size_0, intermediate_size / tp,
tokens, tokens,
experts, experts,
topk, topk,
...@@ -529,27 +545,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -529,27 +545,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation) if(do_validation)
{ {
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>( if(activation == 0)
a_host, {
g_host, CPU_FUSED_MOE(ck_tile::element_wise::Gelu);
d_host, }
sa_host, else
sg_host, {
sd_host, CPU_FUSED_MOE(ck_tile::element_wise::Silu);
sy_host, }
o_host,
sorted_token_ids_host,
sorted_weight_host,
sorted_expert_ids_host,
num_sorted_tiles_host,
topk_ids_host,
block_m,
tokens,
experts,
hidden_size,
shared_intermediate_size_0,
topk,
gate_only);
auto o_dev = o_buf.ToHost<ODataType>(); auto o_dev = o_buf.ToHost<ODataType>();
// o_dev.savetxt("gpu-out.txt", "float"); // o_dev.savetxt("gpu-out.txt", "float");
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
...@@ -51,7 +51,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre ...@@ -51,7 +51,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>; using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t< using GemmEpilogue = std::conditional_t<
CShuffleEpilogue, CShuffleEpilogue,
...@@ -63,8 +63,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre ...@@ -63,8 +63,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
kOutputRank, kOutputRank,
1, 1,
0, 0,
TilePartitioner::kM, TilePartitioner::MPerBlock,
TilePartitioner::kN>>, TilePartitioner::NPerBlock>>,
ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf,
...@@ -86,56 +113,16 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -86,56 +113,16 @@ int run_batched_gemm_example_with_layouts(int argc,
int n_warmup = arg_parser.get_int("warmup"); int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat"); int n_repeat = arg_parser.get_int("repeat");
using namespace ck_tile::literals; stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
auto f_host_tensor_descriptor = [](std::size_t batch_count_, stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(c_layout));
std::size_t row,
std::size_t col, ck_tile::HostTensor<ADataType> a_m_k(ck_tile::host_tensor_descriptor(
std::size_t stride, batch_count, M, K, stride_A, batch_stride_A, is_row_major(a_layout)));
std::size_t batch_stride, ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
auto layout) { batch_count, K, N, stride_B, batch_stride_B, is_row_major(b_layout)));
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>) ck_tile::HostTensor<CDataType> c_m_n_dev_result(ck_tile::host_tensor_descriptor(
{ batch_count, M, N, stride_C, batch_stride_C, is_row_major(c_layout)));
return ck_tile::HostTensorDescriptor({batch_count_, row, col},
{batch_stride, stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({batch_count_, row, col},
{batch_stride, 1_uz, stride});
}
};
auto f_get_default_stride = [](std::size_t row,
std::size_t col,
std::size_t stride,
auto layout) {
if(stride == 0)
{
// give a chance if stride is zero, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
stride_A = f_get_default_stride(M, K, stride_A, a_layout);
stride_B = f_get_default_stride(K, N, stride_B, b_layout);
stride_C = f_get_default_stride(M, N, stride_C, c_layout);
ck_tile::HostTensor<ADataType> a_m_k(
f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, a_layout));
ck_tile::HostTensor<BDataType> b_k_n(
f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, b_layout));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, c_layout));
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k); ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n); ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
...@@ -171,23 +158,33 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -171,23 +158,33 @@ int run_batched_gemm_example_with_layouts(int argc,
if(arg_parser.get_int("v") == 1) if(arg_parser.get_int("v") == 1)
{ {
ck_tile::HostTensor<CDataType> c_m_n_host_ref( ck_tile::HostTensor<CDataType> c_m_n_host_ref(ck_tile::host_tensor_descriptor(
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{})); batch_count, M, N, stride_C, batch_stride_C, is_row_major(CLayout){}));
c_m_n_host_ref.SetZero(); c_m_n_host_ref.SetZero();
const auto b_n_k = b_k_n.transpose({0, 2, 1}); const auto b_n_k = b_k_n.transpose({0, 2, 1});
ck_tile::reference_batched_gemm<ADataType, BDataType, AccDataType, CDataType>( ck_tile::reference_batched_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_n_k, c_m_n_host_ref); a_m_k, b_n_k, c_m_n_host_ref);
const float max_accumulated_value =
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref); *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
} }
else if(arg_parser.get_int("v") == 2) else if(arg_parser.get_int("v") == 2)
{ {
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref( ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(ck_tile::host_tensor_descriptor(
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{})); batch_count, M, N, stride_C, batch_stride_C, is_row_major(CLayout){}));
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
c_m_n_gpu_ref.SetZero(); c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero();
...@@ -240,7 +237,18 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -240,7 +237,18 @@ int run_batched_gemm_example_with_layouts(int argc,
ck_tile::hip_check_error(hipFree(d_C)); ck_tile::hip_check_error(hipFree(d_C));
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); const float max_accumulated_value =
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_gpu_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl; std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
} }
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "grouped_gemm.hpp" #include "grouped_gemm.hpp"
#include "utils.hpp"
namespace { namespace {
...@@ -102,7 +101,7 @@ using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, ...@@ -102,7 +101,7 @@ using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
GemmEpilogue<CLayout>>; GemmEpilogue<CLayout>>;
}; // namespace }; // namespace
std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs) std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{ {
return ::Kernel<std::nullptr_t, std::nullptr_t, std::nullptr_t>::GetWorkSpaceSize(gemm_descs); return ::Kernel<std::nullptr_t, std::nullptr_t, std::nullptr_t>::GetWorkSpaceSize(gemm_descs);
} }
......
...@@ -52,8 +52,8 @@ auto create_args(int argc, char* argv[]) ...@@ -52,8 +52,8 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
} }
std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs); std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs);
float grouped_gemm_calc(const std::vector<grouped_gemm_kargs>& gemm_descs, float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s, const ck_tile::stream_config& s,
void* p_workspace_); void* p_workspace_);
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float invoke_gemm(int n_warmup, float invoke_gemm(int n_warmup,
int n_repeat, int n_repeat,
...@@ -11,7 +38,7 @@ float invoke_gemm(int n_warmup, ...@@ -11,7 +38,7 @@ float invoke_gemm(int n_warmup,
{ {
ck_tile::DeviceMem gemm_workspace; ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(GetWorkspaceSize(args)); gemm_workspace.Realloc(get_workspace_size(args));
float ave_time = grouped_gemm<ALayout, BLayout, CLayout>( float ave_time = grouped_gemm<ALayout, BLayout, CLayout>(
args, args,
...@@ -108,16 +135,19 @@ int run_grouped_gemm_example_with_layouts(int argc, ...@@ -108,16 +135,19 @@ int run_grouped_gemm_example_with_layouts(int argc,
const ck_tile::index_t N = Ns[i]; const ck_tile::index_t N = Ns[i];
const ck_tile::index_t K = Ks[i]; const ck_tile::index_t K = Ks[i];
stride_As[i] = f_get_default_stride(M, N, stride_As[i], a_layout); stride_As[i] =
stride_Bs[i] = f_get_default_stride(K, N, stride_Bs[i], b_layout); ck_tile::get_default_stride(M, N, stride_As[i], is_row_major(a_layout));
stride_Cs[i] = f_get_default_stride(M, N, stride_Cs[i], CLayout{}); stride_Bs[i] =
ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
a_m_k_tensors.push_back( stride_Cs[i] =
ck_tile::HostTensor<ADataType>(f_host_tensor_descriptor(M, K, stride_As[i], a_layout))); ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{}));
b_k_n_tensors.push_back(
ck_tile::HostTensor<BDataType>(f_host_tensor_descriptor(K, N, stride_Bs[i], b_layout))); a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout))));
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>( c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
f_host_tensor_descriptor(M, N, stride_Cs[i], CLayout{}))); ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{}))));
std::cout << "gemm[" << i << "]" std::cout << "gemm[" << i << "]"
<< " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc
...@@ -157,12 +187,23 @@ int run_grouped_gemm_example_with_layouts(int argc, ...@@ -157,12 +187,23 @@ int run_grouped_gemm_example_with_layouts(int argc,
{ {
for(int i = 0; i < group_count; ++i) for(int i = 0; i < group_count; ++i)
{ {
ck_tile::HostTensor<CDataType> c_m_n_host_ref( ck_tile::HostTensor<CDataType> c_m_n_host_ref(ck_tile::host_tensor_descriptor(
f_host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{})); Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{})));
c_m_n_host_ref.SetZero(); c_m_n_host_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>( ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
pass &= ck_tile::check_err(c_m_n_tensors[i], c_m_n_host_ref); const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(Ks[i], 1 /*kbatch*/, max_accumulated_value);
pass &= ck_tile::check_err(c_m_n_tensors[i],
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "gemm[" << i
<< "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
} }
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <typename TLayout>
constexpr auto
f_host_tensor_descriptor(std::size_t row, std::size_t col, std::size_t stride, TLayout layout)
{
using namespace ck_tile::literals;
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
}
}
template <typename TLayout>
constexpr auto
f_get_default_stride(std::size_t row, std::size_t col, std::size_t stride, TLayout layout)
{
if(stride == 0)
{
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
}
...@@ -17,7 +17,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) ...@@ -17,7 +17,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#endif #endif
// to do: add various levels of logging with CK_LOG_LEVEL // to do: add various levels of logging with CK_LOG_LEVEL
#ifndef CK_TIME_KERNEL
#define CK_TIME_KERNEL 1 #define CK_TIME_KERNEL 1
#endif
// constant address space for kernel parameter // constant address space for kernel parameter
// https://llvm.org/docs/AMDGPUUsage.html#address-spaces // https://llvm.org/docs/AMDGPUUsage.html#address-spaces
...@@ -155,6 +157,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) ...@@ -155,6 +157,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// LDS direct loads using inline assembly // LDS direct loads using inline assembly
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0 #define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0
// set rounding to nearest even as default for bf16 conversions
#define CK_USE_RNE_BF16_CONVERSION 1
// set rounding to nearest even as default for f8 conversions // set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 0 #define CK_USE_SR_F8_CONVERSION 0
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -122,19 +122,6 @@ __global__ void ...@@ -122,19 +122,6 @@ __global__ void
static_for<0, NumDTensor, 1>{}( static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; }); [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; });
if constexpr(is_same_v<AElementwiseOperation, element_wise::DynamicUnaryOp>)
{
a_element_op.InitUnaryOpPtrOnDevice();
}
if constexpr(is_same_v<BElementwiseOperation, element_wise::DynamicUnaryOp>)
{
b_element_op.InitUnaryOpPtrOnDevice();
}
if constexpr(is_same_v<CDEElementwiseOperation, element_wise::DynamicUnaryOp>)
{
cde_element_op.InitUnaryOpPtrOnDevice();
}
if constexpr(isMultiA || isMultiB) if constexpr(isMultiA || isMultiB)
{ {
AsPointer p_as_grid_grp; AsPointer p_as_grid_grp;
......
...@@ -247,32 +247,6 @@ struct DequantPack8 ...@@ -247,32 +247,6 @@ struct DequantPack8
constexpr const static bool is_pack8_invocable = true; constexpr const static bool is_pack8_invocable = true;
}; };
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wnon-virtual-dtor"
struct UnaryOpBase
{
public:
__host__ __device__ ~UnaryOpBase() = default;
__host__ __device__ constexpr UnaryOpBase() = default;
__host__ __device__ constexpr UnaryOpBase(const UnaryOpBase&) = default;
__host__ __device__ constexpr UnaryOpBase(UnaryOpBase&&) = default;
__host__ __device__ UnaryOpBase& operator=(const UnaryOpBase&) = default;
__host__ __device__ UnaryOpBase& operator=(UnaryOpBase&&) = default;
__host__ __device__ virtual inline void operator()(float& y, const float& x) const = 0;
__host__ __device__ virtual inline void operator()(double& y, const double& x) const = 0;
__host__ __device__ virtual inline void operator()(int32_t& y, const int32_t& x) const = 0;
__host__ __device__ virtual inline void operator()(int8_t& y, const int8_t& x) const = 0;
__host__ __device__ virtual inline void operator()(half_t& y, const half_t& x) const = 0;
__host__ __device__ virtual inline void operator()(bhalf_t& y, const bhalf_t& x) const = 0;
};
struct PassThroughPack2 struct PassThroughPack2
{ {
template <typename Y, typename X> template <typename Y, typename X>
...@@ -304,27 +278,8 @@ struct PassThroughPack2 ...@@ -304,27 +278,8 @@ struct PassThroughPack2
constexpr const static bool is_pack2_invocable = true; constexpr const static bool is_pack2_invocable = true;
}; };
struct PassThrough final : public UnaryOpBase struct PassThrough
{ {
__host__ __device__ constexpr PassThrough() = default;
__host__ __device__ constexpr PassThrough(const PassThrough&) = default;
__host__ __device__ constexpr PassThrough(PassThrough&&) = default;
__host__ __device__ PassThrough& operator=(const PassThrough&) = default;
__host__ __device__ PassThrough& operator=(PassThrough&&) = default;
__host__ __device__ ~PassThrough() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final { y = x; }
__host__ __device__ inline void operator()(double& y, const double& x) const final { y = x; }
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final { y = x; }
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final { y = x; }
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final { y = x; }
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final { y = x; }
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const; __host__ __device__ void operator()(Y& y, const X& x) const;
...@@ -334,6 +289,12 @@ struct PassThrough final : public UnaryOpBase ...@@ -334,6 +289,12 @@ struct PassThrough final : public UnaryOpBase
y = x; y = x;
} }
template <>
__host__ __device__ void operator()<double, double>(double& y, const double& x) const
{
y = x;
}
template <> template <>
__host__ __device__ void operator()<float, double>(float& y, const double& x) const __host__ __device__ void operator()<float, double>(float& y, const double& x) const
{ {
...@@ -346,12 +307,36 @@ struct PassThrough final : public UnaryOpBase ...@@ -346,12 +307,36 @@ struct PassThrough final : public UnaryOpBase
y = type_convert<double>(x); y = type_convert<double>(x);
} }
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
y = x;
}
template <> template <>
__host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const __host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const
{ {
y = type_convert<half_t>(x); y = type_convert<half_t>(x);
} }
template <>
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const
{
y = x;
}
template <> template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const __host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{ {
...@@ -376,6 +361,12 @@ struct PassThrough final : public UnaryOpBase ...@@ -376,6 +361,12 @@ struct PassThrough final : public UnaryOpBase
y = type_convert<float>(x); y = type_convert<float>(x);
} }
template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = x;
}
template <> template <>
__host__ __device__ void operator()<half_t, int8_t>(half_t& y, const int8_t& x) const __host__ __device__ void operator()<half_t, int8_t>(half_t& y, const int8_t& x) const
{ {
...@@ -675,45 +666,21 @@ struct UnarySquare ...@@ -675,45 +666,21 @@ struct UnarySquare
}; };
}; };
struct UnaryAbs final : public UnaryOpBase struct UnaryAbs
{ {
__host__ __device__ constexpr UnaryAbs() = default; template <typename T>
__host__ __device__ constexpr UnaryAbs(const UnaryAbs&) = default; __host__ __device__ void operator()(T& y, const T& x) const
__host__ __device__ constexpr UnaryAbs(UnaryAbs&&) = default;
__host__ __device__ UnaryAbs& operator=(const UnaryAbs&) = default;
__host__ __device__ UnaryAbs& operator=(UnaryAbs&&) = default;
__host__ __device__ ~UnaryAbs() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
y = ck::math::abs(x);
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
y = ck::math::abs(x);
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{ {
y = ck::math::abs(x);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final static_assert(is_same<T, float>::value || is_same<T, double>::value ||
{ is_same<T, half_t>::value || is_same<T, int32_t>::value ||
y = ck::math::abs(x); is_same<T, int8_t>::value,
} "Data type is not supported by this operation!");
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
y = math::abs(x); y = math::abs(x);
} };
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
y = ck::math::abs(x);
}
template <>
__host__ __device__ void operator()(f8_t& y, const f8_t& x) const __host__ __device__ void operator()(f8_t& y, const f8_t& x) const
{ {
y = ck::type_convert<f8_t>(ck::math::abs(ck::type_convert<float>(x))); y = ck::type_convert<f8_t>(ck::math::abs(ck::type_convert<float>(x)));
...@@ -732,41 +699,20 @@ struct UnarySqrt ...@@ -732,41 +699,20 @@ struct UnarySqrt
}; };
}; };
struct Relu final : public UnaryOpBase struct Relu
{ {
__host__ __device__ constexpr Relu() = default; template <typename T>
__host__ __device__ constexpr Relu(const Relu&) = default; __host__ __device__ void operator()(T& y, const T& x) const
__host__ __device__ constexpr Relu(Relu&&) = default;
__host__ __device__ Relu& operator=(const Relu&) = default;
__host__ __device__ Relu& operator=(Relu&&) = default;
__host__ __device__ ~Relu() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
y = x > 0 ? x : 0;
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
y = x > 0 ? x : 0;
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
y = x > 0 ? x : 0;
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
y = x > 0 ? x : 0;
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
y = x > 0 ? x : 0; y = x > 0 ? x : 0;
} }
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final template <>
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const
{ {
float x_f32 = type_convert<float>(x); float x_f32 = type_convert<float>(x);
float y_f32 = x_f32 > 0 ? x_f32 : 0; float y_f32 = x_f32 > 0 ? x_f32 : 0;
...@@ -913,52 +859,18 @@ struct Gelu ...@@ -913,52 +859,18 @@ struct Gelu
} }
}; };
struct Sigmoid final : public UnaryOpBase struct Sigmoid
{ {
__host__ __device__ constexpr Sigmoid() = default; template <typename T>
__host__ __device__ constexpr Sigmoid(const Sigmoid&) = default; __host__ __device__ void operator()(T& y, const T& x) const
__host__ __device__ constexpr Sigmoid(Sigmoid&&) = default;
__host__ __device__ Sigmoid& operator=(const Sigmoid&) = default;
__host__ __device__ Sigmoid& operator=(Sigmoid&&) = default;
__host__ __device__ ~Sigmoid() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
constexpr float one = type_convert<float>(1);
y = one / (one + math::exp(-x));
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
constexpr double one = type_convert<double>(1);
y = one / (one + ck::math::exp(-x));
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
constexpr int32_t one = type_convert<int32_t>(1);
y = one / (one + ck::math::exp(-x));
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
constexpr int8_t one = type_convert<int8_t>(1);
y = one / (one + ck::math::exp(-x));
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
constexpr half_t one = type_convert<half_t>(1);
y = one / (one + math::exp(-x));
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{ {
constexpr float one = type_convert<float>(1); static_assert(is_same<T, float>::value || is_same<T, double>::value ||
float x_f32 = ck::type_convert<float>(x); is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
float y_f32 = one / (one + ck::math::exp(x_f32)); is_same<T, int32_t>::value,
y = ck::type_convert<bhalf_t>(y_f32); "Data type is not supported by this operation!");
} constexpr T one = type_convert<T>(1);
y = one / (one + math::exp(-x));
};
}; };
struct Silu struct Silu
...@@ -974,44 +886,18 @@ struct Silu ...@@ -974,44 +886,18 @@ struct Silu
}; };
}; };
struct TanH final : public UnaryOpBase struct TanH
{ {
__host__ __device__ constexpr TanH() = default; template <typename T>
__host__ __device__ constexpr TanH(const TanH&) = default; __host__ __device__ void operator()(T& y, const T& x) const
__host__ __device__ constexpr TanH(TanH&&) = default;
__host__ __device__ TanH& operator=(const TanH&) = default;
__host__ __device__ TanH& operator=(TanH&&) = default;
__host__ __device__ ~TanH() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
y = math::tanh(x);
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
y = ck::math::tanh(x);
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
y = ck::math::tanh(x);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
y = ck::math::tanh(x);
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{ {
y = ck::math::tanh(x); static_assert(is_same<T, float>::value || is_same<T, double>::value ||
} is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final y = math::tanh(x);
{ };
y = ck::math::tanh(x);
}
}; };
struct ACos struct ACos
...@@ -1252,418 +1138,138 @@ struct Rcp ...@@ -1252,418 +1138,138 @@ struct Rcp
}; };
}; };
struct Swish final : public UnaryOpBase struct Swish
{ {
__host__ __device__ constexpr Swish(const Swish&) = default; Swish(float beta = 1.0f) : beta_(beta) {}
__host__ __device__ constexpr Swish(Swish&&) = default;
__host__ __device__ ~Swish() = default;
__host__ __device__ Swish(float beta = 1.0f) : beta_(beta) {}
__host__ __device__ float get_beta() const { return beta_; }
const float beta_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<float>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<double>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<int32_t>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<int8_t>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<half_t>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<bhalf_t>(x / (1.f + ck::math::exp(bx)));
}
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const __host__ __device__ void operator()(Y& y, const X& x) const
{ {
static_assert(is_same<X, float>::value || is_same<X, double>::value || static_assert(is_same<X, float>::value || is_same<X, double>::value ||
is_same<X, half_t>::value, is_same<X, ck::half_t>::value || is_same<X, int8_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
static_assert(is_same<Y, float>::value || is_same<Y, double>::value || static_assert(is_same<Y, float>::value || is_same<Y, double>::value ||
is_same<Y, half_t>::value, is_same<Y, ck::half_t>::value || is_same<Y, int8_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
float bx = -beta_ * type_convert<float>(x); float bx = -beta_ * type_convert<float>(x);
y = type_convert<Y>(x / (1.f + math::exp(bx))); y = type_convert<Y>(x / (1.f + math::exp(bx)));
} };
const float beta_;
}; };
struct SoftRelu final : public UnaryOpBase struct SoftRelu
{ {
__host__ __device__ constexpr SoftRelu(const SoftRelu&) = default; SoftRelu(float alpha = 1.f) : alpha_(alpha){};
__host__ __device__ constexpr SoftRelu(SoftRelu&&) = default;
__host__ __device__ ~SoftRelu() = default;
__host__ __device__ SoftRelu(float alpha = 1.0f) : alpha_(alpha) {}
__host__ __device__ float get_alpha() const { return alpha_; }
const float alpha_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
constexpr float one = type_convert<float>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
constexpr double one = type_convert<double>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
constexpr int32_t one = type_convert<int32_t>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
constexpr int8_t one = type_convert<int8_t>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final template <typename T>
{ __host__ __device__ void operator()(T& y, const T& x) const
half_t casted_alpha = type_convert<half_t>(alpha_);
constexpr half_t one = type_convert<half_t>(1);
y = math::log(one + math::exp(x * casted_alpha)) / casted_alpha;
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{ {
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_); static_assert(is_same<T, float>::value || is_same<T, double>::value ||
constexpr bhalf_t one = type_convert<bhalf_t>(1); is_same<T, half_t>::value || is_same<T, int32_t>::value ||
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1);
y = math::log(one + math::exp(x * casted_alpha)) / casted_alpha;
} }
const float alpha_;
}; };
struct Power final : public UnaryOpBase struct Power
{ {
__host__ __device__ constexpr Power(const Power&) = default; Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
__host__ __device__ constexpr Power(Power&&) = default; : alpha_(alpha), beta_(beta), gamma_(gamma){};
__host__ __device__ ~Power() = default;
__host__ __device__ Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) template <typename T>
: alpha_(alpha), beta_(beta), gamma_(gamma) __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
T casted_gamma = type_convert<T>(gamma_);
T shifted_scaled_x = casted_alpha + casted_beta * x;
y = math::pow(shifted_scaled_x, casted_gamma);
} }
__host__ __device__ float get_alpha() const { return alpha_; }
__host__ __device__ float get_beta() const { return beta_; }
__host__ __device__ float get_gamma() const { return gamma_; }
const float alpha_; const float alpha_;
const float beta_; const float beta_;
const float gamma_; const float gamma_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
float casted_beta = type_convert<float>(beta_);
float casted_gamma = type_convert<float>(gamma_);
float shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
double casted_beta = type_convert<double>(beta_);
double casted_gamma = type_convert<double>(gamma_);
double shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
int32_t casted_beta = type_convert<int32_t>(beta_);
int32_t casted_gamma = type_convert<int32_t>(gamma_);
int32_t shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
int8_t casted_beta = type_convert<int8_t>(beta_);
int8_t casted_gamma = type_convert<int8_t>(gamma_);
int8_t shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
half_t casted_beta = type_convert<half_t>(beta_);
half_t casted_gamma = type_convert<half_t>(gamma_);
half_t shifted_scaled_x = casted_alpha + casted_beta * x;
y = math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_);
bhalf_t casted_beta = type_convert<bhalf_t>(beta_);
bhalf_t casted_gamma = type_convert<bhalf_t>(gamma_);
bhalf_t shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
}; };
struct ClippedRelu final : public UnaryOpBase struct ClippedRelu
{ {
__host__ __device__ constexpr ClippedRelu(const ClippedRelu&) = default; ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
__host__ __device__ constexpr ClippedRelu(ClippedRelu&&) = default;
__host__ __device__ ~ClippedRelu() = default;
__host__ __device__ ClippedRelu(float alpha = 0.f, float beta = 1.f) template <typename T>
: alpha_(alpha), beta_(beta) __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
y = math::min(casted_beta, math::max(casted_alpha, x));
} }
__host__ __device__ float get_alpha() const { return alpha_; }
__host__ __device__ float get_beta() const { return beta_; }
const float alpha_; const float alpha_;
const float beta_; const float beta_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
float casted_beta = type_convert<float>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
double casted_beta = type_convert<double>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
int32_t casted_beta = type_convert<int32_t>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
int8_t casted_beta = type_convert<int8_t>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
half_t casted_beta = type_convert<half_t>(beta_);
y = math::min(casted_beta, math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_);
bhalf_t casted_beta = type_convert<bhalf_t>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
}; };
struct LeakyRelu final : public UnaryOpBase struct LeakyRelu
{ {
__host__ __device__ constexpr LeakyRelu(const LeakyRelu&) = default; LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
__host__ __device__ constexpr LeakyRelu(LeakyRelu&&) = default;
__host__ __device__ ~LeakyRelu() = default;
__host__ __device__ LeakyRelu(float alpha = 0.f) : alpha_(alpha) {}
__host__ __device__ float get_alpha() const { return alpha_; }
const float alpha_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()([[maybe_unused]] bhalf_t& y, template <typename T>
[[maybe_unused]] const bhalf_t& x) const final __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
} }
const float alpha_;
}; };
struct Elu final : public UnaryOpBase struct Elu
{ {
__host__ __device__ constexpr Elu(const Elu&) = default; Elu(float alpha = 1.f) : alpha_(alpha){};
__host__ __device__ constexpr Elu(Elu&&) = default;
__host__ __device__ ~Elu() = default;
__host__ __device__ Elu(float alpha = 1.f) : alpha_(alpha) {}
__host__ __device__ float get_alpha() const { return alpha_; }
const float alpha_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
y = x > 0 ? x : casted_alpha * math::expm1(x);
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{ {
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_); static_assert(is_same<T, float>::value || is_same<T, double>::value ||
y = x > 0 ? x : casted_alpha * ck::math::expm1(x); is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x > 0 ? x : casted_alpha * math::expm1(x);
} }
const float alpha_;
}; };
struct Logistic final : public UnaryOpBase struct Logistic
{ {
__host__ __device__ constexpr Logistic(const Logistic&) = default; Logistic(float alpha = 1.f) : alpha_(alpha){};
__host__ __device__ constexpr Logistic(Logistic&&) = default;
__host__ __device__ ~Logistic() = default;
__host__ __device__ Logistic(float alpha = 1.0f) : alpha_(alpha) {}
__host__ __device__ float get_alpha() const { return alpha_; }
const float alpha_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
constexpr float one = type_convert<float>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
constexpr double one = type_convert<double>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
constexpr int32_t one = type_convert<int32_t>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final template <typename T>
{ __host__ __device__ void operator()(T& y, const T& x) const
int8_t casted_alpha = type_convert<int8_t>(alpha_);
constexpr int8_t one = type_convert<int8_t>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
constexpr half_t one = type_convert<half_t>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{ {
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_); static_assert(is_same<T, float>::value || is_same<T, double>::value ||
constexpr bhalf_t one = type_convert<bhalf_t>(1); is_same<T, half_t>::value || is_same<T, int32_t>::value ||
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
} }
const float alpha_;
}; };
struct ConvInvscale struct ConvInvscale
...@@ -1728,7 +1334,7 @@ struct ConvScaleRelu ...@@ -1728,7 +1334,7 @@ struct ConvScaleRelu
__host__ __device__ void operator()<f8_t, float>(f8_t& e, const float& c) const __host__ __device__ void operator()<f8_t, float>(f8_t& e, const float& c) const
{ {
float x; float x;
Relu{}(x, c * scale_in_ * scale_wei_); Relu{}.template operator()<float>(x, c * scale_in_ * scale_wei_);
e = type_convert<f8_t>(x * scale_out_); e = type_convert<f8_t>(x * scale_out_);
}; };
...@@ -1809,225 +1415,138 @@ struct FastNumericArrayConverter<uint8_t, half_t, N> ...@@ -1809,225 +1415,138 @@ struct FastNumericArrayConverter<uint8_t, half_t, N>
struct DynamicUnaryOp struct DynamicUnaryOp
{ {
DynamicUnaryOp& operator=(const DynamicUnaryOp& other)
{
if(this != &other)
{
unary_op_ptr_ = other.unary_op_ptr_;
unary_op_type_ = other.unary_op_type_;
}
return *this;
}
__host__ __device__ DynamicUnaryOp() = delete; __host__ __device__ DynamicUnaryOp() = delete;
__host__ __device__ DynamicUnaryOp(const Swish& swish) __host__ __device__ DynamicUnaryOp(const Swish& swish)
: unary_op_type_(UnaryOpType::Swish), swish_{swish.beta_}
{ {
unary_op_type_ = UnaryOpType::Swish;
beta = swish.get_beta();
} }
__host__ __device__ DynamicUnaryOp(const Swish&& swish) __host__ __device__ DynamicUnaryOp(const Swish&& swish)
: unary_op_type_(UnaryOpType::Swish), swish_{swish.beta_}
{ {
unary_op_type_ = UnaryOpType::Swish;
beta = swish.get_beta();
} }
__host__ __device__ DynamicUnaryOp(const Sigmoid&) { unary_op_type_ = UnaryOpType::Sigmoid; } __host__ __device__ DynamicUnaryOp(const Sigmoid&) : unary_op_type_(UnaryOpType::Sigmoid) {}
__host__ __device__ DynamicUnaryOp(const Sigmoid&&) { unary_op_type_ = UnaryOpType::Sigmoid; } __host__ __device__ DynamicUnaryOp(const Sigmoid&&) : unary_op_type_(UnaryOpType::Sigmoid) {}
__host__ __device__ DynamicUnaryOp(const PassThrough&) __host__ __device__ DynamicUnaryOp(const PassThrough&)
: unary_op_type_(UnaryOpType::PassThrough)
{ {
unary_op_type_ = UnaryOpType::PassThrough;
} }
__host__ __device__ DynamicUnaryOp(const PassThrough&&) __host__ __device__ DynamicUnaryOp(const PassThrough&&)
: unary_op_type_(UnaryOpType::PassThrough)
{ {
unary_op_type_ = UnaryOpType::PassThrough;
} }
__host__ __device__ DynamicUnaryOp(const Logistic& logistic) __host__ __device__ DynamicUnaryOp(const Logistic& logistic)
: unary_op_type_(UnaryOpType::Logistic), logistic_{logistic.alpha_}
{ {
unary_op_type_ = UnaryOpType::Logistic;
alpha = logistic.get_alpha();
} }
__host__ __device__ DynamicUnaryOp(const Logistic&& logistic) __host__ __device__ DynamicUnaryOp(const Logistic&& logistic)
: unary_op_type_(UnaryOpType::Logistic), logistic_{logistic.alpha_}
{ {
unary_op_type_ = UnaryOpType::Logistic;
alpha = logistic.get_alpha();
} }
__host__ __device__ DynamicUnaryOp(const TanH&) { unary_op_type_ = UnaryOpType::TanH; } __host__ __device__ DynamicUnaryOp(const TanH&) : unary_op_type_(UnaryOpType::TanH) {}
__host__ __device__ DynamicUnaryOp(const TanH&&) { unary_op_type_ = UnaryOpType::TanH; } __host__ __device__ DynamicUnaryOp(const TanH&&) : unary_op_type_(UnaryOpType::TanH) {}
__host__ __device__ DynamicUnaryOp(const Relu&) { unary_op_type_ = UnaryOpType::Relu; } __host__ __device__ DynamicUnaryOp(const Relu&) : unary_op_type_(UnaryOpType::Relu) {}
__host__ __device__ DynamicUnaryOp(const Relu&&) { unary_op_type_ = UnaryOpType::Relu; } __host__ __device__ DynamicUnaryOp(const Relu&&) : unary_op_type_(UnaryOpType::Relu) {}
__host__ __device__ DynamicUnaryOp(const SoftRelu& softrelu) __host__ __device__ DynamicUnaryOp(const SoftRelu& softrelu)
: unary_op_type_(UnaryOpType::SoftRelu), soft_relu_{softrelu.alpha_}
{ {
unary_op_type_ = UnaryOpType::SoftRelu;
alpha = softrelu.get_alpha();
} }
__host__ __device__ DynamicUnaryOp(const SoftRelu&& softrelu) __host__ __device__ DynamicUnaryOp(const SoftRelu&& softrelu)
: unary_op_type_(UnaryOpType::SoftRelu), soft_relu_{softrelu.alpha_}
{ {
unary_op_type_ = UnaryOpType::SoftRelu;
alpha = softrelu.get_alpha();
} }
__host__ __device__ DynamicUnaryOp(const UnaryAbs&) { unary_op_type_ = UnaryOpType::UnaryAbs; } __host__ __device__ DynamicUnaryOp(const UnaryAbs&) : unary_op_type_(UnaryOpType::UnaryAbs) {}
__host__ __device__ DynamicUnaryOp(const UnaryAbs&&) { unary_op_type_ = UnaryOpType::UnaryAbs; } __host__ __device__ DynamicUnaryOp(const UnaryAbs&&) : unary_op_type_(UnaryOpType::UnaryAbs) {}
__host__ __device__ DynamicUnaryOp(const Power& pow) __host__ __device__ DynamicUnaryOp(const Power& pow)
: unary_op_type_(UnaryOpType::Power), power_(pow.alpha_, pow.beta_, pow.gamma_)
{ {
unary_op_type_ = UnaryOpType::Power;
alpha = pow.get_alpha();
beta = pow.get_beta();
gamma = pow.get_gamma();
} }
__host__ __device__ DynamicUnaryOp(const Power&& pow) __host__ __device__ DynamicUnaryOp(const Power&& pow)
: unary_op_type_(UnaryOpType::Power), power_(pow.alpha_, pow.beta_, pow.gamma_)
{ {
unary_op_type_ = UnaryOpType::Power;
alpha = pow.get_alpha();
beta = pow.get_beta();
gamma = pow.get_gamma();
} }
__host__ __device__ DynamicUnaryOp(const ClippedRelu& clippedrelu) __host__ __device__ DynamicUnaryOp(const ClippedRelu& clippedrelu)
: unary_op_type_(UnaryOpType::ClippedRelu),
clipped_relu_{clippedrelu.alpha_, clippedrelu.beta_}
{ {
unary_op_type_ = UnaryOpType::ClippedRelu;
alpha = clippedrelu.get_alpha();
beta = clippedrelu.get_beta();
} }
__host__ __device__ DynamicUnaryOp(const ClippedRelu&& clippedrelu) __host__ __device__ DynamicUnaryOp(const ClippedRelu&& clippedrelu)
: unary_op_type_(UnaryOpType::ClippedRelu),
clipped_relu_{clippedrelu.alpha_, clippedrelu.beta_}
{ {
unary_op_type_ = UnaryOpType::ClippedRelu;
alpha = clippedrelu.get_alpha();
beta = clippedrelu.get_beta();
} }
__host__ __device__ DynamicUnaryOp(const LeakyRelu& leakyrelu) __host__ __device__ DynamicUnaryOp(const LeakyRelu& leakyrelu)
: unary_op_type_(UnaryOpType::LeakyRelu), leaky_relu_{leakyrelu.alpha_}
{ {
unary_op_type_ = UnaryOpType::LeakyRelu;
alpha = leakyrelu.get_alpha();
} }
__host__ __device__ DynamicUnaryOp(const LeakyRelu&& leakyrelu) __host__ __device__ DynamicUnaryOp(const LeakyRelu&& leakyrelu)
: unary_op_type_(UnaryOpType::LeakyRelu), leaky_relu_{leakyrelu.alpha_}
{ {
unary_op_type_ = UnaryOpType::LeakyRelu;
alpha = leakyrelu.get_alpha();
} }
__host__ __device__ DynamicUnaryOp(const Elu& elu) __host__ __device__ DynamicUnaryOp(const Elu& elu)
: unary_op_type_(UnaryOpType::Elu), elu_{elu.alpha_}
{ {
unary_op_type_ = UnaryOpType::Elu;
alpha = elu.get_alpha();
} }
__host__ __device__ DynamicUnaryOp(const Elu&& elu) __host__ __device__ DynamicUnaryOp(const Elu&& elu)
: unary_op_type_(UnaryOpType::Elu), elu_{elu.alpha_}
{ {
unary_op_type_ = UnaryOpType::Elu;
alpha = elu.get_alpha();
} }
__host__ __device__ DynamicUnaryOp(const DynamicUnaryOp& dynamic_op) __host__ __device__ DynamicUnaryOp(const DynamicUnaryOp& dynamic_op) = default;
: unary_op_type_(dynamic_op.unary_op_type_),
unary_op_ptr_(dynamic_op.unary_op_ptr_),
alpha(dynamic_op.alpha),
beta(dynamic_op.beta),
gamma(dynamic_op.gamma)
{
}
__host__ __device__ ~DynamicUnaryOp() __host__ __device__ ~DynamicUnaryOp() {}
{
switch(unary_op_type_)
{
case(UnaryOpType::Swish): delete static_cast<Swish*>(unary_op_ptr_); break;
case(UnaryOpType::Sigmoid): delete static_cast<Sigmoid*>(unary_op_ptr_); break;
case(UnaryOpType::PassThrough): delete static_cast<PassThrough*>(unary_op_ptr_); break;
case(UnaryOpType::Logistic): delete static_cast<Logistic*>(unary_op_ptr_); break;
case(UnaryOpType::TanH): delete static_cast<TanH*>(unary_op_ptr_); break;
case(UnaryOpType::Relu): delete static_cast<Relu*>(unary_op_ptr_); break;
case(UnaryOpType::SoftRelu): delete static_cast<SoftRelu*>(unary_op_ptr_); break;
case(UnaryOpType::UnaryAbs): delete static_cast<UnaryAbs*>(unary_op_ptr_); break;
case(UnaryOpType::Power): delete static_cast<Power*>(unary_op_ptr_); break;
case(UnaryOpType::ClippedRelu): delete static_cast<ClippedRelu*>(unary_op_ptr_); break;
case(UnaryOpType::LeakyRelu): delete static_cast<LeakyRelu*>(unary_op_ptr_); break;
case(UnaryOpType::Elu): delete static_cast<Elu*>(unary_op_ptr_); break;
default: break;
}
}
__device__ void InitUnaryOpPtrOnDevice()
{
switch(unary_op_type_)
{
case(UnaryOpType::Swish): unary_op_ptr_ = new Swish(beta); break;
case(UnaryOpType::Sigmoid): unary_op_ptr_ = new Sigmoid; break;
case(UnaryOpType::PassThrough): unary_op_ptr_ = new PassThrough; break;
case(UnaryOpType::Logistic): unary_op_ptr_ = new Logistic(alpha); break;
case(UnaryOpType::TanH): unary_op_ptr_ = new TanH; break;
case(UnaryOpType::Relu): unary_op_ptr_ = new Relu; break;
case(UnaryOpType::SoftRelu): unary_op_ptr_ = new SoftRelu(alpha); break;
case(UnaryOpType::UnaryAbs): unary_op_ptr_ = new UnaryAbs; break;
case(UnaryOpType::Power): unary_op_ptr_ = new Power(alpha, beta, gamma); break;
case(UnaryOpType::ClippedRelu): unary_op_ptr_ = new ClippedRelu(alpha, beta); break;
case(UnaryOpType::LeakyRelu): unary_op_ptr_ = new LeakyRelu(alpha); break;
case(UnaryOpType::Elu): unary_op_ptr_ = new Elu(alpha); break;
default: unary_op_ptr_ = nullptr; break;
}
}
template <typename Y, typename X>
__device__ void operator()(Y& y, const X& x) const
{
isSupported<X, Y>();
unary_op_ptr_->operator()(y, x);
}
template <typename Y, typename X> template <typename Y, typename X>
__host__ void operator()(Y& y, const X& x) const __host__ __device__ void operator()(Y& y, const X& x) const
{ {
isSupported<X, Y>();
switch(unary_op_type_) switch(unary_op_type_)
{ {
case(UnaryOpType::Swish): Swish{}.operator()(y, x); break; case(UnaryOpType::Swish): swish_(y, x); break;
case(UnaryOpType::Sigmoid): Sigmoid{}.operator()(y, x); break; case(UnaryOpType::Sigmoid): sigmoid_(y, x); break;
case(UnaryOpType::PassThrough): PassThrough{}.operator()(y, x); break; case(UnaryOpType::PassThrough): pass_through_(y, x); break;
case(UnaryOpType::Logistic): Logistic{}.operator()(y, x); break; case(UnaryOpType::Logistic): logistic_(y, x); break;
case(UnaryOpType::TanH): TanH{}.operator()(y, x); break; case(UnaryOpType::TanH): tanh_(y, x); break;
case(UnaryOpType::Relu): Relu{}.operator()(y, x); break; case(UnaryOpType::Relu): relu_(y, x); break;
case(UnaryOpType::SoftRelu): SoftRelu{}.operator()(y, x); break; case(UnaryOpType::SoftRelu): soft_relu_(y, x); break;
case(UnaryOpType::UnaryAbs): UnaryAbs{}.operator()(y, x); break; case(UnaryOpType::UnaryAbs): unary_abs_(y, x); break;
case(UnaryOpType::Power): Power{}.operator()(y, x); break; case(UnaryOpType::Power): power_(y, x); break;
case(UnaryOpType::ClippedRelu): ClippedRelu{}.operator()(y, x); break; case(UnaryOpType::ClippedRelu): clipped_relu_(y, x); break;
case(UnaryOpType::LeakyRelu): LeakyRelu{}.operator()(y, x); break; case(UnaryOpType::LeakyRelu): leaky_relu_(y, x); break;
case(UnaryOpType::Elu): Elu{}.operator()(y, x); break; case(UnaryOpType::Elu): elu_(y, x); break;
default: break; default: break;
} }
} }
template <typename X, typename Y> template <>
__device__ __host__ constexpr void isSupported() const __host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{ {
float y_float;
static_assert(std::is_same<X, Y>::value, "X and Y must be of the same type"); float x_float = type_convert<float>(x);
this->operator()(y_float, x_float);
static_assert(is_same<X, float>::value || is_same<X, double>::value || y = type_convert<bhalf_t>(y_float);
is_same<X, bhalf_t>::value || is_same<X, half_t>::value ||
is_same<X, int32_t>::value || is_same<X, int8_t>::value,
"Data type is not supported by this operation!");
} }
private: private:
...@@ -2049,12 +1568,20 @@ struct DynamicUnaryOp ...@@ -2049,12 +1568,20 @@ struct DynamicUnaryOp
public: public:
UnaryOpType unary_op_type_; UnaryOpType unary_op_type_;
UnaryOpBase* unary_op_ptr_ = nullptr;
float alpha; Swish swish_;
float beta; Sigmoid sigmoid_;
float gamma; PassThrough pass_through_;
Logistic logistic_;
TanH tanh_;
Relu relu_;
SoftRelu soft_relu_;
UnaryAbs unary_abs_;
Power power_;
ClippedRelu clipped_relu_;
LeakyRelu leaky_relu_;
Elu elu_;
}; };
#pragma clang diagnostic pop
} // namespace element_wise } // namespace element_wise
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -31,8 +31,6 @@ struct pk_i4_t ...@@ -31,8 +31,6 @@ struct pk_i4_t
type data; type data;
__host__ __device__ constexpr pk_i4_t() : data{type{}} {} __host__ __device__ constexpr pk_i4_t() : data{type{}} {}
__host__ __device__ constexpr pk_i4_t(type init) : data{init} {} __host__ __device__ constexpr pk_i4_t(type init) : data{init} {}
__host__ __device__ constexpr operator float() const { return static_cast<int8_t>(data); }
}; };
inline constexpr auto next_pow2(uint32_t x) inline constexpr auto next_pow2(uint32_t x)
......
...@@ -29,6 +29,13 @@ struct DynamicBuffer ...@@ -29,6 +29,13 @@ struct DynamicBuffer
ElementSpaceSize element_space_size_; ElementSpaceSize element_space_size_;
T invalid_element_value_ = T{0}; T invalid_element_value_ = T{0};
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<T>, pk_i4_t>)
return 2;
else
return 1;
}();
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size) __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
: p_data_{p_data}, element_space_size_{element_space_size} : p_data_{p_data}, element_space_size_{element_space_size}
{ {
...@@ -82,14 +89,18 @@ struct DynamicBuffer ...@@ -82,14 +89,18 @@ struct DynamicBuffer
return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>, return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>,
t_per_x, t_per_x,
coherence>( coherence>(
p_data_, i, is_valid_element, element_space_size_); p_data_, i, is_valid_element, element_space_size_ / PackedSize);
} }
else else
{ {
return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>, return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
t_per_x, t_per_x,
coherence>( coherence>(
p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); p_data_,
i,
is_valid_element,
element_space_size_ / PackedSize,
invalid_element_value_);
} }
} }
else else
...@@ -191,7 +202,7 @@ struct DynamicBuffer ...@@ -191,7 +202,7 @@ struct DynamicBuffer
dst_buf.p_data_, dst_buf.p_data_,
dst_offset, dst_offset,
is_valid_element, is_valid_element,
element_space_size_); element_space_size_ / PackedSize);
} }
template <typename X, template <typename X,
...@@ -226,7 +237,7 @@ struct DynamicBuffer ...@@ -226,7 +237,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>( amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>(
x, p_data_, i, is_valid_element, element_space_size_); x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
} }
else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds && else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds &&
is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value && is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value &&
...@@ -378,7 +389,7 @@ struct DynamicBuffer ...@@ -378,7 +389,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>( amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_); x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
} }
else else
{ {
...@@ -417,7 +428,7 @@ struct DynamicBuffer ...@@ -417,7 +428,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>( amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_); x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
} }
else if(is_valid_element) else if(is_valid_element)
{ {
......
...@@ -14,6 +14,41 @@ namespace ck { ...@@ -14,6 +14,41 @@ namespace ck {
#define __gfx94__ #define __gfx94__
#endif #endif
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
// Convert fp32 to bf16 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
{
// Nan check
if(x != x)
{
return uint16_t(0x7FC0);
}
union
{
float fp32;
uint32_t int32;
} u = {x};
const uint32_t first_bf16_mantisa_bit = ((u.int32 >> 16) & 1);
constexpr uint32_t rounding_bias = uint32_t((1 << 15) - 1);
return uint16_t((u.int32 + first_bf16_mantisa_bit + rounding_bias) >> 16);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return bf16_convert_rtn<bhalf_t>(x_fp32);
}
// Convert X to Y, both X and Y are non-const data types. // Convert X to Y, both X and Y are non-const data types.
template <typename Y, template <typename Y,
typename X, typename X,
...@@ -51,17 +86,15 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t ...@@ -51,17 +86,15 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t
return u.fp32; return u.fp32;
} }
// convert fp32 to bfp16 // convert fp32 to bfp16, round to nearest even
template <> template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x) inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
{ {
union #if CK_USE_RNE_BF16_CONVERSION
{ return bf16_convert_rtn<bhalf_t>(x);
float fp32; #else
uint32_t int32;
} u = {x};
return uint16_t(u.int32 >> 16); return uint16_t(u.int32 >> 16);
#endif
} }
// convert bfp16 to fp16 via fp32 // convert bfp16 to fp16 via fp32
...@@ -635,60 +668,4 @@ inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array ...@@ -635,60 +668,4 @@ inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array
} }
} }
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
// Convert fp32 to bf16 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool flag0 = ~u.int32 & 0x7f800000;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool flag1 = !flag0 && (u.int32 & 0xffff);
u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
return uint16_t(u.int32 >> 16);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return bf16_convert_rtn<bhalf_t>(x_fp32);
}
} // namespace ck } // namespace ck
...@@ -54,7 +54,6 @@ ...@@ -54,7 +54,6 @@
#include "ck_tile/core/tensor/tile_window_linear.hpp" #include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp" #include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp" #include "ck_tile/core/utility/functional_with_tuple.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment