"yaml-reader/vscode:/vscode.git/clone" did not exist on "d1ef1e8ef14265d7c8ae41155cf84abfe52f147c"
Commit a7ae4f8e authored by Astha Rai's avatar Astha Rai
Browse files

Merge branch 'codegen_hiprtc' of github.com:ROCm/composable_kernel into codegen_hiprtc

parents a6055c3c 781005a5
...@@ -41,6 +41,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf ...@@ -41,6 +41,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
t.prec_sq, t.prec_sq,
t.prec_kw, t.prec_kw,
t.block_m, t.block_m,
t.activation,
t.gate_only, t.gate_only,
t.fused_quant}; t.fused_quant};
auto a1 = fused_moegemm_args{ auto a1 = fused_moegemm_args{
......
...@@ -17,15 +17,67 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile: ...@@ -17,15 +17,67 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
// clang-format off // clang-format off
float r = -1; float r = -1;
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0)
{ {
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0)
{
constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0)
{
constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0)
{
constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1)
{
constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1)
{
constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1)
{
constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a); r = fused_moegemm_<t_>(s, a);
} }
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" && else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1)
{ {
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a); r = fused_moegemm_<t_>(s, a);
} }
// clang-format on // clang-format on
......
...@@ -21,21 +21,31 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) ...@@ -21,21 +21,31 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
typename Ts_::BlockTile_1, typename Ts_::BlockTile_1,
typename Ts_::WarpPerBlock_0, typename Ts_::WarpPerBlock_0,
typename Ts_::WarpTile_0>; typename Ts_::WarpTile_0>;
using f_problem =
ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType, constexpr auto get_activation_ = []() {
typename Ts_::GDataType, if constexpr(Ts_::Activation == 0)
typename Ts_::DDataType, {
typename Ts_::AccDataType, return ck_tile::element_wise::FastGeluAsm{};
typename Ts_::ODataType, }
typename Ts_::AScaleDataType, else
typename Ts_::GScaleDataType, return ck_tile::element_wise::Silu{};
typename Ts_::DScaleDataType, };
typename Ts_::YSmoothScaleDataType, using f_act_ = ck_tile::remove_cvref_t<decltype(get_activation_())>;
typename Ts_::TopkWeightDataType,
typename Ts_::IndexDataType, using f_problem = ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
ck_tile::element_wise::FastGeluAsm, // TODO: hardcoded typename Ts_::GDataType,
f_shape, typename Ts_::DDataType,
f_traits>; typename Ts_::AccDataType,
typename Ts_::ODataType,
typename Ts_::AScaleDataType,
typename Ts_::GScaleDataType,
typename Ts_::DScaleDataType,
typename Ts_::YSmoothScaleDataType,
typename Ts_::TopkWeightDataType,
typename Ts_::IndexDataType,
f_act_, // TODO: hardcoded
f_shape,
f_traits>;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>; // using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem>; using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem>;
......
...@@ -15,7 +15,8 @@ template <typename I, ...@@ -15,7 +15,8 @@ template <typename I,
typename KW, typename KW,
typename BlockTIle_, // seq<b_token, b_interm, b_hidden, b_down> typename BlockTIle_, // seq<b_token, b_interm, b_hidden, b_down>
typename WarpPerBlock_, typename WarpPerBlock_,
typename WarpTile_, // seq<*,*,*>, used to select mfma typename WarpTile_, // seq<*,*,*>, used to select mfma
ck_tile::index_t Activation_ = 0, // 0: Gelu 1: Silu
ck_tile::index_t GateOnly_ = 0, ck_tile::index_t GateOnly_ = 0,
ck_tile::index_t FusedQuant_ = 0> ck_tile::index_t FusedQuant_ = 0>
struct fmoe_ // traits, ugly name, only used for internal struct fmoe_ // traits, ugly name, only used for internal
...@@ -44,10 +45,11 @@ struct fmoe_ // traits, ugly name, only used for internal ...@@ -44,10 +45,11 @@ struct fmoe_ // traits, ugly name, only used for internal
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>; using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>; using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>;
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>; using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_>;
using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>; using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>; using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>;
static constexpr ck_tile::index_t Activation = Activation_; // 0: Gelu 1: Silu
static constexpr ck_tile::index_t GateOnly = GateOnly_; static constexpr ck_tile::index_t GateOnly = GateOnly_;
static constexpr ck_tile::index_t FusedQuant = FusedQuant_; static constexpr ck_tile::index_t FusedQuant = FusedQuant_;
}; };
...@@ -8,7 +8,18 @@ ...@@ -8,7 +8,18 @@
// clang-format off // clang-format off
template float fused_moegemm_< template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0> fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a); >(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on // clang-format on
...@@ -8,7 +8,19 @@ ...@@ -8,7 +8,19 @@
// clang-format off // clang-format off
template float fused_moegemm_< template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0> fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a); >(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on // clang-format on
...@@ -108,12 +108,14 @@ auto create_args(int argc, char* argv[]) ...@@ -108,12 +108,14 @@ auto create_args(int argc, char* argv[])
.insert( .insert(
"gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate") "gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
.insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm") .insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm")
.insert("act", "0", "activation after first gemm. 0:gelu, 1:silu")
.insert("balance", .insert("balance",
"0", "0",
"if set to 1, will try balance the expert in topk-ids(convenient for testing)") "if set to 1, will try balance the expert in topk-ids(convenient for testing)")
.insert("init", .insert("init",
"2", "1",
"init method. 0:random stepped float(fast). 1: random uniform, 2:rand normalized" "init method. 0:random stepped float(fast). 1: random uniform[-0.5, 0.5], 2:rand "
"normalized[0, 1]"
"normalized(slow)") "normalized(slow)")
.insert("seed", "11939", "seed used to do random") .insert("seed", "11939", "seed used to do random")
.insert("warmup", "5", "cold iter") .insert("warmup", "5", "cold iter")
...@@ -135,6 +137,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -135,6 +137,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t intermediate_size = arg_parser.get_int("i"); ck_tile::index_t intermediate_size = arg_parser.get_int("i");
ck_tile::index_t stride = arg_parser.get_int("stride"); ck_tile::index_t stride = arg_parser.get_int("stride");
ck_tile::index_t block_m = arg_parser.get_int("bm"); ck_tile::index_t block_m = arg_parser.get_int("bm");
ck_tile::index_t activation = arg_parser.get_int("act");
if(stride < 0) if(stride < 0)
stride = hidden_size; stride = hidden_size;
std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_i = arg_parser.get_str("prec_i");
...@@ -194,11 +197,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -194,11 +197,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
return std::string(", st:") + std::to_string(stride); return std::string(", st:") + std::to_string(stride);
}(); }();
std::cout << "[" << api_str << "|" << prec_str << "]" std::cout
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str << "[" << api_str << "|" << prec_str << "]"
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp << " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str
<< ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1 << ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
<< ", go:" << gate_only << ", q:" << fused_quant << std::flush; << ", act:"
<< activation
// << ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
<< (gate_only ? ", g1u0" : ", g1u1") << ", q:" << fused_quant << std::flush;
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>; using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
using ADataType = typename TypeConfig::ADataType; using ADataType = typename TypeConfig::ADataType;
...@@ -370,6 +376,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -370,6 +376,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
prec_sq, prec_sq,
prec_kw, prec_kw,
block_m, block_m,
activation,
gate_only, gate_only,
fused_quant}; fused_quant};
...@@ -389,7 +396,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -389,7 +396,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_buf.GetDeviceBuffer(), num_sorted_tiles_buf.GetDeviceBuffer(),
block_m, block_m,
hidden_size, hidden_size,
shared_intermediate_size_0, intermediate_size / tp,
tokens, tokens,
experts, experts,
topk, topk,
...@@ -408,6 +415,28 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -408,6 +415,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< cal_tbps(ave_time) << " TB/s" << std::flush; << cal_tbps(ave_time) << " TB/s" << std::flush;
bool pass = true; bool pass = true;
#define CPU_FUSED_MOE(act_type_) \
ck_tile::reference_fused_moe<AccDataType, act_type_>(a_host, \
g_host, \
d_host, \
sa_host, \
sg_host, \
sd_host, \
sy_host, \
o_host, \
sorted_token_ids_host, \
sorted_weight_host, \
sorted_expert_ids_host, \
num_sorted_tiles_host, \
topk_ids_host, \
block_m, \
tokens, \
experts, \
hidden_size, \
intermediate_size / tp, \
topk, \
gate_only)
if(do_validation) if(do_validation)
{ {
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>( ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
...@@ -419,28 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -419,28 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_host.mData[0], num_sorted_tiles_host.mData[0],
experts, experts,
block_m); block_m);
if(activation == 0)
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>( {
a_host, CPU_FUSED_MOE(ck_tile::element_wise::Gelu);
g_host, }
d_host, else
sa_host, {
sg_host, CPU_FUSED_MOE(ck_tile::element_wise::Silu);
sd_host, }
sy_host,
o_host,
sorted_token_ids_host,
sorted_weight_host,
sorted_expert_ids_host,
num_sorted_tiles_host,
topk_ids_host,
block_m,
tokens,
experts,
hidden_size,
shared_intermediate_size_0,
topk,
gate_only);
auto o_dev = o_buf.ToHost<ODataType>(); auto o_dev = o_buf.ToHost<ODataType>();
// o_dev.savetxt("gpu-out.txt", "float"); // o_dev.savetxt("gpu-out.txt", "float");
...@@ -491,6 +506,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -491,6 +506,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
prec_sq, prec_sq,
prec_kw, prec_kw,
block_m, block_m,
activation,
gate_only, gate_only,
fused_quant}; fused_quant};
...@@ -507,7 +523,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -507,7 +523,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
sorted_expert_ids_buf.GetDeviceBuffer(), sorted_expert_ids_buf.GetDeviceBuffer(),
num_sorted_tiles_buf.GetDeviceBuffer(), num_sorted_tiles_buf.GetDeviceBuffer(),
hidden_size, hidden_size,
shared_intermediate_size_0, intermediate_size / tp,
tokens, tokens,
experts, experts,
topk, topk,
...@@ -529,27 +545,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -529,27 +545,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation) if(do_validation)
{ {
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>( if(activation == 0)
a_host, {
g_host, CPU_FUSED_MOE(ck_tile::element_wise::Gelu);
d_host, }
sa_host, else
sg_host, {
sd_host, CPU_FUSED_MOE(ck_tile::element_wise::Silu);
sy_host, }
o_host,
sorted_token_ids_host,
sorted_weight_host,
sorted_expert_ids_host,
num_sorted_tiles_host,
topk_ids_host,
block_m,
tokens,
experts,
hidden_size,
shared_intermediate_size_0,
topk,
gate_only);
auto o_dev = o_buf.ToHost<ODataType>(); auto o_dev = o_buf.ToHost<ODataType>();
// o_dev.savetxt("gpu-out.txt", "float"); // o_dev.savetxt("gpu-out.txt", "float");
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
...@@ -51,7 +51,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre ...@@ -51,7 +51,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>; using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t< using GemmEpilogue = std::conditional_t<
CShuffleEpilogue, CShuffleEpilogue,
...@@ -63,8 +63,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre ...@@ -63,8 +63,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
kOutputRank, kOutputRank,
1, 1,
0, 0,
TilePartitioner::kM, TilePartitioner::MPerBlock,
TilePartitioner::kN>>, TilePartitioner::NPerBlock>>,
ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
...@@ -72,9 +72,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre ...@@ -72,9 +72,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile:: using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>; GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>; using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
......
...@@ -39,7 +39,7 @@ auto create_args(int argc, char* argv[]) ...@@ -39,7 +39,7 @@ auto create_args(int argc, char* argv[])
.insert("stride_b", "0", "Tensor B stride") .insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride") .insert("stride_c", "0", "Tensor C stride")
.insert("a_layout", "R", "A tensor data layout - Row by default") .insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "R", "B tensor data layout - Row by default") .insert("b_layout", "C", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default") .insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("batch_stride_a", "32768", "Batch A stride") .insert("batch_stride_a", "32768", "Batch A stride")
.insert("batch_stride_b", "16384", "Batch B stride") .insert("batch_stride_b", "16384", "Batch B stride")
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf,
...@@ -179,8 +199,18 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -179,8 +199,18 @@ int run_batched_gemm_example_with_layouts(int argc,
ck_tile::reference_batched_gemm<ADataType, BDataType, AccDataType, CDataType>( ck_tile::reference_batched_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_n_k, c_m_n_host_ref); a_m_k, b_n_k, c_m_n_host_ref);
const float max_accumulated_value =
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref); *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
} }
...@@ -240,7 +270,18 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -240,7 +270,18 @@ int run_batched_gemm_example_with_layouts(int argc,
ck_tile::hip_check_error(hipFree(d_C)); ck_tile::hip_check_error(hipFree(d_C));
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); const float max_accumulated_value =
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_gpu_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl; std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
} }
...@@ -260,11 +301,11 @@ int run_batched_gemm_example(int argc, char* argv[]) ...@@ -260,11 +301,11 @@ int run_batched_gemm_example(int argc, char* argv[])
std::string a_layout = arg_parser.get_str("a_layout"); std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout"); std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "R") // if(a_layout == "R" && b_layout == "R")
{ // {
return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); // return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
} // }
else if(a_layout == "R" && b_layout == "C") if(a_layout == "R" && b_layout == "C")
{ {
return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
} }
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "grouped_gemm.hpp" #include "grouped_gemm.hpp"
#include "utils.hpp"
namespace { namespace {
...@@ -89,12 +88,9 @@ using CodegenPipelineProblem = ...@@ -89,12 +88,9 @@ using CodegenPipelineProblem =
CodegenGemmShape, CodegenGemmShape,
CodegenGemmTraits<ALayout, BLayout, CLayout>>; CodegenGemmTraits<ALayout, BLayout, CLayout>>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
using CodegenGemmPipeline = using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>, ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>>;
CodegenGemmPolicy>;
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
...@@ -102,7 +98,7 @@ using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, ...@@ -102,7 +98,7 @@ using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
GemmEpilogue<CLayout>>; GemmEpilogue<CLayout>>;
}; // namespace }; // namespace
std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs) std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{ {
return ::Kernel<std::nullptr_t, std::nullptr_t, std::nullptr_t>::GetWorkSpaceSize(gemm_descs); return ::Kernel<std::nullptr_t, std::nullptr_t, std::nullptr_t>::GetWorkSpaceSize(gemm_descs);
} }
......
...@@ -41,7 +41,7 @@ auto create_args(int argc, char* argv[]) ...@@ -41,7 +41,7 @@ auto create_args(int argc, char* argv[])
.insert("stride_Bs", "", "Tensor B strides - it is empty by default.") .insert("stride_Bs", "", "Tensor B strides - it is empty by default.")
.insert("stride_Cs", "", "Tensor C strides - it is empty by default.") .insert("stride_Cs", "", "Tensor C strides - it is empty by default.")
.insert("a_layout", "R", "A tensor data layout - Row by default.") .insert("a_layout", "R", "A tensor data layout - Row by default.")
.insert("b_layout", "R", "B tensor data layout - Row by default.") .insert("b_layout", "C", "B tensor data layout - Row by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default.") .insert("c_layout", "R", "C tensor data layout - Row by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.") .insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("warmup", "10", "number of iterations before benchmark the kernel.") .insert("warmup", "10", "number of iterations before benchmark the kernel.")
...@@ -52,8 +52,8 @@ auto create_args(int argc, char* argv[]) ...@@ -52,8 +52,8 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
} }
std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs); std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs);
float grouped_gemm_calc(const std::vector<grouped_gemm_kargs>& gemm_descs, float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s, const ck_tile::stream_config& s,
void* p_workspace_); void* p_workspace_);
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float invoke_gemm(int n_warmup, float invoke_gemm(int n_warmup,
int n_repeat, int n_repeat,
...@@ -11,7 +38,7 @@ float invoke_gemm(int n_warmup, ...@@ -11,7 +38,7 @@ float invoke_gemm(int n_warmup,
{ {
ck_tile::DeviceMem gemm_workspace; ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(GetWorkspaceSize(args)); gemm_workspace.Realloc(get_workspace_size(args));
float ave_time = grouped_gemm<ALayout, BLayout, CLayout>( float ave_time = grouped_gemm<ALayout, BLayout, CLayout>(
args, args,
...@@ -108,16 +135,16 @@ int run_grouped_gemm_example_with_layouts(int argc, ...@@ -108,16 +135,16 @@ int run_grouped_gemm_example_with_layouts(int argc,
const ck_tile::index_t N = Ns[i]; const ck_tile::index_t N = Ns[i];
const ck_tile::index_t K = Ks[i]; const ck_tile::index_t K = Ks[i];
stride_As[i] = f_get_default_stride(M, N, stride_As[i], a_layout); stride_As[i] = ck_tile::get_default_stride(M, N, stride_As[i], is_row_major(a_layout));
stride_Bs[i] = f_get_default_stride(K, N, stride_Bs[i], b_layout); stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
stride_Cs[i] = f_get_default_stride(M, N, stride_Cs[i], CLayout{}); stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{}));
a_m_k_tensors.push_back( a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
ck_tile::HostTensor<ADataType>(f_host_tensor_descriptor(M, K, stride_As[i], a_layout))); ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
b_k_n_tensors.push_back( b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
ck_tile::HostTensor<BDataType>(f_host_tensor_descriptor(K, N, stride_Bs[i], b_layout))); ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout))));
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>( c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
f_host_tensor_descriptor(M, N, stride_Cs[i], CLayout{}))); ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{}))));
std::cout << "gemm[" << i << "]" std::cout << "gemm[" << i << "]"
<< " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc
...@@ -157,12 +184,23 @@ int run_grouped_gemm_example_with_layouts(int argc, ...@@ -157,12 +184,23 @@ int run_grouped_gemm_example_with_layouts(int argc,
{ {
for(int i = 0; i < group_count; ++i) for(int i = 0; i < group_count; ++i)
{ {
ck_tile::HostTensor<CDataType> c_m_n_host_ref( ck_tile::HostTensor<CDataType> c_m_n_host_ref(ck_tile::host_tensor_descriptor(
f_host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{})); Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{})));
c_m_n_host_ref.SetZero(); c_m_n_host_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>( ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
pass &= ck_tile::check_err(c_m_n_tensors[i], c_m_n_host_ref); const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(Ks[i], 1 /*kbatch*/, max_accumulated_value);
pass &= ck_tile::check_err(c_m_n_tensors[i],
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "gemm[" << i
<< "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
} }
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
} }
...@@ -188,10 +226,10 @@ int run_grouped_gemm_example(int argc, char* argv[]) ...@@ -188,10 +226,10 @@ int run_grouped_gemm_example(int argc, char* argv[])
{ {
return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
} }
else if(a_layout == "R" && b_layout == "R") // else if(a_layout == "R" && b_layout == "R")
{ // {
return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); // return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
} // }
else else
{ {
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <typename TLayout>
constexpr auto
f_host_tensor_descriptor(std::size_t row, std::size_t col, std::size_t stride, TLayout layout)
{
using namespace ck_tile::literals;
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
}
}
template <typename TLayout>
constexpr auto
f_get_default_stride(std::size_t row, std::size_t col, std::size_t stride, TLayout layout)
{
if(stride == 0)
{
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
}
...@@ -17,7 +17,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) ...@@ -17,7 +17,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#endif #endif
// to do: add various levels of logging with CK_LOG_LEVEL // to do: add various levels of logging with CK_LOG_LEVEL
#ifndef CK_TIME_KERNEL
#define CK_TIME_KERNEL 1 #define CK_TIME_KERNEL 1
#endif
// constant address space for kernel parameter // constant address space for kernel parameter
// https://llvm.org/docs/AMDGPUUsage.html#address-spaces // https://llvm.org/docs/AMDGPUUsage.html#address-spaces
...@@ -155,6 +157,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) ...@@ -155,6 +157,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// LDS direct loads using inline assembly // LDS direct loads using inline assembly
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0 #define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0
// set rounding to nearest even as default for bf16 conversions
#define CK_USE_RNE_BF16_CONVERSION 1
// set rounding to nearest even as default for f8 conversions // set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 0 #define CK_USE_SR_F8_CONVERSION 0
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -122,19 +122,6 @@ __global__ void ...@@ -122,19 +122,6 @@ __global__ void
static_for<0, NumDTensor, 1>{}( static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; }); [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; });
if constexpr(is_same_v<AElementwiseOperation, element_wise::DynamicUnaryOp>)
{
a_element_op.InitUnaryOpPtrOnDevice();
}
if constexpr(is_same_v<BElementwiseOperation, element_wise::DynamicUnaryOp>)
{
b_element_op.InitUnaryOpPtrOnDevice();
}
if constexpr(is_same_v<CDEElementwiseOperation, element_wise::DynamicUnaryOp>)
{
cde_element_op.InitUnaryOpPtrOnDevice();
}
if constexpr(isMultiA || isMultiB) if constexpr(isMultiA || isMultiB)
{ {
AsPointer p_as_grid_grp; AsPointer p_as_grid_grp;
......
...@@ -247,32 +247,6 @@ struct DequantPack8 ...@@ -247,32 +247,6 @@ struct DequantPack8
constexpr const static bool is_pack8_invocable = true; constexpr const static bool is_pack8_invocable = true;
}; };
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wnon-virtual-dtor"
struct UnaryOpBase
{
public:
__host__ __device__ ~UnaryOpBase() = default;
__host__ __device__ constexpr UnaryOpBase() = default;
__host__ __device__ constexpr UnaryOpBase(const UnaryOpBase&) = default;
__host__ __device__ constexpr UnaryOpBase(UnaryOpBase&&) = default;
__host__ __device__ UnaryOpBase& operator=(const UnaryOpBase&) = default;
__host__ __device__ UnaryOpBase& operator=(UnaryOpBase&&) = default;
__host__ __device__ virtual inline void operator()(float& y, const float& x) const = 0;
__host__ __device__ virtual inline void operator()(double& y, const double& x) const = 0;
__host__ __device__ virtual inline void operator()(int32_t& y, const int32_t& x) const = 0;
__host__ __device__ virtual inline void operator()(int8_t& y, const int8_t& x) const = 0;
__host__ __device__ virtual inline void operator()(half_t& y, const half_t& x) const = 0;
__host__ __device__ virtual inline void operator()(bhalf_t& y, const bhalf_t& x) const = 0;
};
struct PassThroughPack2 struct PassThroughPack2
{ {
template <typename Y, typename X> template <typename Y, typename X>
...@@ -304,27 +278,8 @@ struct PassThroughPack2 ...@@ -304,27 +278,8 @@ struct PassThroughPack2
constexpr const static bool is_pack2_invocable = true; constexpr const static bool is_pack2_invocable = true;
}; };
struct PassThrough final : public UnaryOpBase struct PassThrough
{ {
__host__ __device__ constexpr PassThrough() = default;
__host__ __device__ constexpr PassThrough(const PassThrough&) = default;
__host__ __device__ constexpr PassThrough(PassThrough&&) = default;
__host__ __device__ PassThrough& operator=(const PassThrough&) = default;
__host__ __device__ PassThrough& operator=(PassThrough&&) = default;
__host__ __device__ ~PassThrough() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final { y = x; }
__host__ __device__ inline void operator()(double& y, const double& x) const final { y = x; }
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final { y = x; }
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final { y = x; }
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final { y = x; }
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final { y = x; }
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const; __host__ __device__ void operator()(Y& y, const X& x) const;
...@@ -334,6 +289,12 @@ struct PassThrough final : public UnaryOpBase ...@@ -334,6 +289,12 @@ struct PassThrough final : public UnaryOpBase
y = x; y = x;
} }
template <>
__host__ __device__ void operator()<double, double>(double& y, const double& x) const
{
y = x;
}
template <> template <>
__host__ __device__ void operator()<float, double>(float& y, const double& x) const __host__ __device__ void operator()<float, double>(float& y, const double& x) const
{ {
...@@ -346,12 +307,36 @@ struct PassThrough final : public UnaryOpBase ...@@ -346,12 +307,36 @@ struct PassThrough final : public UnaryOpBase
y = type_convert<double>(x); y = type_convert<double>(x);
} }
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
y = x;
}
template <> template <>
__host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const __host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const
{ {
y = type_convert<half_t>(x); y = type_convert<half_t>(x);
} }
template <>
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const
{
y = x;
}
template <> template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const __host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{ {
...@@ -376,6 +361,12 @@ struct PassThrough final : public UnaryOpBase ...@@ -376,6 +361,12 @@ struct PassThrough final : public UnaryOpBase
y = type_convert<float>(x); y = type_convert<float>(x);
} }
template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = x;
}
template <> template <>
__host__ __device__ void operator()<half_t, int8_t>(half_t& y, const int8_t& x) const __host__ __device__ void operator()<half_t, int8_t>(half_t& y, const int8_t& x) const
{ {
...@@ -675,45 +666,21 @@ struct UnarySquare ...@@ -675,45 +666,21 @@ struct UnarySquare
}; };
}; };
struct UnaryAbs final : public UnaryOpBase struct UnaryAbs
{ {
__host__ __device__ constexpr UnaryAbs() = default; template <typename T>
__host__ __device__ constexpr UnaryAbs(const UnaryAbs&) = default; __host__ __device__ void operator()(T& y, const T& x) const
__host__ __device__ constexpr UnaryAbs(UnaryAbs&&) = default;
__host__ __device__ UnaryAbs& operator=(const UnaryAbs&) = default;
__host__ __device__ UnaryAbs& operator=(UnaryAbs&&) = default;
__host__ __device__ ~UnaryAbs() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
y = ck::math::abs(x);
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
y = ck::math::abs(x);
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{ {
y = ck::math::abs(x);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final static_assert(is_same<T, float>::value || is_same<T, double>::value ||
{ is_same<T, half_t>::value || is_same<T, int32_t>::value ||
y = ck::math::abs(x); is_same<T, int8_t>::value,
} "Data type is not supported by this operation!");
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
y = math::abs(x); y = math::abs(x);
} };
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
y = ck::math::abs(x);
}
template <>
__host__ __device__ void operator()(f8_t& y, const f8_t& x) const __host__ __device__ void operator()(f8_t& y, const f8_t& x) const
{ {
y = ck::type_convert<f8_t>(ck::math::abs(ck::type_convert<float>(x))); y = ck::type_convert<f8_t>(ck::math::abs(ck::type_convert<float>(x)));
...@@ -732,41 +699,20 @@ struct UnarySqrt ...@@ -732,41 +699,20 @@ struct UnarySqrt
}; };
}; };
struct Relu final : public UnaryOpBase struct Relu
{ {
__host__ __device__ constexpr Relu() = default; template <typename T>
__host__ __device__ constexpr Relu(const Relu&) = default; __host__ __device__ void operator()(T& y, const T& x) const
__host__ __device__ constexpr Relu(Relu&&) = default;
__host__ __device__ Relu& operator=(const Relu&) = default;
__host__ __device__ Relu& operator=(Relu&&) = default;
__host__ __device__ ~Relu() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
y = x > 0 ? x : 0;
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
y = x > 0 ? x : 0;
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
y = x > 0 ? x : 0;
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
y = x > 0 ? x : 0;
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
y = x > 0 ? x : 0; y = x > 0 ? x : 0;
} }
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final template <>
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const
{ {
float x_f32 = type_convert<float>(x); float x_f32 = type_convert<float>(x);
float y_f32 = x_f32 > 0 ? x_f32 : 0; float y_f32 = x_f32 > 0 ? x_f32 : 0;
...@@ -913,52 +859,18 @@ struct Gelu ...@@ -913,52 +859,18 @@ struct Gelu
} }
}; };
struct Sigmoid final : public UnaryOpBase struct Sigmoid
{ {
__host__ __device__ constexpr Sigmoid() = default; template <typename T>
__host__ __device__ constexpr Sigmoid(const Sigmoid&) = default; __host__ __device__ void operator()(T& y, const T& x) const
__host__ __device__ constexpr Sigmoid(Sigmoid&&) = default;
__host__ __device__ Sigmoid& operator=(const Sigmoid&) = default;
__host__ __device__ Sigmoid& operator=(Sigmoid&&) = default;
__host__ __device__ ~Sigmoid() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
constexpr float one = type_convert<float>(1);
y = one / (one + math::exp(-x));
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
constexpr double one = type_convert<double>(1);
y = one / (one + ck::math::exp(-x));
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
constexpr int32_t one = type_convert<int32_t>(1);
y = one / (one + ck::math::exp(-x));
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
constexpr int8_t one = type_convert<int8_t>(1);
y = one / (one + ck::math::exp(-x));
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
constexpr half_t one = type_convert<half_t>(1);
y = one / (one + math::exp(-x));
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{ {
constexpr float one = type_convert<float>(1); static_assert(is_same<T, float>::value || is_same<T, double>::value ||
float x_f32 = ck::type_convert<float>(x); is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
float y_f32 = one / (one + ck::math::exp(x_f32)); is_same<T, int32_t>::value,
y = ck::type_convert<bhalf_t>(y_f32); "Data type is not supported by this operation!");
} constexpr T one = type_convert<T>(1);
y = one / (one + math::exp(-x));
};
}; };
struct Silu struct Silu
...@@ -974,44 +886,18 @@ struct Silu ...@@ -974,44 +886,18 @@ struct Silu
}; };
}; };
struct TanH final : public UnaryOpBase struct TanH
{ {
__host__ __device__ constexpr TanH() = default; template <typename T>
__host__ __device__ constexpr TanH(const TanH&) = default; __host__ __device__ void operator()(T& y, const T& x) const
__host__ __device__ constexpr TanH(TanH&&) = default;
__host__ __device__ TanH& operator=(const TanH&) = default;
__host__ __device__ TanH& operator=(TanH&&) = default;
__host__ __device__ ~TanH() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
y = math::tanh(x);
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
y = ck::math::tanh(x);
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
y = ck::math::tanh(x);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
y = ck::math::tanh(x);
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{ {
y = ck::math::tanh(x); static_assert(is_same<T, float>::value || is_same<T, double>::value ||
} is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final y = math::tanh(x);
{ };
y = ck::math::tanh(x);
}
}; };
struct ACos struct ACos
...@@ -1252,418 +1138,138 @@ struct Rcp ...@@ -1252,418 +1138,138 @@ struct Rcp
}; };
}; };
struct Swish final : public UnaryOpBase struct Swish
{ {
__host__ __device__ constexpr Swish(const Swish&) = default; Swish(float beta = 1.0f) : beta_(beta) {}
__host__ __device__ constexpr Swish(Swish&&) = default;
__host__ __device__ ~Swish() = default;
__host__ __device__ Swish(float beta = 1.0f) : beta_(beta) {}
__host__ __device__ float get_beta() const { return beta_; }
const float beta_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<float>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<double>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<int32_t>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<int8_t>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<half_t>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<bhalf_t>(x / (1.f + ck::math::exp(bx)));
}
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const __host__ __device__ void operator()(Y& y, const X& x) const
{ {
static_assert(is_same<X, float>::value || is_same<X, double>::value || static_assert(is_same<X, float>::value || is_same<X, double>::value ||
is_same<X, half_t>::value, is_same<X, ck::half_t>::value || is_same<X, int8_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
static_assert(is_same<Y, float>::value || is_same<Y, double>::value || static_assert(is_same<Y, float>::value || is_same<Y, double>::value ||
is_same<Y, half_t>::value, is_same<Y, ck::half_t>::value || is_same<Y, int8_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
float bx = -beta_ * type_convert<float>(x); float bx = -beta_ * type_convert<float>(x);
y = type_convert<Y>(x / (1.f + math::exp(bx))); y = type_convert<Y>(x / (1.f + math::exp(bx)));
} };
const float beta_;
}; };
struct SoftRelu final : public UnaryOpBase struct SoftRelu
{ {
__host__ __device__ constexpr SoftRelu(const SoftRelu&) = default; SoftRelu(float alpha = 1.f) : alpha_(alpha){};
__host__ __device__ constexpr SoftRelu(SoftRelu&&) = default;
__host__ __device__ ~SoftRelu() = default;
__host__ __device__ SoftRelu(float alpha = 1.0f) : alpha_(alpha) {}
__host__ __device__ float get_alpha() const { return alpha_; }
const float alpha_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
constexpr float one = type_convert<float>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
constexpr double one = type_convert<double>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
constexpr int32_t one = type_convert<int32_t>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
constexpr int8_t one = type_convert<int8_t>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final template <typename T>
{ __host__ __device__ void operator()(T& y, const T& x) const
half_t casted_alpha = type_convert<half_t>(alpha_);
constexpr half_t one = type_convert<half_t>(1);
y = math::log(one + math::exp(x * casted_alpha)) / casted_alpha;
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{ {
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_); static_assert(is_same<T, float>::value || is_same<T, double>::value ||
constexpr bhalf_t one = type_convert<bhalf_t>(1); is_same<T, half_t>::value || is_same<T, int32_t>::value ||
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1);
y = math::log(one + math::exp(x * casted_alpha)) / casted_alpha;
} }
const float alpha_;
}; };
struct Power final : public UnaryOpBase struct Power
{ {
__host__ __device__ constexpr Power(const Power&) = default; Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
__host__ __device__ constexpr Power(Power&&) = default; : alpha_(alpha), beta_(beta), gamma_(gamma){};
__host__ __device__ ~Power() = default;
__host__ __device__ Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) template <typename T>
: alpha_(alpha), beta_(beta), gamma_(gamma) __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
T casted_gamma = type_convert<T>(gamma_);
T shifted_scaled_x = casted_alpha + casted_beta * x;
y = math::pow(shifted_scaled_x, casted_gamma);
} }
__host__ __device__ float get_alpha() const { return alpha_; }
__host__ __device__ float get_beta() const { return beta_; }
__host__ __device__ float get_gamma() const { return gamma_; }
const float alpha_; const float alpha_;
const float beta_; const float beta_;
const float gamma_; const float gamma_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
float casted_beta = type_convert<float>(beta_);
float casted_gamma = type_convert<float>(gamma_);
float shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
double casted_beta = type_convert<double>(beta_);
double casted_gamma = type_convert<double>(gamma_);
double shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
int32_t casted_beta = type_convert<int32_t>(beta_);
int32_t casted_gamma = type_convert<int32_t>(gamma_);
int32_t shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
int8_t casted_beta = type_convert<int8_t>(beta_);
int8_t casted_gamma = type_convert<int8_t>(gamma_);
int8_t shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
half_t casted_beta = type_convert<half_t>(beta_);
half_t casted_gamma = type_convert<half_t>(gamma_);
half_t shifted_scaled_x = casted_alpha + casted_beta * x;
y = math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_);
bhalf_t casted_beta = type_convert<bhalf_t>(beta_);
bhalf_t casted_gamma = type_convert<bhalf_t>(gamma_);
bhalf_t shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
}; };
struct ClippedRelu final : public UnaryOpBase struct ClippedRelu
{ {
__host__ __device__ constexpr ClippedRelu(const ClippedRelu&) = default; ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
__host__ __device__ constexpr ClippedRelu(ClippedRelu&&) = default;
__host__ __device__ ~ClippedRelu() = default;
__host__ __device__ ClippedRelu(float alpha = 0.f, float beta = 1.f) template <typename T>
: alpha_(alpha), beta_(beta) __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
y = math::min(casted_beta, math::max(casted_alpha, x));
} }
__host__ __device__ float get_alpha() const { return alpha_; }
__host__ __device__ float get_beta() const { return beta_; }
const float alpha_; const float alpha_;
const float beta_; const float beta_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
float casted_beta = type_convert<float>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
double casted_beta = type_convert<double>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
int32_t casted_beta = type_convert<int32_t>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
int8_t casted_beta = type_convert<int8_t>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
half_t casted_beta = type_convert<half_t>(beta_);
y = math::min(casted_beta, math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_);
bhalf_t casted_beta = type_convert<bhalf_t>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
}; };
struct LeakyRelu final : public UnaryOpBase struct LeakyRelu
{ {
__host__ __device__ constexpr LeakyRelu(const LeakyRelu&) = default; LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
__host__ __device__ constexpr LeakyRelu(LeakyRelu&&) = default;
__host__ __device__ ~LeakyRelu() = default;
__host__ __device__ LeakyRelu(float alpha = 0.f) : alpha_(alpha) {}
__host__ __device__ float get_alpha() const { return alpha_; }
const float alpha_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()([[maybe_unused]] bhalf_t& y, template <typename T>
[[maybe_unused]] const bhalf_t& x) const final __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
} }
const float alpha_;
}; };
struct Elu final : public UnaryOpBase struct Elu
{ {
__host__ __device__ constexpr Elu(const Elu&) = default; Elu(float alpha = 1.f) : alpha_(alpha){};
__host__ __device__ constexpr Elu(Elu&&) = default;
__host__ __device__ ~Elu() = default;
__host__ __device__ Elu(float alpha = 1.f) : alpha_(alpha) {}
__host__ __device__ float get_alpha() const { return alpha_; }
const float alpha_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
y = x > 0 ? x : casted_alpha * math::expm1(x);
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{ {
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_); static_assert(is_same<T, float>::value || is_same<T, double>::value ||
y = x > 0 ? x : casted_alpha * ck::math::expm1(x); is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x > 0 ? x : casted_alpha * math::expm1(x);
} }
const float alpha_;
}; };
struct Logistic final : public UnaryOpBase struct Logistic
{ {
__host__ __device__ constexpr Logistic(const Logistic&) = default; Logistic(float alpha = 1.f) : alpha_(alpha){};
__host__ __device__ constexpr Logistic(Logistic&&) = default;
__host__ __device__ ~Logistic() = default;
__host__ __device__ Logistic(float alpha = 1.0f) : alpha_(alpha) {}
__host__ __device__ float get_alpha() const { return alpha_; }
const float alpha_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
constexpr float one = type_convert<float>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
constexpr double one = type_convert<double>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
constexpr int32_t one = type_convert<int32_t>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final template <typename T>
{ __host__ __device__ void operator()(T& y, const T& x) const
int8_t casted_alpha = type_convert<int8_t>(alpha_);
constexpr int8_t one = type_convert<int8_t>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
constexpr half_t one = type_convert<half_t>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{ {
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_); static_assert(is_same<T, float>::value || is_same<T, double>::value ||
constexpr bhalf_t one = type_convert<bhalf_t>(1); is_same<T, half_t>::value || is_same<T, int32_t>::value ||
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
} }
const float alpha_;
}; };
struct ConvInvscale struct ConvInvscale
...@@ -1728,7 +1334,7 @@ struct ConvScaleRelu ...@@ -1728,7 +1334,7 @@ struct ConvScaleRelu
__host__ __device__ void operator()<f8_t, float>(f8_t& e, const float& c) const __host__ __device__ void operator()<f8_t, float>(f8_t& e, const float& c) const
{ {
float x; float x;
Relu{}(x, c * scale_in_ * scale_wei_); Relu{}.template operator()<float>(x, c * scale_in_ * scale_wei_);
e = type_convert<f8_t>(x * scale_out_); e = type_convert<f8_t>(x * scale_out_);
}; };
...@@ -1809,225 +1415,138 @@ struct FastNumericArrayConverter<uint8_t, half_t, N> ...@@ -1809,225 +1415,138 @@ struct FastNumericArrayConverter<uint8_t, half_t, N>
struct DynamicUnaryOp struct DynamicUnaryOp
{ {
DynamicUnaryOp& operator=(const DynamicUnaryOp& other)
{
if(this != &other)
{
unary_op_ptr_ = other.unary_op_ptr_;
unary_op_type_ = other.unary_op_type_;
}
return *this;
}
__host__ __device__ DynamicUnaryOp() = delete; __host__ __device__ DynamicUnaryOp() = delete;
__host__ __device__ DynamicUnaryOp(const Swish& swish) __host__ __device__ DynamicUnaryOp(const Swish& swish)
: unary_op_type_(UnaryOpType::Swish), swish_{swish.beta_}
{ {
unary_op_type_ = UnaryOpType::Swish;
beta = swish.get_beta();
} }
__host__ __device__ DynamicUnaryOp(const Swish&& swish) __host__ __device__ DynamicUnaryOp(const Swish&& swish)
: unary_op_type_(UnaryOpType::Swish), swish_{swish.beta_}
{ {
unary_op_type_ = UnaryOpType::Swish;
beta = swish.get_beta();
} }
__host__ __device__ DynamicUnaryOp(const Sigmoid&) { unary_op_type_ = UnaryOpType::Sigmoid; } __host__ __device__ DynamicUnaryOp(const Sigmoid&) : unary_op_type_(UnaryOpType::Sigmoid) {}
__host__ __device__ DynamicUnaryOp(const Sigmoid&&) { unary_op_type_ = UnaryOpType::Sigmoid; } __host__ __device__ DynamicUnaryOp(const Sigmoid&&) : unary_op_type_(UnaryOpType::Sigmoid) {}
__host__ __device__ DynamicUnaryOp(const PassThrough&) __host__ __device__ DynamicUnaryOp(const PassThrough&)
: unary_op_type_(UnaryOpType::PassThrough)
{ {
unary_op_type_ = UnaryOpType::PassThrough;
} }
__host__ __device__ DynamicUnaryOp(const PassThrough&&) __host__ __device__ DynamicUnaryOp(const PassThrough&&)
: unary_op_type_(UnaryOpType::PassThrough)
{ {
unary_op_type_ = UnaryOpType::PassThrough;
} }
__host__ __device__ DynamicUnaryOp(const Logistic& logistic) __host__ __device__ DynamicUnaryOp(const Logistic& logistic)
: unary_op_type_(UnaryOpType::Logistic), logistic_{logistic.alpha_}
{ {
unary_op_type_ = UnaryOpType::Logistic;
alpha = logistic.get_alpha();
} }
__host__ __device__ DynamicUnaryOp(const Logistic&& logistic) __host__ __device__ DynamicUnaryOp(const Logistic&& logistic)
: unary_op_type_(UnaryOpType::Logistic), logistic_{logistic.alpha_}
{ {
unary_op_type_ = UnaryOpType::Logistic;
alpha = logistic.get_alpha();
} }
__host__ __device__ DynamicUnaryOp(const TanH&) { unary_op_type_ = UnaryOpType::TanH; } __host__ __device__ DynamicUnaryOp(const TanH&) : unary_op_type_(UnaryOpType::TanH) {}
__host__ __device__ DynamicUnaryOp(const TanH&&) { unary_op_type_ = UnaryOpType::TanH; } __host__ __device__ DynamicUnaryOp(const TanH&&) : unary_op_type_(UnaryOpType::TanH) {}
__host__ __device__ DynamicUnaryOp(const Relu&) { unary_op_type_ = UnaryOpType::Relu; } __host__ __device__ DynamicUnaryOp(const Relu&) : unary_op_type_(UnaryOpType::Relu) {}
__host__ __device__ DynamicUnaryOp(const Relu&&) { unary_op_type_ = UnaryOpType::Relu; } __host__ __device__ DynamicUnaryOp(const Relu&&) : unary_op_type_(UnaryOpType::Relu) {}
__host__ __device__ DynamicUnaryOp(const SoftRelu& softrelu) __host__ __device__ DynamicUnaryOp(const SoftRelu& softrelu)
: unary_op_type_(UnaryOpType::SoftRelu), soft_relu_{softrelu.alpha_}
{ {
unary_op_type_ = UnaryOpType::SoftRelu;
alpha = softrelu.get_alpha();
} }
__host__ __device__ DynamicUnaryOp(const SoftRelu&& softrelu) __host__ __device__ DynamicUnaryOp(const SoftRelu&& softrelu)
: unary_op_type_(UnaryOpType::SoftRelu), soft_relu_{softrelu.alpha_}
{ {
unary_op_type_ = UnaryOpType::SoftRelu;
alpha = softrelu.get_alpha();
} }
__host__ __device__ DynamicUnaryOp(const UnaryAbs&) { unary_op_type_ = UnaryOpType::UnaryAbs; } __host__ __device__ DynamicUnaryOp(const UnaryAbs&) : unary_op_type_(UnaryOpType::UnaryAbs) {}
__host__ __device__ DynamicUnaryOp(const UnaryAbs&&) { unary_op_type_ = UnaryOpType::UnaryAbs; } __host__ __device__ DynamicUnaryOp(const UnaryAbs&&) : unary_op_type_(UnaryOpType::UnaryAbs) {}
__host__ __device__ DynamicUnaryOp(const Power& pow) __host__ __device__ DynamicUnaryOp(const Power& pow)
: unary_op_type_(UnaryOpType::Power), power_(pow.alpha_, pow.beta_, pow.gamma_)
{ {
unary_op_type_ = UnaryOpType::Power;
alpha = pow.get_alpha();
beta = pow.get_beta();
gamma = pow.get_gamma();
} }
__host__ __device__ DynamicUnaryOp(const Power&& pow) __host__ __device__ DynamicUnaryOp(const Power&& pow)
: unary_op_type_(UnaryOpType::Power), power_(pow.alpha_, pow.beta_, pow.gamma_)
{ {
unary_op_type_ = UnaryOpType::Power;
alpha = pow.get_alpha();
beta = pow.get_beta();
gamma = pow.get_gamma();
} }
__host__ __device__ DynamicUnaryOp(const ClippedRelu& clippedrelu) __host__ __device__ DynamicUnaryOp(const ClippedRelu& clippedrelu)
: unary_op_type_(UnaryOpType::ClippedRelu),
clipped_relu_{clippedrelu.alpha_, clippedrelu.beta_}
{ {
unary_op_type_ = UnaryOpType::ClippedRelu;
alpha = clippedrelu.get_alpha();
beta = clippedrelu.get_beta();
} }
__host__ __device__ DynamicUnaryOp(const ClippedRelu&& clippedrelu) __host__ __device__ DynamicUnaryOp(const ClippedRelu&& clippedrelu)
: unary_op_type_(UnaryOpType::ClippedRelu),
clipped_relu_{clippedrelu.alpha_, clippedrelu.beta_}
{ {
unary_op_type_ = UnaryOpType::ClippedRelu;
alpha = clippedrelu.get_alpha();
beta = clippedrelu.get_beta();
} }
__host__ __device__ DynamicUnaryOp(const LeakyRelu& leakyrelu) __host__ __device__ DynamicUnaryOp(const LeakyRelu& leakyrelu)
: unary_op_type_(UnaryOpType::LeakyRelu), leaky_relu_{leakyrelu.alpha_}
{ {
unary_op_type_ = UnaryOpType::LeakyRelu;
alpha = leakyrelu.get_alpha();
} }
__host__ __device__ DynamicUnaryOp(const LeakyRelu&& leakyrelu) __host__ __device__ DynamicUnaryOp(const LeakyRelu&& leakyrelu)
: unary_op_type_(UnaryOpType::LeakyRelu), leaky_relu_{leakyrelu.alpha_}
{ {
unary_op_type_ = UnaryOpType::LeakyRelu;
alpha = leakyrelu.get_alpha();
} }
__host__ __device__ DynamicUnaryOp(const Elu& elu) __host__ __device__ DynamicUnaryOp(const Elu& elu)
: unary_op_type_(UnaryOpType::Elu), elu_{elu.alpha_}
{ {
unary_op_type_ = UnaryOpType::Elu;
alpha = elu.get_alpha();
} }
__host__ __device__ DynamicUnaryOp(const Elu&& elu) __host__ __device__ DynamicUnaryOp(const Elu&& elu)
: unary_op_type_(UnaryOpType::Elu), elu_{elu.alpha_}
{ {
unary_op_type_ = UnaryOpType::Elu;
alpha = elu.get_alpha();
} }
__host__ __device__ DynamicUnaryOp(const DynamicUnaryOp& dynamic_op) __host__ __device__ DynamicUnaryOp(const DynamicUnaryOp& dynamic_op) = default;
: unary_op_type_(dynamic_op.unary_op_type_),
unary_op_ptr_(dynamic_op.unary_op_ptr_),
alpha(dynamic_op.alpha),
beta(dynamic_op.beta),
gamma(dynamic_op.gamma)
{
}
__host__ __device__ ~DynamicUnaryOp() __host__ __device__ ~DynamicUnaryOp() {}
{
switch(unary_op_type_)
{
case(UnaryOpType::Swish): delete static_cast<Swish*>(unary_op_ptr_); break;
case(UnaryOpType::Sigmoid): delete static_cast<Sigmoid*>(unary_op_ptr_); break;
case(UnaryOpType::PassThrough): delete static_cast<PassThrough*>(unary_op_ptr_); break;
case(UnaryOpType::Logistic): delete static_cast<Logistic*>(unary_op_ptr_); break;
case(UnaryOpType::TanH): delete static_cast<TanH*>(unary_op_ptr_); break;
case(UnaryOpType::Relu): delete static_cast<Relu*>(unary_op_ptr_); break;
case(UnaryOpType::SoftRelu): delete static_cast<SoftRelu*>(unary_op_ptr_); break;
case(UnaryOpType::UnaryAbs): delete static_cast<UnaryAbs*>(unary_op_ptr_); break;
case(UnaryOpType::Power): delete static_cast<Power*>(unary_op_ptr_); break;
case(UnaryOpType::ClippedRelu): delete static_cast<ClippedRelu*>(unary_op_ptr_); break;
case(UnaryOpType::LeakyRelu): delete static_cast<LeakyRelu*>(unary_op_ptr_); break;
case(UnaryOpType::Elu): delete static_cast<Elu*>(unary_op_ptr_); break;
default: break;
}
}
__device__ void InitUnaryOpPtrOnDevice()
{
switch(unary_op_type_)
{
case(UnaryOpType::Swish): unary_op_ptr_ = new Swish(beta); break;
case(UnaryOpType::Sigmoid): unary_op_ptr_ = new Sigmoid; break;
case(UnaryOpType::PassThrough): unary_op_ptr_ = new PassThrough; break;
case(UnaryOpType::Logistic): unary_op_ptr_ = new Logistic(alpha); break;
case(UnaryOpType::TanH): unary_op_ptr_ = new TanH; break;
case(UnaryOpType::Relu): unary_op_ptr_ = new Relu; break;
case(UnaryOpType::SoftRelu): unary_op_ptr_ = new SoftRelu(alpha); break;
case(UnaryOpType::UnaryAbs): unary_op_ptr_ = new UnaryAbs; break;
case(UnaryOpType::Power): unary_op_ptr_ = new Power(alpha, beta, gamma); break;
case(UnaryOpType::ClippedRelu): unary_op_ptr_ = new ClippedRelu(alpha, beta); break;
case(UnaryOpType::LeakyRelu): unary_op_ptr_ = new LeakyRelu(alpha); break;
case(UnaryOpType::Elu): unary_op_ptr_ = new Elu(alpha); break;
default: unary_op_ptr_ = nullptr; break;
}
}
template <typename Y, typename X>
__device__ void operator()(Y& y, const X& x) const
{
isSupported<X, Y>();
unary_op_ptr_->operator()(y, x);
}
template <typename Y, typename X> template <typename Y, typename X>
__host__ void operator()(Y& y, const X& x) const __host__ __device__ void operator()(Y& y, const X& x) const
{ {
isSupported<X, Y>();
switch(unary_op_type_) switch(unary_op_type_)
{ {
case(UnaryOpType::Swish): Swish{}.operator()(y, x); break; case(UnaryOpType::Swish): swish_(y, x); break;
case(UnaryOpType::Sigmoid): Sigmoid{}.operator()(y, x); break; case(UnaryOpType::Sigmoid): sigmoid_(y, x); break;
case(UnaryOpType::PassThrough): PassThrough{}.operator()(y, x); break; case(UnaryOpType::PassThrough): pass_through_(y, x); break;
case(UnaryOpType::Logistic): Logistic{}.operator()(y, x); break; case(UnaryOpType::Logistic): logistic_(y, x); break;
case(UnaryOpType::TanH): TanH{}.operator()(y, x); break; case(UnaryOpType::TanH): tanh_(y, x); break;
case(UnaryOpType::Relu): Relu{}.operator()(y, x); break; case(UnaryOpType::Relu): relu_(y, x); break;
case(UnaryOpType::SoftRelu): SoftRelu{}.operator()(y, x); break; case(UnaryOpType::SoftRelu): soft_relu_(y, x); break;
case(UnaryOpType::UnaryAbs): UnaryAbs{}.operator()(y, x); break; case(UnaryOpType::UnaryAbs): unary_abs_(y, x); break;
case(UnaryOpType::Power): Power{}.operator()(y, x); break; case(UnaryOpType::Power): power_(y, x); break;
case(UnaryOpType::ClippedRelu): ClippedRelu{}.operator()(y, x); break; case(UnaryOpType::ClippedRelu): clipped_relu_(y, x); break;
case(UnaryOpType::LeakyRelu): LeakyRelu{}.operator()(y, x); break; case(UnaryOpType::LeakyRelu): leaky_relu_(y, x); break;
case(UnaryOpType::Elu): Elu{}.operator()(y, x); break; case(UnaryOpType::Elu): elu_(y, x); break;
default: break; default: break;
} }
} }
template <typename X, typename Y> template <>
__device__ __host__ constexpr void isSupported() const __host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{ {
float y_float;
static_assert(std::is_same<X, Y>::value, "X and Y must be of the same type"); float x_float = type_convert<float>(x);
this->operator()(y_float, x_float);
static_assert(is_same<X, float>::value || is_same<X, double>::value || y = type_convert<bhalf_t>(y_float);
is_same<X, bhalf_t>::value || is_same<X, half_t>::value ||
is_same<X, int32_t>::value || is_same<X, int8_t>::value,
"Data type is not supported by this operation!");
} }
private: private:
...@@ -2049,12 +1568,20 @@ struct DynamicUnaryOp ...@@ -2049,12 +1568,20 @@ struct DynamicUnaryOp
public: public:
UnaryOpType unary_op_type_; UnaryOpType unary_op_type_;
UnaryOpBase* unary_op_ptr_ = nullptr;
float alpha; Swish swish_;
float beta; Sigmoid sigmoid_;
float gamma; PassThrough pass_through_;
Logistic logistic_;
TanH tanh_;
Relu relu_;
SoftRelu soft_relu_;
UnaryAbs unary_abs_;
Power power_;
ClippedRelu clipped_relu_;
LeakyRelu leaky_relu_;
Elu elu_;
}; };
#pragma clang diagnostic pop
} // namespace element_wise } // namespace element_wise
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -31,8 +31,6 @@ struct pk_i4_t ...@@ -31,8 +31,6 @@ struct pk_i4_t
type data; type data;
__host__ __device__ constexpr pk_i4_t() : data{type{}} {} __host__ __device__ constexpr pk_i4_t() : data{type{}} {}
__host__ __device__ constexpr pk_i4_t(type init) : data{init} {} __host__ __device__ constexpr pk_i4_t(type init) : data{init} {}
__host__ __device__ constexpr operator float() const { return static_cast<int8_t>(data); }
}; };
inline constexpr auto next_pow2(uint32_t x) inline constexpr auto next_pow2(uint32_t x)
......
...@@ -29,6 +29,13 @@ struct DynamicBuffer ...@@ -29,6 +29,13 @@ struct DynamicBuffer
ElementSpaceSize element_space_size_; ElementSpaceSize element_space_size_;
T invalid_element_value_ = T{0}; T invalid_element_value_ = T{0};
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<T>, pk_i4_t>)
return 2;
else
return 1;
}();
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size) __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
: p_data_{p_data}, element_space_size_{element_space_size} : p_data_{p_data}, element_space_size_{element_space_size}
{ {
...@@ -82,14 +89,18 @@ struct DynamicBuffer ...@@ -82,14 +89,18 @@ struct DynamicBuffer
return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>, return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>,
t_per_x, t_per_x,
coherence>( coherence>(
p_data_, i, is_valid_element, element_space_size_); p_data_, i, is_valid_element, element_space_size_ / PackedSize);
} }
else else
{ {
return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>, return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
t_per_x, t_per_x,
coherence>( coherence>(
p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); p_data_,
i,
is_valid_element,
element_space_size_ / PackedSize,
invalid_element_value_);
} }
} }
else else
...@@ -191,7 +202,7 @@ struct DynamicBuffer ...@@ -191,7 +202,7 @@ struct DynamicBuffer
dst_buf.p_data_, dst_buf.p_data_,
dst_offset, dst_offset,
is_valid_element, is_valid_element,
element_space_size_); element_space_size_ / PackedSize);
} }
template <typename X, template <typename X,
...@@ -226,7 +237,7 @@ struct DynamicBuffer ...@@ -226,7 +237,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>( amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>(
x, p_data_, i, is_valid_element, element_space_size_); x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
} }
else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds && else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds &&
is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value && is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value &&
...@@ -378,7 +389,7 @@ struct DynamicBuffer ...@@ -378,7 +389,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>( amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_); x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
} }
else else
{ {
...@@ -417,7 +428,7 @@ struct DynamicBuffer ...@@ -417,7 +428,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>( amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_); x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
} }
else if(is_valid_element) else if(is_valid_element)
{ {
......
...@@ -14,6 +14,41 @@ namespace ck { ...@@ -14,6 +14,41 @@ namespace ck {
#define __gfx94__ #define __gfx94__
#endif #endif
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
// Convert fp32 to bf16 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
{
// Nan check
if(x != x)
{
return uint16_t(0x7FC0);
}
union
{
float fp32;
uint32_t int32;
} u = {x};
const uint32_t first_bf16_mantisa_bit = ((u.int32 >> 16) & 1);
constexpr uint32_t rounding_bias = uint32_t((1 << 15) - 1);
return uint16_t((u.int32 + first_bf16_mantisa_bit + rounding_bias) >> 16);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return bf16_convert_rtn<bhalf_t>(x_fp32);
}
// Convert X to Y, both X and Y are non-const data types. // Convert X to Y, both X and Y are non-const data types.
template <typename Y, template <typename Y,
typename X, typename X,
...@@ -51,17 +86,15 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t ...@@ -51,17 +86,15 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t
return u.fp32; return u.fp32;
} }
// convert fp32 to bfp16 // convert fp32 to bfp16, round to nearest even
template <> template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x) inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
{ {
union #if CK_USE_RNE_BF16_CONVERSION
{ return bf16_convert_rtn<bhalf_t>(x);
float fp32; #else
uint32_t int32;
} u = {x};
return uint16_t(u.int32 >> 16); return uint16_t(u.int32 >> 16);
#endif
} }
// convert bfp16 to fp16 via fp32 // convert bfp16 to fp16 via fp32
...@@ -635,60 +668,4 @@ inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array ...@@ -635,60 +668,4 @@ inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array
} }
} }
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
// Convert fp32 to bf16 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool flag0 = ~u.int32 & 0x7f800000;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool flag1 = !flag0 && (u.int32 & 0xffff);
u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
return uint16_t(u.int32 >> 16);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return bf16_convert_rtn<bhalf_t>(x_fp32);
}
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment