Commit 6f768035 authored by Umang Yadav's avatar Umang Yadav
Browse files

Merge branch 'rocblas_mlir_fp8' into miopen_fp8

parents da7717ce b2542239
......@@ -49,6 +49,12 @@ std::string get_device_name()
return props.gcnArchName;
}
bool gfx_has_fp8_intrinsics()
{
const auto device_name = trim(split_string(get_device_name(), ':').front());
return (starts_with(device_name, "gfx9") and device_name >= "gfx940");
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -21,49 +21,64 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "migraphx/serialize.hpp"
#include <iterator>
#include <utility>
#include <migraphx/eliminate_fp8.hpp>
#include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/time_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace driver {
void eliminate_fp8::apply(module& m) const
struct precompile_op : action<precompile_op>
{
for(auto ins : iterator_for(m))
static program create_preop_program(const operation& preop, std::vector<shape> inputs)
{
if(not contains(op_names, ins->name()) or
ins->get_shape().type() != migraphx::shape::fp8e4m3fnuz_type)
continue;
migraphx::shape::type_t orig_type = ins->get_shape().type();
std::vector<instruction_ref> orig_inputs = ins->inputs();
std::vector<instruction_ref> new_inputs;
std::transform(orig_inputs.begin(),
orig_inputs.end(),
std::back_inserter(new_inputs),
[&](const auto& i) {
return m.insert_instruction(
ins,
migraphx::make_op(
"convert", {{"target_type", migraphx::to_value(target_type)}}),
i);
});
program p;
auto* mm = p.get_main_module();
std::vector<instruction_ref> args;
inputs.pop_back();
transform(inputs, range(inputs.size()), std::back_inserter(args), [&](auto input, auto i) {
return mm->add_parameter("x" + std::to_string(i), input);
});
mm->add_instruction(preop, args);
return p;
}
auto new_ins = m.insert_instruction(ins, ins->get_operator(), {new_inputs});
auto convert_back_ins = m.insert_instruction(
ins,
migraphx::make_op("convert", {{"target_type", migraphx::to_value(orig_type)}}),
new_ins);
m.replace_instruction(ins, convert_back_ins);
static operation get_code_object(const program& p)
{
MIGRAPHX_TIDY_CONST auto* mm = p.get_main_module();
auto it = std::find_if(mm->begin(), mm->end(), [](const auto& ins) {
return (ins.name() == "gpu::code_object");
});
if(it == mm->end())
MIGRAPHX_THROW("Failed to create code object");
return it->get_operator();
}
static void apply(const parser& p, const value& v)
{
context ctx;
auto inputs = p.parse_shapes(v.at("inputs"));
auto name = v.at("name").to<std::string>();
auto preop = make_op(name);
if(v.contains("fields"))
preop.from_value(v.at("fields"));
bool exhaustive = v.get("exhaustive", false);
auto prog = create_preop_program(preop, inputs);
run_passes(prog, {lowering{}, compile_ops{&ctx, exhaustive}});
auto op = get_code_object(prog);
auto t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << preop << ": " << t << "ms" << std::endl;
}
}
};
} // namespace driver
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -38,6 +38,18 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR);
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations.
*
* The variable, if defined, forces MIGraphX to use only specific operations
* with MLIR regardless of the underlying GPU architecture. The variable accepts
* a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution", "dot". If the variable is not defined MIGraphX
* will decide by itself which operations to delegate to MLIR. The variable is
* intended to be primarily used by rocMLIR developers.
*/
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS);
bool mlir_enabled()
{
......@@ -49,6 +61,26 @@ bool mlir_enabled()
#endif
}
static bool is_requested(std::string_view option, bool fallback = false)
{
auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "");
if(string_value.empty())
return fallback;
const auto options = split_string(string_value, ',');
return contains(options, option);
}
bool mlir_attention_enabled()
{
#ifdef MIGRAPHX_MLIR
if(not mlir_enabled())
return false;
return is_requested("attention");
#else
return false;
#endif
}
#ifdef MIGRAPHX_MLIR
struct mlir_op
......@@ -62,41 +94,27 @@ struct mlir_op
return pack(f(self.op, "op"));
}
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
shape compute_shape(const std::vector<shape>& inputs, const std::vector<module_ref>& mods) const
{
module_ref mod = mods[0];
check_shapes{inputs, *this}.packed_or_broadcasted();
if(mods.size() != 1)
MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
module_ref mod = mods[0];
auto type = mod->get_output_shapes().front().type();
auto type = mod->get_output_shapes().front().type();
std::unordered_map<instruction_ref, shape> ins_shapes;
size_t param_cnt = 0;
std::vector<std::string> names = mod->get_parameter_names();
std::sort(names.begin(), names.end());
for(const std::string& param_name : names)
{
ins_shapes[mod->get_parameter(param_name)] = inputs[param_cnt++];
}
for(auto ins : iterator_for(*mod))
{
if(ins->name() == "@param")
{
continue;
}
if(ins->name() == "@literal")
if(ins->name() == "@literal" or ins->name() == "@param")
{
ins_shapes[ins] = ins->get_shape();
continue;
}
if(ins->name() == "@return")
{
auto s = ins_shapes[ins->inputs().at(0)].with_type(type);
if(not s.standard())
MIGRAPHX_THROW("MLIR doesnt support non-standard output");
return s;
return ins_shapes[ins->inputs().at(0)].with_type(type);
}
std::vector<shape> input_shapes;
input_shapes.resize(ins->inputs().size());
......@@ -112,38 +130,55 @@ struct mlir_op
MIGRAPHX_REGISTER_OP(mlir_op);
namespace {
std::tuple<instruction_ref, std::vector<operation>>
get_fusable_input_op_stream(instruction_ref lower_input)
{
instruction_ref upper_input = lower_input;
std::vector<operation> op_stream;
while(contains({"slice",
"transpose",
"multibroadcast",
"broadcast",
"contiguous",
"reshape",
"squeeze",
"flatten",
"unsqueeze"},
upper_input->name()))
{
operation op = upper_input->get_operator();
if(contains({"squeeze", "flatten", "unsqueeze"}, upper_input->name()))
{
op = migraphx::make_op("reshape", {{"dims", upper_input->get_shape().lens()}});
}
op_stream.push_back(op);
upper_input = upper_input->inputs().at(0);
}
return {upper_input, op_stream};
}
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
fuse_input_ops_and_gemm_based_op(module_ref mm,
const std::vector<instruction_ref>& gemm_based_op_inputs,
const operation& gemm_based_op)
{
std::vector<instruction_ref> top_inputs;
std::vector<instruction_ref> imm_inputs;
size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op->inputs())
for(instruction_ref input : gemm_based_op_inputs)
{
std::vector<operation> op_stream;
while(contains(
{"slice", "transpose", "contiguous", "reshape", "squeeze", "flatten", "unsqueeze"},
input->name()))
{
operation op = input->get_operator();
if(contains({"squeeze", "flatten", "unsqueeze"}, input->name()))
{
op = migraphx::make_op("reshape", {{"dims", input->get_shape().lens()}});
}
op_stream.push_back(op);
input = input->inputs().at(0);
}
top_inputs.push_back(input);
auto [upper_input, op_stream] = get_fusable_input_op_stream(input);
top_inputs.push_back(upper_input);
instruction_ref prev_input =
mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape());
mm->add_parameter("y" + std::to_string(input_cnt++), upper_input->get_shape());
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
imm_inputs.push_back(prev_input);
}
instruction_ref new_gemm_based_op =
mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
instruction_ref new_gemm_based_op = mm->add_instruction(gemm_based_op, imm_inputs);
return {new_gemm_based_op, top_inputs};
}
......@@ -183,6 +218,7 @@ auto is_mlir_conv(mlir_mode mode)
return false;
if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false;
auto input_arg_t = ins->inputs().front()->get_shape().type();
value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>();
if(group != 1)
......@@ -190,6 +226,10 @@ auto is_mlir_conv(mlir_mode mode)
// Avoid MLIR assertion: Index < Length && "Invalid index!"
if(ins->get_shape().lens().size() != 4)
return false;
if(ins->get_shape().type() == shape::fp8e4m3fnuz_type)
return true;
if(ins->get_shape().type() == shape::float_type and input_arg_t == shape::fp8e4m3fnuz_type)
return true;
if(ins->get_shape().type() == shape::int8_type)
return true;
if(mode == mlir_mode::int8)
......@@ -205,103 +245,140 @@ auto is_mlir_conv(mlir_mode mode)
});
}
struct find_mlir_fused_ops
std::unordered_map<instruction_ref, instruction_ref>
create_param_map_with_literals(module_ref mm, const module* pm, const shape& shape)
{
mlir_mode conv_mode = mlir_mode::none;
mlir_mode dot_mode = mlir_mode::none;
auto matcher() const
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(*pm))
{
auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op"));
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
}
std::unordered_map<instruction_ref, instruction_ref>
create_param_map_with_literals(module_ref mm, const module* pm, const shape& shape) const
{
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(*pm))
if(ins->name() != "@literal")
{
if(ins->name() != "@literal")
{
continue;
}
literal r = ins->get_literal();
instruction_ref literal = mm->add_literal(r);
instruction_ref mbcast = mm->add_instruction(
make_op("multibroadcast", {{"out_lens", shape.lens()}}), literal);
ins_map[ins] = mbcast;
continue;
}
return ins_map;
literal r = ins->get_literal();
instruction_ref literal = mm->add_literal(r);
instruction_ref mbcast =
mm->add_instruction(make_op("multibroadcast", {{"out_lens", shape.lens()}}), literal);
ins_map[ins] = mbcast;
}
return ins_map;
}
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
bool is_pointwise_op_supported_by_mlir(const instruction& i) const
std::vector<instruction_ref>
fold_pointwise_mod(instruction_ref pm_ins,
module_ref parent_mod,
const std::unordered_map<instruction_ref, instruction_ref>& ins_map)
{
auto* pm = pm_ins->module_inputs().front();
auto names = pm->get_parameter_names();
std::sort(names.begin(), names.end());
std::unordered_map<instruction_ref, instruction_ref> param_map =
create_param_map_with_literals(parent_mod, pm, pm_ins->get_shape());
std::transform(names.begin(),
names.end(),
pm_ins->inputs().begin(),
std::inserter(param_map, param_map.end()),
[&](auto name, auto input) {
if(ins_map.count(input))
return std::make_pair(pm->get_parameter(name), ins_map.at(input));
return std::make_pair(pm->get_parameter(name),
parent_mod->add_parameter(name, input->get_shape()));
});
return parent_mod->insert_instructions(parent_mod->end(), pm, param_map);
}
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
bool is_pointwise_op_supported_by_mlir(const instruction& i)
{
using type_t = shape::type_t;
const auto& name = i.name();
const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type,
type_t::fp8e4m3fnuz_type,
type_t::int8_type,
type_t::int32_type,
type_t::bool_type};
// Preliminary type check.
if(not contains(allowed_types, result_type))
{
return false;
}
const std::initializer_list<std::string> any_type_ops = {"@literal", "@param", "@return"};
const std::initializer_list<std::string> no_bool_ops = {
"convolution",
"quant_convolution",
"dot",
"quant_dot",
"add",
"clip",
"relu",
"sub",
"mul",
"div",
"pow",
"where",
"quantizelinear",
"dequantizelinear",
"abs",
"neg",
};
const std::initializer_list<std::string> fp_only_ops = {
"ceil",
"erf",
"exp",
"floor",
"log",
"recip",
"rsqrt",
"sigmoid",
"softmax",
"tanh",
};
bool is_float =
contains({type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type}, result_type);
if(contains(any_type_ops, name))
return true;
if(result_type != type_t::bool_type and contains(no_bool_ops, name))
return true;
if(is_float and contains(fp_only_ops, name))
return true;
// Only conversions between floating types are known to be unambigiously
// supported.
if(is_float and name == "convert")
{
using type_t = shape::type_t;
const auto& name = i.name();
const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type,
type_t::int8_type,
type_t::fp8e4m3fnuz_type,
type_t::int32_type,
type_t::bool_type};
// Preliminary type check.
if(not contains(allowed_types, result_type))
if(result_type == shape::fp8e4m3fnuz_type)
{
return false;
}
const std::initializer_list<std::string> any_type_ops = {"@literal", "@param", "@return"};
const std::initializer_list<std::string> no_bool_ops = {
"convolution",
"quant_convolution",
"dot",
"quant_dot",
"add",
"clip",
"relu",
"sub",
"mul",
"div",
"pow",
"where",
"quantizelinear",
"dequantizelinear",
"abs",
"neg",
};
const std::initializer_list<std::string> fp_only_ops = {
"ceil",
"erf",
"exp",
"floor",
"log",
"recip",
"rsqrt",
"sigmoid",
"softmax",
"tanh",
};
bool is_float = contains({type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type},
result_type);
if(contains(any_type_ops, name))
return true;
if(result_type != type_t::bool_type and contains(no_bool_ops, name))
return true;
if(is_float and contains(fp_only_ops, name))
return true;
// Only conversions between floating types are known to be unambigiously
// supported.
if(is_float and name == "convert")
{
return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
});
}
} // else
return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
});
}
return false;
}
MIGRAPHX_PRED_MATCHER(mlir_pointwise, instruction_ref ins)
{
if(ins->name() != "pointwise")
return false;
auto* pm = ins->module_inputs().front();
return std::all_of(pm->begin(), pm->end(), [&](const auto& i) {
return is_pointwise_op_supported_by_mlir(i);
});
}
struct find_mlir_fused_ops
{
mlir_mode conv_mode = mlir_mode::none;
mlir_mode dot_mode = mlir_mode::none;
auto matcher() const
{
auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op"));
return mlir_pointwise()(match::any_of[match::inputs()](dot_or_conv.bind("x")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
......@@ -311,29 +388,12 @@ struct find_mlir_fused_ops
auto x_ins = r.instructions["x"]; // input after contiguous
auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names();
// Whitelist pointwise operators.
if(std::any_of(pm->begin(), pm->end(), [&](const auto& i) {
return not is_pointwise_op_supported_by_mlir(i);
}))
return;
std::sort(names.begin(), names.end());
module_ref mm = mpm.create_module("mlir_" + pm->name());
mm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map =
create_param_map_with_literals(mm, pm, gemm_based_op->get_shape());
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, gemm_based_op);
std::transform(names.begin(),
names.end(),
ins->inputs().begin(),
std::inserter(param_map, param_map.end()),
[&, &anchor = anchor_op](auto name, auto input) {
if(input == x_ins)
return std::make_pair(pm->get_parameter(name), anchor);
return std::make_pair(pm->get_parameter(name),
mm->add_parameter(name, input->get_shape()));
});
mm->add_return(mm->insert_instructions(mm->end(), pm, param_map));
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(
mm, gemm_based_op->inputs(), gemm_based_op->get_operator());
mm->add_return(fold_pointwise_mod(ins, mm, {{x_ins, anchor_op}}));
std::vector<instruction_ref> inputs;
std::copy_if(ins->inputs().begin(),
......@@ -351,54 +411,104 @@ struct find_mlir_standalone_op
{
mlir_mode mode = mlir_mode::none;
auto matcher() const { return Matcher(mode); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto conv_based_op = r.result;
// enable only for fp32/fp16/i8 types
if(std::any_of(conv_based_op->inputs().begin(), conv_based_op->inputs().end(), [&](auto i) {
auto gemm_based_op = r.result;
// enable only for fp32/fp16/i8/fp8 types
if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) {
return not contains({shape::type_t::float_type,
shape::type_t::half_type,
shape::type_t::fp8e4m3fnuz_type,
shape::type_t::int8_type},
shape::type_t::int8_type,
shape::type_t::fp8e4m3fnuz_type},
i->get_shape().type());
}))
return;
static size_t counter = 0;
module_ref mm =
mpm.create_module("mlir_" + conv_based_op->name() + std::to_string(counter++));
mpm.create_module("mlir_" + gemm_based_op->name() + std::to_string(counter++));
mm->set_bypass();
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, conv_based_op);
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(
mm, gemm_based_op->inputs(), gemm_based_op->get_operator());
mm->add_return({anchor_op});
mpm.get_module().replace_instruction(
conv_based_op, mlir_op{conv_based_op->get_operator()}, top_inputs, {mm});
gemm_based_op, mlir_op{gemm_based_op->get_operator()}, top_inputs, {mm});
}
};
using find_mlir_standalone_convolution_op = find_mlir_standalone_op<&is_mlir_conv>;
using find_mlir_standalone_dot_op = find_mlir_standalone_op<&is_mlir_dot>;
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations.
*
* The variable, if defined, forces MIGraphX to use only specific operations
* with MLIR regardless of the underlying GPU architecture. The variable accepts
* a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution", "dot". If the variable is not defined MIGraphX
* will decide by itself which operations to delegate to MLIR. The variable is
* intended to be primarily used by rocMLIR developers.
*/
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS);
struct find_mlir_standalone_attention_op
{
auto matcher() const
{
return match::name("gpu::pre_gemm_softmax_gemm").bind("gemm_softmax_gemm");
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
static size_t counter = 0;
module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++));
auto gemm_softmax_gemm = r.instructions["gemm_softmax_gemm"];
std::vector<instruction_ref> inputs;
mm->set_bypass();
bool is_requested(std::string_view option, bool fallback = false)
std::unordered_map<instruction_ref, instruction_ref> ins_map;
auto gemm0_inputs = gemm_softmax_gemm->inputs();
gemm0_inputs.pop_back();
auto [gemm0, top_gemm0_inputs] =
fuse_input_ops_and_gemm_based_op(mm, gemm0_inputs, make_op("dot"));
inputs.insert(inputs.begin(), top_gemm0_inputs.begin(), top_gemm0_inputs.end());
// handle scale
auto v = gemm_softmax_gemm->get_operator().to_value();
assert(v.contains("scale"));
auto scale = v.at("scale").to<float>();
auto scale_lit = mm->add_literal(literal{shape{gemm0->get_shape().type()}, {scale}});
instruction_ref scale_lit_mbcast = mm->add_instruction(
make_op("multibroadcast", {{"out_lens", gemm0->get_shape().lens()}}), scale_lit);
auto scaled_gemm0 = mm->add_instruction(make_op("mul"), gemm0, scale_lit_mbcast);
auto softmax = mm->add_instruction(
make_op("softmax", {{"axis", gemm0->get_shape().lens().size() - 1}}), scaled_gemm0);
auto [old_upper_v, upper_v_op_stream] =
get_fusable_input_op_stream(gemm_softmax_gemm->inputs()[2]);
instruction_ref new_upper_v = mm->add_parameter("z", old_upper_v->get_shape());
for(const auto& op : reverse(upper_v_op_stream))
{
new_upper_v = mm->add_instruction(op, {new_upper_v});
}
inputs.push_back(old_upper_v);
auto gemm1 = mm->add_instruction(make_op("dot"), {softmax, new_upper_v});
ins_map[gemm_softmax_gemm] = gemm1;
auto ins_to_replace = gemm1;
auto ins_to_be_replaced = gemm_softmax_gemm;
if(r.instructions.find("trailing_pm") != r.instructions.end())
{
ins_to_replace = fold_pointwise_mod(r.instructions["trailing_pm"], mm, ins_map)[0];
std::copy_if(r.instructions["trailing_pm"]->inputs().begin(),
r.instructions["trailing_pm"]->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return input != gemm_softmax_gemm; });
ins_to_be_replaced = r.instructions["trailing_pm"];
}
mm->add_return({ins_to_replace});
mpm.get_module().replace_instruction(
ins_to_be_replaced, mlir_op{gemm1->get_operator()}, inputs, {mm});
}
};
struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op
{
auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "");
if(string_value.empty())
return fallback;
const auto options = split_string(string_value, ',');
return contains(options, option);
}
auto matcher() const
{
auto standalone_matcher = find_mlir_standalone_attention_op::matcher();
return mlir_pointwise()(
match::any_of[match::inputs()](standalone_matcher).bind("trailing_pm"));
;
}
};
} // namespace
#endif // MIGRAPHX_MLIR
......@@ -420,13 +530,20 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
mlir_mode mode =
(enabled(MIGRAPHX_ENABLE_EXTRA_MLIR{}) or enable_extra) ? mlir_mode::fast : mlir_mode::none;
// Attention offloads; default disabled
if(mlir_attention_enabled())
{
match::find_matches(mpm, find_mlir_attention_fused_ops{});
match::find_matches(mpm, find_mlir_standalone_attention_op{});
}
match::find_matches(mpm,
find_mlir_fused_ops{.conv_mode = get_mode("fused", mlir_mode::fast),
.dot_mode = get_mode("fused", mode)});
match::find_matches(
mpm,
find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::int8)},
find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::all)},
find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::none)});
#else
(void)mpm;
......
......@@ -22,6 +22,7 @@
* THE SOFTWARE.
*/
#include <rocblas/internal/rocblas-types.h>
#include <rocblas/rocblas.h>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
......@@ -36,6 +37,20 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
/*
Regular rocBLAS API takes compute_type as `rocblas_datatype` enum value v/s "ex3" BETA API takes it
as `rocblas_computetype` enum value. `rb_compute_type` is faciliator to implictly cast integer enum
value to required type that can be used inside `common_args` generator.
*/
struct rb_compute_type
{
int type = 0;
rb_compute_type(rocblas_datatype t) : type(static_cast<int>(t)) {}
rb_compute_type(rocblas_computetype t) : type(static_cast<int>(t)) {}
operator rocblas_datatype() const { return static_cast<rocblas_datatype>(type); }
operator rocblas_computetype() const { return static_cast<rocblas_computetype>(type); }
};
// Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype get_type(shape::type_t type)
{
......@@ -185,12 +200,17 @@ struct gemm_impl
{
output_type = rocblas_datatype_i32_r;
}
compute_type = output_type;
compute_type = rb_compute_type{output_type};
if(compute_fp32)
{
if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r;
}
if(arg_type == rocblas_datatype_f8_r)
{
assert(get_type(input_shapes[1].type()) == rocblas_datatype_f8_r);
compute_type = rocblas_compute_type_f32;
}
auto a_lens = input_shapes[0].lens();
auto b_lens = input_shapes[1].lens();
......@@ -230,7 +250,6 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex3,
common_args,
rocblas_compute_type_f32,
rocblas_gemm_algo_standard,
solution_idx,
gemm_flags);
......@@ -240,7 +259,6 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex3,
common_args,
rocblas_compute_type_f32,
rocblas_gemm_algo_standard,
solution_idx,
gemm_flags);
......@@ -254,7 +272,6 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args,
compute_type,
rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
......@@ -264,7 +281,6 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex,
common_args,
compute_type,
rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
......@@ -304,7 +320,6 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args);
check_valid = rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args,
compute_type,
rocblas_gemm_algo_solution_index,
solution_idx,
rocblas_gemm_flags_check_solution_index);
......@@ -314,7 +329,6 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args);
check_valid = rocblas_invoke(&rocblas_gemm_ex,
common_args,
compute_type,
rocblas_gemm_algo_solution_index,
solution_idx,
rocblas_gemm_flags_check_solution_index);
......@@ -365,7 +379,8 @@ struct gemm_impl
output_type,
ldd,
d_stride,
num_matrices);
num_matrices,
compute_type);
}
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
......@@ -398,7 +413,8 @@ struct gemm_impl
ldc,
is_3inputs ? args[3].data() : args[2].data(),
output_type,
ldd);
ldd,
compute_type);
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
......@@ -428,7 +444,6 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_args,
compute_type,
rocblas_gemm_algo_solution_index,
gemm_flags,
nullptr,
......@@ -438,7 +453,6 @@ struct gemm_impl
auto common_sol_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_sol_args,
compute_type,
rocblas_gemm_algo_solution_index,
gemm_flags,
solution_indices.data(),
......@@ -449,7 +463,6 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_args,
compute_type,
rocblas_gemm_algo_solution_index,
gemm_flags,
nullptr,
......@@ -459,7 +472,6 @@ struct gemm_impl
auto common_sol_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_sol_args,
compute_type,
rocblas_gemm_algo_solution_index,
gemm_flags,
solution_indices.data(),
......@@ -521,7 +533,7 @@ struct gemm_impl
rocblas_int c_stride = 0;
rocblas_int d_stride = 0;
rocblas_datatype arg_type = rocblas_datatype_f32_r;
rocblas_datatype compute_type = rocblas_datatype_f32_r;
rb_compute_type compute_type = rocblas_datatype_f32_r;
rocblas_datatype output_type = rocblas_datatype_f32_r;
bool strided_batched = true;
bool is_3inputs = true;
......
......@@ -58,10 +58,10 @@ struct hiprtc_src_file
MIGRAPHX_GPU_EXPORT bool hip_has_flags(const std::vector<std::string>& flags);
MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>> compile_hip_src_with_hiprtc(
std::vector<hiprtc_src_file> srcs, std::string params, const std::string& arch);
std::vector<hiprtc_src_file> srcs, const std::string& params, const std::string& arch);
MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch);
MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>> compile_hip_src(
const std::vector<src_file>& srcs, const std::string& params, const std::string& arch);
MIGRAPHX_GPU_EXPORT std::string enum_params(std::size_t count, std::string param);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_PAD_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_PAD_HPP
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument MIGRAPHX_DEVICE_EXPORT pad(hipStream_t stream,
argument result,
argument arg1,
float value,
std::vector<std::int64_t> pads);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -37,6 +37,8 @@ MIGRAPHX_GPU_EXPORT std::string get_device_name();
MIGRAPHX_GPU_EXPORT int get_device_id();
MIGRAPHX_GPU_EXPORT bool gfx_has_fp8_intrinsics();
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -34,10 +34,11 @@ struct module_pass_manager;
namespace gpu {
MIGRAPHX_GPU_EXPORT bool mlir_enabled();
MIGRAPHX_GPU_EXPORT bool mlir_attention_enabled();
struct MIGRAPHX_GPU_EXPORT fuse_mlir
{
context* ctx = nullptr;
context* ctx = nullptr;
bool enable_extra = false;
std::string name() const { return "gpu::fuse_mlir"; }
void apply(module_pass_manager& mpm) const;
......
......@@ -66,6 +66,10 @@ struct gemm_softmax_gemm
}
static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); }
static bool is_mlir_supported_type(shape::type_t t)
{
return contains({shape::type_t::float_type, shape::half_type}, t);
}
};
} // namespace gpu
......
......@@ -217,6 +217,12 @@ inline pooling_descriptor make_pooling(const migraphx::op::pooling& op)
ss << op.mode;
MIGRAPHX_THROW(ss.str());
}
if(not std::all_of(
op.dilations.cbegin(), op.dilations.cend(), [](std::size_t d) { return d == 1; }))
{
MIGRAPHX_THROW("Unsupported dilations for pooling: [" + to_string_range(op.dilations) +
"]");
}
auto p = make_obj<pooling_descriptor>(&miopenCreatePoolingDescriptor);
int kdims = op.kdims();
......
......@@ -146,7 +146,6 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler>
vectorize vec{};
auto nelements = options.virtual_inputs.back().elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs));
if(algo == "block")
{
// Vectorize if the axis is a reduction axis
......@@ -170,13 +169,13 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler>
options.kernel_name = "reduce_kernel";
std::string identity = "[](auto x) { return x; }";
auto src = interpolate_string(simple_reduce_kernel,
{{"reduction", v.at("reduction").to<std::string>()},
{"init", v.get("init", std::string{"0"})},
{"read", v.get("read", identity)},
{"write", v.get("write", identity)},
{"algo", algo},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}});
{{"reduction", v.at("reduction").to<std::string>()},
{"init", v.get("init", std::string{"0"})},
{"read", v.get("read", identity)},
{"write", v.get("write", identity)},
{"algo", algo},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}});
options.params += "-Wno-float-equal";
return compile_hip_code_object(src, options);
}
......@@ -267,13 +266,13 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
auto src = interpolate_string(
fused_reduce_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"algo", algo},
{"reduced", "decltype(" + generate_make_shape(reduce_output_shape) + ")"},
{"lambda", v.at("lambda").to<std::string>()},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}});
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"algo", algo},
{"reduced", "decltype(" + generate_make_shape(reduce_output_shape) + ")"},
{"lambda", v.at("lambda").to<std::string>()},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}});
options.params += "-Wno-float-equal";
return compile_hip_code_object(src, options);
}
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -21,41 +21,58 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_PAD_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAD_HPP
#ifndef MIGRAPHX_GUARD_JIT_SCATTER_HPP
#define MIGRAPHX_GUARD_JIT_SCATTER_HPP
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_pad
template <typename Derived>
struct scatter_compiler : compiler<Derived>
{
op::pad op;
template <class Self, class F>
static auto reflect(Self& self, F f)
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return migraphx::reflect(self.op, f);
const auto inputs =
to_shapes(std::vector<instruction_ref>{ins->inputs().begin() + 1, ins->inputs().end()});
hip_compile_options options;
options.set_launch_params(op.to_value(), compute_global_for(ctx, inputs.at(1).elements()));
options.inputs = inputs;
options.output = inputs.back();
options.kernel_name = derived().get_kernel_name(op);
options.virtual_inputs = inputs;
// The compiler protests the inequality comparison in assign_mul when pertaining to floating
// point, despite it making sense in the context. Thus the warning removal.
options.params += "-Wno-float-equal";
const auto src = derived().make_interpolated_string(op);
return prepend_copy_data_to_output(compile_hip_code_object(src, options));
}
std::string name() const { return "gpu::pad"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
compiler_replace prepend_copy_data_to_output(const operation& co) const
{
return shapes.size() - 1;
return {co, [](module& m, instruction_ref ins, const operation& op) {
auto args = ins->inputs();
args.back() =
m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back());
args.erase(args.begin());
return m.replace_instruction(ins, op, args);
}};
}
std::string get_kernel_name(const operation& op) const { return op.name() + "_kernel"; }
const Derived& derived() const { return static_cast<const Derived&>(*this); }
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -21,11 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include "scatter.hpp"
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -55,46 +51,21 @@ MIGRAPHX_GLOBAL void scatternd_kernel(void* in_indices, void* in_updates, void*
)__migraphx__";
struct scatternd_compiler : compiler<scatternd_compiler>
struct scatternd_compiler : scatter_compiler<scatternd_compiler>
{
std::vector<std::string> names() const
{
return {"scatternd_none", "scatternd_add", "scatternd_mul"};
return {
"scatternd_none", "scatternd_add", "scatternd_mul", "scatternd_min", "scatternd_max"};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
std::string make_interpolated_string(const operation& op) const
{
hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.at(1).elements()));
options.inputs = inputs;
options.output = inputs.back();
options.kernel_name = "scatternd_kernel";
options.virtual_inputs = inputs;
auto reduction = "assign_" + v.get("reduction", std::string{"none"});
auto src = interpolate_string(scatternd_kernel, {{"reduction", reduction}});
return compile_hip_code_object(src, options);
const auto reduction = op.name().substr(std::char_traits<char>::length("scatternd_"));
return interpolate_string(scatternd_kernel, {{"reduction", "assign_" + reduction}});
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
assert(starts_with(op.name(), "scatternd_"));
auto reduction = op.name().substr(10);
return insert(compile_op(
ctx,
to_shapes(std::vector<instruction_ref>{ins->inputs().begin() + 1, ins->inputs().end()}),
{{"reduction", reduction}}));
}
compiler_replace insert(const operation& co) const
{
return {co, [](module& m, instruction_ref ins, const operation& op) {
auto args = ins->inputs();
args.back() =
m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back());
args.erase(args.begin());
return m.replace_instruction(ins, op, args);
}};
}
std::string get_kernel_name(const operation&) const { return "scatternd_kernel"; }
};
} // namespace gpu
......
......@@ -22,8 +22,12 @@
#ifndef MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
#define MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
#include <migraphx/kernels/type_traits.hpp>
namespace migraphx {
template <typename To, typename From>
template <typename To,
typename From,
MIGRAPHX_REQUIRES(is_trivially_copyable<To>{} and is_trivially_copyable<From>{})>
inline constexpr To bit_cast(From fr) noexcept
{
static_assert(sizeof(To) == sizeof(From));
......
......@@ -365,15 +365,6 @@ struct float8
inline __device__ constexpr float8& operator=(const float8& rhs) = default;
inline __device__ constexpr float8& operator=(float8&& rhs) noexcept = default;
inline __device__ constexpr bool operator==(const float8& rhs) const
{
if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf())
return false;
else if((rhs.is_zero() and this->is_zero()) or (this->data == rhs.data))
return true;
return false;
}
inline __device__ constexpr bool operator<(const float8& rhs) const
{
const auto we = static_cast<float>(*this);
......@@ -403,12 +394,20 @@ using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
}
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_FABS(T) \
inline constexpr __device__ T fabs(T v) \
{ \
/*NOLINTNEXTLINE*/ \
v.data = v.data & 0x7f; \
return v; \
#define MIGRAPHX_FP8_OTHER_OPS(T) \
inline constexpr __device__ T fabs(T v) \
{ \
/*NOLINTNEXTLINE*/ \
v.data = v.data & 0x7f; \
return v; \
} \
inline __device__ constexpr bool operator==(const T& lhs, const T& rhs) \
{ \
if(rhs.is_nan() or rhs.is_inf() or lhs.is_nan() or lhs.is_inf()) \
return false; \
else if((rhs.is_zero() and lhs.is_zero()) or (lhs.data == rhs.data)) \
return true; \
return false; \
}
// NOLINTNEXTLINE
......@@ -417,11 +416,10 @@ using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
MIGRAPHX_FP8_BINARY_OP(-, T, T) \
MIGRAPHX_FP8_BINARY_OP(/, T, T) \
MIGRAPHX_FP8_BINARY_OP(+, T, T) \
MIGRAPHX_FP8_BINARY_OP(==, T, bool) \
MIGRAPHX_FP8_BINARY_OP(>=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(<=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(!=, T, bool) \
MIGRAPHX_FP8_FABS(T)
MIGRAPHX_FP8_OTHER_OPS(T)
MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e5m2)
MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e5m2fnuz)
......@@ -447,7 +445,7 @@ class numeric_limits<fp8e4m3fnuz>
{
return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits());
}
// this is min value that is not DeNorm. DeNorm min is 0x01
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01
static constexpr __device__ fp8e4m3fnuz min()
{
return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits());
......@@ -475,7 +473,7 @@ class numeric_limits<fp8e4m3fn>
}
static constexpr __device__ fp8e4m3fn max() { return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits()); }
// this is min value that is not DeNorm. DeNorm min is 0x01
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01
static constexpr __device__ fp8e4m3fn min() { return fp8e4m3fn(0x08, fp8e4m3fn::from_bits()); }
static constexpr __device__ fp8e4m3fn lowest()
......@@ -503,8 +501,7 @@ class numeric_limits<fp8e5m2fnuz>
{
return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits());
}
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
static constexpr __device__ fp8e5m2fnuz min()
{
return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits());
......@@ -529,8 +526,7 @@ class numeric_limits<fp8e5m2>
}
static constexpr __device__ fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); }
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
static constexpr __device__ fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); }
static constexpr __device__ fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); }
......@@ -539,24 +535,26 @@ class numeric_limits<fp8e5m2>
};
} // namespace fp8
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_MIN_MAX(T) \
template <> \
constexpr T numeric_max<T, void>() \
{ \
return fp8::numeric_limits<T>::max(); \
} \
template <> \
constexpr T numeric_lowest<T>() \
{ \
return fp8::numeric_limits<T>::lowest(); \
}
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e4m3fnuz);
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e5m2fnuz);
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e4m3fn);
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e5m2);
template <class T,
MIGRAPHX_REQUIRES(is_same<T, fp8::fp8e4m3fnuz>{} or is_same<T, fp8::fp8e5m2fnuz>{} or
is_same<T, fp8::fp8e4m3fn>{} or is_same<T, fp8::fp8e5m2>{})>
constexpr T numeric_max(migraphx::fp8::f8_type unused = migraphx::fp8::f8_type::fp8)
{
// unused parameter is added to make this numeric_max different overload definition
// compared to numeric_max defined in type_traits.hpp
(void)(unused);
return fp8::numeric_limits<T>::max();
}
template <class T,
MIGRAPHX_REQUIRES(is_same<T, fp8::fp8e4m3fnuz>{} or is_same<T, fp8::fp8e5m2fnuz>{} or
is_same<T, fp8::fp8e4m3fn>{} or is_same<T, fp8::fp8e5m2>{})>
constexpr T numeric_lowest(migraphx::fp8::f8_type unused = migraphx::fp8::f8_type::fp8)
{
// unused parameter is added to make this numeric_lowest different overload definition
// compared to numeric_lowest defined in type_traits.hpp
(void)(unused);
return fp8::numeric_limits<T>::lowest();
}
} // namespace migraphx
// =================================================================================================
#if defined(__clang__)
......
......@@ -53,35 +53,35 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t,
auto indices_shape_lens = indices_shape.lens;
auto data_shape_lens = data_shape.lens;
auto num_slice_dims = indices_shape_lens.back();
std::size_t num_slices =
size_t num_slices =
accumulate(indices_shape_lens.begin(), indices_shape_lens.end() - 1, 1, op::product{});
std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims,
data_shape_lens.end(),
1,
op::product{});
const std::size_t num_batches =
size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims,
data_shape_lens.end(),
1,
op::product{});
const size_t num_batches =
accumulate(data_shape_lens.begin(), data_shape_lens.begin() + batch_dims, 1, op::product{});
const std::size_t data_batch_stride =
const size_t data_batch_stride =
accumulate(data_shape_lens.begin() + batch_dims, data_shape_lens.end(), 1, op::product{});
const auto num_slices_per_batch = num_slices / num_batches;
ind.global_stride(output_shape.elements(), [&](auto i) {
const auto* indices_ptr = indices_t.data();
const std::size_t j = i / slice_size;
const std::size_t batch_idx = j / num_slices_per_batch;
const size_t j = i / slice_size;
const size_t batch_idx = j / num_slices_per_batch;
auto* slice_indices = indices_ptr + (j * num_slice_dims);
std::size_t relative_slice_offset = 0;
for(std::size_t idx = 0; idx < num_slice_dims; ++idx)
size_t relative_slice_offset = 0;
for(size_t idx = 0; idx < num_slice_dims; ++idx)
{
int64_t index = slice_indices[idx];
const std::size_t input_dim_idx = batch_dims + idx;
const size_t input_dim_idx = batch_dims + idx;
const auto input_dim = data_shape_lens[input_dim_idx];
MIGRAPHX_ASSERT(index >= -static_cast<int64_t>(input_dim) and
index < static_cast<int64_t>(input_dim));
if(index < 0)
index += input_dim;
std::size_t size_from_slice_dims =
size_t size_from_slice_dims =
accumulate(data_shape_lens.begin() + batch_dims + idx + 1,
data_shape_lens.begin() + batch_dims + num_slice_dims,
slice_size,
......
......@@ -54,12 +54,11 @@ __device__ void generic_binary_layernorm(
using value_type = typename Input1::type;
using vec_value_type = vec_type<value_type>;
constexpr auto relements = r.template elements<Input1>();
constexpr auto relements_r = static_cast<vec_value_type>(1.0 / relements);
constexpr auto relements_r = vec_value_type{1.0 / relements};
auto relements_rsqrt = sqrt(relements_r);
auto means = r.reduce(op::sum{},
make_array<vec_value_type>(static_cast<vec_value_type>(0),
static_cast<vec_value_type>(0)),
make_array<vec_value_type>(vec_value_type{0}, vec_value_type{0}),
[&](auto x) {
auto x_out = x * relements_r;
// dividing x by sqrt(relements) before squaring allows computing
......@@ -71,7 +70,7 @@ __device__ void generic_binary_layernorm(
auto mean_x = means[0];
auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x);
value_type eps_val = static_cast<value_type>(eps);
value_type eps_val = implicit_conversion(eps);
r.inner([&](auto& y, auto x, auto... xs) {
auto m = x - mean_x;
......
......@@ -290,7 +290,7 @@ MIGRAPHX_DEVICE_MATH_VEC(where)
template <class T, class U>
constexpr auto convert(U v)
{
return vec_transform(v)([](auto x) { return static_cast<T>(x); });
return vec_transform(v)([](auto x) -> T { return static_cast<T>(x); });
}
} // namespace migraphx
......
......@@ -118,7 +118,7 @@ struct highest
template <class T>
constexpr operator T() const
{
return numeric_max<vec_type<T>, void>();
return numeric_max<vec_type<T>>();
}
};
} // namespace migraphx
......
......@@ -28,6 +28,7 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/ranges.hpp>
#include <migraphx/kernels/vec.hpp>
namespace migraphx {
......@@ -39,7 +40,6 @@ __device__ void pad(const index& idx,
const PadVal& pad_val)
{
auto output_shape = output.get_shape();
using otype = typename Output::type;
idx.global_stride(output_shape.elements(), [&](auto i) {
// 1. get current multi-index for output
// 2. get the size of the input to determine input boundaries
......@@ -54,9 +54,9 @@ __device__ void pad(const index& idx,
if(any_of(range_multi.begin(), range_multi.end(), [&](auto j) {
return multi[j] < offsets[j] or input_idx[j] >= input_bounds[j];
}))
output[multi] = otype(pad_val);
output[multi] = implicit_conversion(pad_val);
else
output[multi] = otype(input[input_idx]);
output[multi] = implicit_conversion(input[input_idx]);
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment