Commit 6f768035 authored by Umang Yadav's avatar Umang Yadav
Browse files

Merge branch 'rocblas_mlir_fp8' into miopen_fp8

parents da7717ce b2542239
...@@ -49,6 +49,12 @@ std::string get_device_name() ...@@ -49,6 +49,12 @@ std::string get_device_name()
return props.gcnArchName; return props.gcnArchName;
} }
bool gfx_has_fp8_intrinsics()
{
const auto device_name = trim(split_string(get_device_name(), ':').front());
return (starts_with(device_name, "gfx9") and device_name >= "gfx940");
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,49 +21,64 @@ ...@@ -21,49 +21,64 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/serialize.hpp" #include <migraphx/gpu/driver/action.hpp>
#include <iterator> #include <migraphx/gpu/time_op.hpp>
#include <utility> #include <migraphx/gpu/context.hpp>
#include <migraphx/eliminate_fp8.hpp> #include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace driver {
void eliminate_fp8::apply(module& m) const struct precompile_op : action<precompile_op>
{ {
for(auto ins : iterator_for(m)) static program create_preop_program(const operation& preop, std::vector<shape> inputs)
{ {
if(not contains(op_names, ins->name()) or program p;
ins->get_shape().type() != migraphx::shape::fp8e4m3fnuz_type) auto* mm = p.get_main_module();
continue; std::vector<instruction_ref> args;
migraphx::shape::type_t orig_type = ins->get_shape().type(); inputs.pop_back();
std::vector<instruction_ref> orig_inputs = ins->inputs(); transform(inputs, range(inputs.size()), std::back_inserter(args), [&](auto input, auto i) {
std::vector<instruction_ref> new_inputs; return mm->add_parameter("x" + std::to_string(i), input);
std::transform(orig_inputs.begin(), });
orig_inputs.end(), mm->add_instruction(preop, args);
std::back_inserter(new_inputs), return p;
[&](const auto& i) { }
return m.insert_instruction(
ins,
migraphx::make_op(
"convert", {{"target_type", migraphx::to_value(target_type)}}),
i);
});
auto new_ins = m.insert_instruction(ins, ins->get_operator(), {new_inputs}); static operation get_code_object(const program& p)
auto convert_back_ins = m.insert_instruction( {
ins, MIGRAPHX_TIDY_CONST auto* mm = p.get_main_module();
migraphx::make_op("convert", {{"target_type", migraphx::to_value(orig_type)}}), auto it = std::find_if(mm->begin(), mm->end(), [](const auto& ins) {
new_ins); return (ins.name() == "gpu::code_object");
m.replace_instruction(ins, convert_back_ins); });
if(it == mm->end())
MIGRAPHX_THROW("Failed to create code object");
return it->get_operator();
}
static void apply(const parser& p, const value& v)
{
context ctx;
auto inputs = p.parse_shapes(v.at("inputs"));
auto name = v.at("name").to<std::string>();
auto preop = make_op(name);
if(v.contains("fields"))
preop.from_value(v.at("fields"));
bool exhaustive = v.get("exhaustive", false);
auto prog = create_preop_program(preop, inputs);
run_passes(prog, {lowering{}, compile_ops{&ctx, exhaustive}});
auto op = get_code_object(prog);
auto t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << preop << ": " << t << "ms" << std::endl;
} }
} };
} // namespace driver
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -38,6 +38,18 @@ namespace gpu { ...@@ -38,6 +38,18 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR);
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations.
*
* The variable, if defined, forces MIGraphX to use only specific operations
* with MLIR regardless of the underlying GPU architecture. The variable accepts
* a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution", "dot". If the variable is not defined MIGraphX
* will decide by itself which operations to delegate to MLIR. The variable is
* intended to be primarily used by rocMLIR developers.
*/
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS);
bool mlir_enabled() bool mlir_enabled()
{ {
...@@ -49,6 +61,26 @@ bool mlir_enabled() ...@@ -49,6 +61,26 @@ bool mlir_enabled()
#endif #endif
} }
static bool is_requested(std::string_view option, bool fallback = false)
{
auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "");
if(string_value.empty())
return fallback;
const auto options = split_string(string_value, ',');
return contains(options, option);
}
bool mlir_attention_enabled()
{
#ifdef MIGRAPHX_MLIR
if(not mlir_enabled())
return false;
return is_requested("attention");
#else
return false;
#endif
}
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
struct mlir_op struct mlir_op
...@@ -62,41 +94,27 @@ struct mlir_op ...@@ -62,41 +94,27 @@ struct mlir_op
return pack(f(self.op, "op")); return pack(f(self.op, "op"));
} }
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const shape compute_shape(const std::vector<shape>& inputs, const std::vector<module_ref>& mods) const
{ {
module_ref mod = mods[0];
check_shapes{inputs, *this}.packed_or_broadcasted(); check_shapes{inputs, *this}.packed_or_broadcasted();
if(mods.size() != 1) if(mods.size() != 1)
MIGRAPHX_THROW("should have one submodule."); MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2) if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs."); MIGRAPHX_THROW("should have at least two inputs.");
module_ref mod = mods[0]; auto type = mod->get_output_shapes().front().type();
auto type = mod->get_output_shapes().front().type();
std::unordered_map<instruction_ref, shape> ins_shapes; std::unordered_map<instruction_ref, shape> ins_shapes;
size_t param_cnt = 0;
std::vector<std::string> names = mod->get_parameter_names();
std::sort(names.begin(), names.end());
for(const std::string& param_name : names)
{
ins_shapes[mod->get_parameter(param_name)] = inputs[param_cnt++];
}
for(auto ins : iterator_for(*mod)) for(auto ins : iterator_for(*mod))
{ {
if(ins->name() == "@param") if(ins->name() == "@literal" or ins->name() == "@param")
{
continue;
}
if(ins->name() == "@literal")
{ {
ins_shapes[ins] = ins->get_shape(); ins_shapes[ins] = ins->get_shape();
continue; continue;
} }
if(ins->name() == "@return") if(ins->name() == "@return")
{ {
auto s = ins_shapes[ins->inputs().at(0)].with_type(type); return ins_shapes[ins->inputs().at(0)].with_type(type);
if(not s.standard())
MIGRAPHX_THROW("MLIR doesnt support non-standard output");
return s;
} }
std::vector<shape> input_shapes; std::vector<shape> input_shapes;
input_shapes.resize(ins->inputs().size()); input_shapes.resize(ins->inputs().size());
...@@ -112,38 +130,55 @@ struct mlir_op ...@@ -112,38 +130,55 @@ struct mlir_op
MIGRAPHX_REGISTER_OP(mlir_op); MIGRAPHX_REGISTER_OP(mlir_op);
namespace { namespace {
std::tuple<instruction_ref, std::vector<operation>>
get_fusable_input_op_stream(instruction_ref lower_input)
{
instruction_ref upper_input = lower_input;
std::vector<operation> op_stream;
while(contains({"slice",
"transpose",
"multibroadcast",
"broadcast",
"contiguous",
"reshape",
"squeeze",
"flatten",
"unsqueeze"},
upper_input->name()))
{
operation op = upper_input->get_operator();
if(contains({"squeeze", "flatten", "unsqueeze"}, upper_input->name()))
{
op = migraphx::make_op("reshape", {{"dims", upper_input->get_shape().lens()}});
}
op_stream.push_back(op);
upper_input = upper_input->inputs().at(0);
}
return {upper_input, op_stream};
}
std::tuple<instruction_ref, std::vector<instruction_ref>> std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) fuse_input_ops_and_gemm_based_op(module_ref mm,
const std::vector<instruction_ref>& gemm_based_op_inputs,
const operation& gemm_based_op)
{ {
std::vector<instruction_ref> top_inputs; std::vector<instruction_ref> top_inputs;
std::vector<instruction_ref> imm_inputs; std::vector<instruction_ref> imm_inputs;
size_t input_cnt = 0; size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op->inputs()) for(instruction_ref input : gemm_based_op_inputs)
{ {
std::vector<operation> op_stream; auto [upper_input, op_stream] = get_fusable_input_op_stream(input);
while(contains( top_inputs.push_back(upper_input);
{"slice", "transpose", "contiguous", "reshape", "squeeze", "flatten", "unsqueeze"},
input->name()))
{
operation op = input->get_operator();
if(contains({"squeeze", "flatten", "unsqueeze"}, input->name()))
{
op = migraphx::make_op("reshape", {{"dims", input->get_shape().lens()}});
}
op_stream.push_back(op);
input = input->inputs().at(0);
}
top_inputs.push_back(input);
instruction_ref prev_input = instruction_ref prev_input =
mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape()); mm->add_parameter("y" + std::to_string(input_cnt++), upper_input->get_shape());
for(const auto& op : reverse(op_stream)) for(const auto& op : reverse(op_stream))
{ {
prev_input = mm->add_instruction(op, {prev_input}); prev_input = mm->add_instruction(op, {prev_input});
} }
imm_inputs.push_back(prev_input); imm_inputs.push_back(prev_input);
} }
instruction_ref new_gemm_based_op = instruction_ref new_gemm_based_op = mm->add_instruction(gemm_based_op, imm_inputs);
mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
return {new_gemm_based_op, top_inputs}; return {new_gemm_based_op, top_inputs};
} }
...@@ -183,6 +218,7 @@ auto is_mlir_conv(mlir_mode mode) ...@@ -183,6 +218,7 @@ auto is_mlir_conv(mlir_mode mode)
return false; return false;
if(ins->name() != "convolution" and ins->name() != "quant_convolution") if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false; return false;
auto input_arg_t = ins->inputs().front()->get_shape().type();
value v = ins->get_operator().to_value(); value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>(); auto group = v.at("group").to<int>();
if(group != 1) if(group != 1)
...@@ -190,6 +226,10 @@ auto is_mlir_conv(mlir_mode mode) ...@@ -190,6 +226,10 @@ auto is_mlir_conv(mlir_mode mode)
// Avoid MLIR assertion: Index < Length && "Invalid index!" // Avoid MLIR assertion: Index < Length && "Invalid index!"
if(ins->get_shape().lens().size() != 4) if(ins->get_shape().lens().size() != 4)
return false; return false;
if(ins->get_shape().type() == shape::fp8e4m3fnuz_type)
return true;
if(ins->get_shape().type() == shape::float_type and input_arg_t == shape::fp8e4m3fnuz_type)
return true;
if(ins->get_shape().type() == shape::int8_type) if(ins->get_shape().type() == shape::int8_type)
return true; return true;
if(mode == mlir_mode::int8) if(mode == mlir_mode::int8)
...@@ -205,103 +245,140 @@ auto is_mlir_conv(mlir_mode mode) ...@@ -205,103 +245,140 @@ auto is_mlir_conv(mlir_mode mode)
}); });
} }
struct find_mlir_fused_ops std::unordered_map<instruction_ref, instruction_ref>
create_param_map_with_literals(module_ref mm, const module* pm, const shape& shape)
{ {
mlir_mode conv_mode = mlir_mode::none; std::unordered_map<instruction_ref, instruction_ref> ins_map;
mlir_mode dot_mode = mlir_mode::none; for(auto ins : iterator_for(*pm))
auto matcher() const
{ {
auto dot_or_conv = match::skip(match::name("contiguous"))( if(ins->name() != "@literal")
match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op"));
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
}
std::unordered_map<instruction_ref, instruction_ref>
create_param_map_with_literals(module_ref mm, const module* pm, const shape& shape) const
{
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(*pm))
{ {
if(ins->name() != "@literal") continue;
{
continue;
}
literal r = ins->get_literal();
instruction_ref literal = mm->add_literal(r);
instruction_ref mbcast = mm->add_instruction(
make_op("multibroadcast", {{"out_lens", shape.lens()}}), literal);
ins_map[ins] = mbcast;
} }
return ins_map; literal r = ins->get_literal();
instruction_ref literal = mm->add_literal(r);
instruction_ref mbcast =
mm->add_instruction(make_op("multibroadcast", {{"out_lens", shape.lens()}}), literal);
ins_map[ins] = mbcast;
} }
return ins_map;
}
// Whitelist supported fusion options, including imposing type constraints std::vector<instruction_ref>
// for cases where MLIR only supports an operation (usually a pointwise function) fold_pointwise_mod(instruction_ref pm_ins,
// on particular types. module_ref parent_mod,
bool is_pointwise_op_supported_by_mlir(const instruction& i) const const std::unordered_map<instruction_ref, instruction_ref>& ins_map)
{
auto* pm = pm_ins->module_inputs().front();
auto names = pm->get_parameter_names();
std::sort(names.begin(), names.end());
std::unordered_map<instruction_ref, instruction_ref> param_map =
create_param_map_with_literals(parent_mod, pm, pm_ins->get_shape());
std::transform(names.begin(),
names.end(),
pm_ins->inputs().begin(),
std::inserter(param_map, param_map.end()),
[&](auto name, auto input) {
if(ins_map.count(input))
return std::make_pair(pm->get_parameter(name), ins_map.at(input));
return std::make_pair(pm->get_parameter(name),
parent_mod->add_parameter(name, input->get_shape()));
});
return parent_mod->insert_instructions(parent_mod->end(), pm, param_map);
}
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
bool is_pointwise_op_supported_by_mlir(const instruction& i)
{
using type_t = shape::type_t;
const auto& name = i.name();
const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type,
type_t::fp8e4m3fnuz_type,
type_t::int8_type,
type_t::int32_type,
type_t::bool_type};
// Preliminary type check.
if(not contains(allowed_types, result_type))
{
return false;
}
const std::initializer_list<std::string> any_type_ops = {"@literal", "@param", "@return"};
const std::initializer_list<std::string> no_bool_ops = {
"convolution",
"quant_convolution",
"dot",
"quant_dot",
"add",
"clip",
"relu",
"sub",
"mul",
"div",
"pow",
"where",
"quantizelinear",
"dequantizelinear",
"abs",
"neg",
};
const std::initializer_list<std::string> fp_only_ops = {
"ceil",
"erf",
"exp",
"floor",
"log",
"recip",
"rsqrt",
"sigmoid",
"softmax",
"tanh",
};
bool is_float =
contains({type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type}, result_type);
if(contains(any_type_ops, name))
return true;
if(result_type != type_t::bool_type and contains(no_bool_ops, name))
return true;
if(is_float and contains(fp_only_ops, name))
return true;
// Only conversions between floating types are known to be unambigiously
// supported.
if(is_float and name == "convert")
{ {
using type_t = shape::type_t; if(result_type == shape::fp8e4m3fnuz_type)
const auto& name = i.name();
const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type,
type_t::int8_type,
type_t::fp8e4m3fnuz_type,
type_t::int32_type,
type_t::bool_type};
// Preliminary type check.
if(not contains(allowed_types, result_type))
{ {
return false; return false;
} } // else
const std::initializer_list<std::string> any_type_ops = {"@literal", "@param", "@return"}; return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
const std::initializer_list<std::string> no_bool_ops = { return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
"convolution", });
"quant_convolution", }
"dot", return false;
"quant_dot", }
"add",
"clip", MIGRAPHX_PRED_MATCHER(mlir_pointwise, instruction_ref ins)
"relu", {
"sub", if(ins->name() != "pointwise")
"mul",
"div",
"pow",
"where",
"quantizelinear",
"dequantizelinear",
"abs",
"neg",
};
const std::initializer_list<std::string> fp_only_ops = {
"ceil",
"erf",
"exp",
"floor",
"log",
"recip",
"rsqrt",
"sigmoid",
"softmax",
"tanh",
};
bool is_float = contains({type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type},
result_type);
if(contains(any_type_ops, name))
return true;
if(result_type != type_t::bool_type and contains(no_bool_ops, name))
return true;
if(is_float and contains(fp_only_ops, name))
return true;
// Only conversions between floating types are known to be unambigiously
// supported.
if(is_float and name == "convert")
{
return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
});
}
return false; return false;
auto* pm = ins->module_inputs().front();
return std::all_of(pm->begin(), pm->end(), [&](const auto& i) {
return is_pointwise_op_supported_by_mlir(i);
});
}
struct find_mlir_fused_ops
{
mlir_mode conv_mode = mlir_mode::none;
mlir_mode dot_mode = mlir_mode::none;
auto matcher() const
{
auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op"));
return mlir_pointwise()(match::any_of[match::inputs()](dot_or_conv.bind("x")));
} }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
...@@ -311,29 +388,12 @@ struct find_mlir_fused_ops ...@@ -311,29 +388,12 @@ struct find_mlir_fused_ops
auto x_ins = r.instructions["x"]; // input after contiguous auto x_ins = r.instructions["x"]; // input after contiguous
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names(); auto names = pm->get_parameter_names();
// Whitelist pointwise operators.
if(std::any_of(pm->begin(), pm->end(), [&](const auto& i) {
return not is_pointwise_op_supported_by_mlir(i);
}))
return;
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
module_ref mm = mpm.create_module("mlir_" + pm->name()); module_ref mm = mpm.create_module("mlir_" + pm->name());
mm->set_bypass(); mm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map = auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(
create_param_map_with_literals(mm, pm, gemm_based_op->get_shape()); mm, gemm_based_op->inputs(), gemm_based_op->get_operator());
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, gemm_based_op); mm->add_return(fold_pointwise_mod(ins, mm, {{x_ins, anchor_op}}));
std::transform(names.begin(),
names.end(),
ins->inputs().begin(),
std::inserter(param_map, param_map.end()),
[&, &anchor = anchor_op](auto name, auto input) {
if(input == x_ins)
return std::make_pair(pm->get_parameter(name), anchor);
return std::make_pair(pm->get_parameter(name),
mm->add_parameter(name, input->get_shape()));
});
mm->add_return(mm->insert_instructions(mm->end(), pm, param_map));
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::copy_if(ins->inputs().begin(), std::copy_if(ins->inputs().begin(),
...@@ -351,54 +411,104 @@ struct find_mlir_standalone_op ...@@ -351,54 +411,104 @@ struct find_mlir_standalone_op
{ {
mlir_mode mode = mlir_mode::none; mlir_mode mode = mlir_mode::none;
auto matcher() const { return Matcher(mode); } auto matcher() const { return Matcher(mode); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto conv_based_op = r.result; auto gemm_based_op = r.result;
// enable only for fp32/fp16/i8 types // enable only for fp32/fp16/i8/fp8 types
if(std::any_of(conv_based_op->inputs().begin(), conv_based_op->inputs().end(), [&](auto i) { if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) {
return not contains({shape::type_t::float_type, return not contains({shape::type_t::float_type,
shape::type_t::half_type, shape::type_t::half_type,
shape::type_t::fp8e4m3fnuz_type, shape::type_t::int8_type,
shape::type_t::int8_type}, shape::type_t::fp8e4m3fnuz_type},
i->get_shape().type()); i->get_shape().type());
})) }))
return; return;
static size_t counter = 0; static size_t counter = 0;
module_ref mm = module_ref mm =
mpm.create_module("mlir_" + conv_based_op->name() + std::to_string(counter++)); mpm.create_module("mlir_" + gemm_based_op->name() + std::to_string(counter++));
mm->set_bypass(); mm->set_bypass();
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, conv_based_op); auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(
mm, gemm_based_op->inputs(), gemm_based_op->get_operator());
mm->add_return({anchor_op}); mm->add_return({anchor_op});
mpm.get_module().replace_instruction( mpm.get_module().replace_instruction(
conv_based_op, mlir_op{conv_based_op->get_operator()}, top_inputs, {mm}); gemm_based_op, mlir_op{gemm_based_op->get_operator()}, top_inputs, {mm});
} }
}; };
using find_mlir_standalone_convolution_op = find_mlir_standalone_op<&is_mlir_conv>; using find_mlir_standalone_convolution_op = find_mlir_standalone_op<&is_mlir_conv>;
using find_mlir_standalone_dot_op = find_mlir_standalone_op<&is_mlir_dot>; using find_mlir_standalone_dot_op = find_mlir_standalone_op<&is_mlir_dot>;
/** struct find_mlir_standalone_attention_op
* @brief Declares a new MIGraphX environment variable which forces to generate {
* only specific MLIR operations. auto matcher() const
* {
* The variable, if defined, forces MIGraphX to use only specific operations return match::name("gpu::pre_gemm_softmax_gemm").bind("gemm_softmax_gemm");
* with MLIR regardless of the underlying GPU architecture. The variable accepts }
* a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution", "dot". If the variable is not defined MIGraphX void apply(module_pass_manager& mpm, const match::matcher_result& r) const
* will decide by itself which operations to delegate to MLIR. The variable is {
* intended to be primarily used by rocMLIR developers. static size_t counter = 0;
*/ module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++));
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS); auto gemm_softmax_gemm = r.instructions["gemm_softmax_gemm"];
std::vector<instruction_ref> inputs;
mm->set_bypass();
bool is_requested(std::string_view option, bool fallback = false) std::unordered_map<instruction_ref, instruction_ref> ins_map;
auto gemm0_inputs = gemm_softmax_gemm->inputs();
gemm0_inputs.pop_back();
auto [gemm0, top_gemm0_inputs] =
fuse_input_ops_and_gemm_based_op(mm, gemm0_inputs, make_op("dot"));
inputs.insert(inputs.begin(), top_gemm0_inputs.begin(), top_gemm0_inputs.end());
// handle scale
auto v = gemm_softmax_gemm->get_operator().to_value();
assert(v.contains("scale"));
auto scale = v.at("scale").to<float>();
auto scale_lit = mm->add_literal(literal{shape{gemm0->get_shape().type()}, {scale}});
instruction_ref scale_lit_mbcast = mm->add_instruction(
make_op("multibroadcast", {{"out_lens", gemm0->get_shape().lens()}}), scale_lit);
auto scaled_gemm0 = mm->add_instruction(make_op("mul"), gemm0, scale_lit_mbcast);
auto softmax = mm->add_instruction(
make_op("softmax", {{"axis", gemm0->get_shape().lens().size() - 1}}), scaled_gemm0);
auto [old_upper_v, upper_v_op_stream] =
get_fusable_input_op_stream(gemm_softmax_gemm->inputs()[2]);
instruction_ref new_upper_v = mm->add_parameter("z", old_upper_v->get_shape());
for(const auto& op : reverse(upper_v_op_stream))
{
new_upper_v = mm->add_instruction(op, {new_upper_v});
}
inputs.push_back(old_upper_v);
auto gemm1 = mm->add_instruction(make_op("dot"), {softmax, new_upper_v});
ins_map[gemm_softmax_gemm] = gemm1;
auto ins_to_replace = gemm1;
auto ins_to_be_replaced = gemm_softmax_gemm;
if(r.instructions.find("trailing_pm") != r.instructions.end())
{
ins_to_replace = fold_pointwise_mod(r.instructions["trailing_pm"], mm, ins_map)[0];
std::copy_if(r.instructions["trailing_pm"]->inputs().begin(),
r.instructions["trailing_pm"]->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return input != gemm_softmax_gemm; });
ins_to_be_replaced = r.instructions["trailing_pm"];
}
mm->add_return({ins_to_replace});
mpm.get_module().replace_instruction(
ins_to_be_replaced, mlir_op{gemm1->get_operator()}, inputs, {mm});
}
};
struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op
{ {
auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, ""); auto matcher() const
if(string_value.empty()) {
return fallback; auto standalone_matcher = find_mlir_standalone_attention_op::matcher();
const auto options = split_string(string_value, ','); return mlir_pointwise()(
return contains(options, option); match::any_of[match::inputs()](standalone_matcher).bind("trailing_pm"));
} ;
}
};
} // namespace } // namespace
#endif // MIGRAPHX_MLIR #endif // MIGRAPHX_MLIR
...@@ -420,13 +530,20 @@ void fuse_mlir::apply(module_pass_manager& mpm) const ...@@ -420,13 +530,20 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
mlir_mode mode = mlir_mode mode =
(enabled(MIGRAPHX_ENABLE_EXTRA_MLIR{}) or enable_extra) ? mlir_mode::fast : mlir_mode::none; (enabled(MIGRAPHX_ENABLE_EXTRA_MLIR{}) or enable_extra) ? mlir_mode::fast : mlir_mode::none;
// Attention offloads; default disabled
if(mlir_attention_enabled())
{
match::find_matches(mpm, find_mlir_attention_fused_ops{});
match::find_matches(mpm, find_mlir_standalone_attention_op{});
}
match::find_matches(mpm, match::find_matches(mpm,
find_mlir_fused_ops{.conv_mode = get_mode("fused", mlir_mode::fast), find_mlir_fused_ops{.conv_mode = get_mode("fused", mlir_mode::fast),
.dot_mode = get_mode("fused", mode)}); .dot_mode = get_mode("fused", mode)});
match::find_matches( match::find_matches(
mpm, mpm,
find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::int8)}, find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::all)},
find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::none)}); find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::none)});
#else #else
(void)mpm; (void)mpm;
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <rocblas/internal/rocblas-types.h>
#include <rocblas/rocblas.h> #include <rocblas/rocblas.h>
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/gemm_impl.hpp> #include <migraphx/gpu/gemm_impl.hpp>
...@@ -36,6 +37,20 @@ namespace migraphx { ...@@ -36,6 +37,20 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
/*
Regular rocBLAS API takes compute_type as `rocblas_datatype` enum value v/s "ex3" BETA API takes it
as `rocblas_computetype` enum value. `rb_compute_type` is faciliator to implictly cast integer enum
value to required type that can be used inside `common_args` generator.
*/
struct rb_compute_type
{
int type = 0;
rb_compute_type(rocblas_datatype t) : type(static_cast<int>(t)) {}
rb_compute_type(rocblas_computetype t) : type(static_cast<int>(t)) {}
operator rocblas_datatype() const { return static_cast<rocblas_datatype>(type); }
operator rocblas_computetype() const { return static_cast<rocblas_computetype>(type); }
};
// Convert rocBLAS datatypes to equivalent Migraphx data types // Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype get_type(shape::type_t type) rocblas_datatype get_type(shape::type_t type)
{ {
...@@ -185,12 +200,17 @@ struct gemm_impl ...@@ -185,12 +200,17 @@ struct gemm_impl
{ {
output_type = rocblas_datatype_i32_r; output_type = rocblas_datatype_i32_r;
} }
compute_type = output_type; compute_type = rb_compute_type{output_type};
if(compute_fp32) if(compute_fp32)
{ {
if(arg_type == rocblas_datatype_f16_r) if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r; compute_type = rocblas_datatype_f32_r;
} }
if(arg_type == rocblas_datatype_f8_r)
{
assert(get_type(input_shapes[1].type()) == rocblas_datatype_f8_r);
compute_type = rocblas_compute_type_f32;
}
auto a_lens = input_shapes[0].lens(); auto a_lens = input_shapes[0].lens();
auto b_lens = input_shapes[1].lens(); auto b_lens = input_shapes[1].lens();
...@@ -230,7 +250,6 @@ struct gemm_impl ...@@ -230,7 +250,6 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args); auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex3, rocblas_invoke(&rocblas_gemm_strided_batched_ex3,
common_args, common_args,
rocblas_compute_type_f32,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
solution_idx, solution_idx,
gemm_flags); gemm_flags);
...@@ -240,7 +259,6 @@ struct gemm_impl ...@@ -240,7 +259,6 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args); auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex3, rocblas_invoke(&rocblas_gemm_ex3,
common_args, common_args,
rocblas_compute_type_f32,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
solution_idx, solution_idx,
gemm_flags); gemm_flags);
...@@ -254,7 +272,6 @@ struct gemm_impl ...@@ -254,7 +272,6 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args); auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex, rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
solution_idx, solution_idx,
gemm_flags); gemm_flags);
...@@ -264,7 +281,6 @@ struct gemm_impl ...@@ -264,7 +281,6 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args); auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex, rocblas_invoke(&rocblas_gemm_ex,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
solution_idx, solution_idx,
gemm_flags); gemm_flags);
...@@ -304,7 +320,6 @@ struct gemm_impl ...@@ -304,7 +320,6 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args); auto common_args = create_strided_batched_args_common(ctx, input_args);
check_valid = rocblas_invoke(&rocblas_gemm_strided_batched_ex, check_valid = rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
solution_idx, solution_idx,
rocblas_gemm_flags_check_solution_index); rocblas_gemm_flags_check_solution_index);
...@@ -314,7 +329,6 @@ struct gemm_impl ...@@ -314,7 +329,6 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args); auto common_args = create_gemm_ex_args_common(ctx, input_args);
check_valid = rocblas_invoke(&rocblas_gemm_ex, check_valid = rocblas_invoke(&rocblas_gemm_ex,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
solution_idx, solution_idx,
rocblas_gemm_flags_check_solution_index); rocblas_gemm_flags_check_solution_index);
...@@ -365,7 +379,8 @@ struct gemm_impl ...@@ -365,7 +379,8 @@ struct gemm_impl
output_type, output_type,
ldd, ldd,
d_stride, d_stride,
num_matrices); num_matrices,
compute_type);
} }
/** /**
* Helper method to create that subset of a long rocBLAS argument list that is common * Helper method to create that subset of a long rocBLAS argument list that is common
...@@ -398,7 +413,8 @@ struct gemm_impl ...@@ -398,7 +413,8 @@ struct gemm_impl
ldc, ldc,
is_3inputs ? args[3].data() : args[2].data(), is_3inputs ? args[3].data() : args[2].data(),
output_type, output_type,
ldd); ldd,
compute_type);
} }
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API #ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
...@@ -428,7 +444,6 @@ struct gemm_impl ...@@ -428,7 +444,6 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args); auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions, rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
gemm_flags, gemm_flags,
nullptr, nullptr,
...@@ -438,7 +453,6 @@ struct gemm_impl ...@@ -438,7 +453,6 @@ struct gemm_impl
auto common_sol_args = create_strided_batched_args_common(ctx, input_args); auto common_sol_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions, rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_sol_args, common_sol_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
gemm_flags, gemm_flags,
solution_indices.data(), solution_indices.data(),
...@@ -449,7 +463,6 @@ struct gemm_impl ...@@ -449,7 +463,6 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args); auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions, rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
gemm_flags, gemm_flags,
nullptr, nullptr,
...@@ -459,7 +472,6 @@ struct gemm_impl ...@@ -459,7 +472,6 @@ struct gemm_impl
auto common_sol_args = create_gemm_ex_args_common(ctx, input_args); auto common_sol_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions, rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_sol_args, common_sol_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
gemm_flags, gemm_flags,
solution_indices.data(), solution_indices.data(),
...@@ -521,7 +533,7 @@ struct gemm_impl ...@@ -521,7 +533,7 @@ struct gemm_impl
rocblas_int c_stride = 0; rocblas_int c_stride = 0;
rocblas_int d_stride = 0; rocblas_int d_stride = 0;
rocblas_datatype arg_type = rocblas_datatype_f32_r; rocblas_datatype arg_type = rocblas_datatype_f32_r;
rocblas_datatype compute_type = rocblas_datatype_f32_r; rb_compute_type compute_type = rocblas_datatype_f32_r;
rocblas_datatype output_type = rocblas_datatype_f32_r; rocblas_datatype output_type = rocblas_datatype_f32_r;
bool strided_batched = true; bool strided_batched = true;
bool is_3inputs = true; bool is_3inputs = true;
......
...@@ -58,10 +58,10 @@ struct hiprtc_src_file ...@@ -58,10 +58,10 @@ struct hiprtc_src_file
MIGRAPHX_GPU_EXPORT bool hip_has_flags(const std::vector<std::string>& flags); MIGRAPHX_GPU_EXPORT bool hip_has_flags(const std::vector<std::string>& flags);
MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>> compile_hip_src_with_hiprtc( MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>> compile_hip_src_with_hiprtc(
std::vector<hiprtc_src_file> srcs, std::string params, const std::string& arch); std::vector<hiprtc_src_file> srcs, const std::string& params, const std::string& arch);
MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>> MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>> compile_hip_src(
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch); const std::vector<src_file>& srcs, const std::string& params, const std::string& arch);
MIGRAPHX_GPU_EXPORT std::string enum_params(std::size_t count, std::string param); MIGRAPHX_GPU_EXPORT std::string enum_params(std::size_t count, std::string param);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_PAD_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_PAD_HPP
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument MIGRAPHX_DEVICE_EXPORT pad(hipStream_t stream,
argument result,
argument arg1,
float value,
std::vector<std::int64_t> pads);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -37,6 +37,8 @@ MIGRAPHX_GPU_EXPORT std::string get_device_name(); ...@@ -37,6 +37,8 @@ MIGRAPHX_GPU_EXPORT std::string get_device_name();
MIGRAPHX_GPU_EXPORT int get_device_id(); MIGRAPHX_GPU_EXPORT int get_device_id();
MIGRAPHX_GPU_EXPORT bool gfx_has_fp8_intrinsics();
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -34,10 +34,11 @@ struct module_pass_manager; ...@@ -34,10 +34,11 @@ struct module_pass_manager;
namespace gpu { namespace gpu {
MIGRAPHX_GPU_EXPORT bool mlir_enabled(); MIGRAPHX_GPU_EXPORT bool mlir_enabled();
MIGRAPHX_GPU_EXPORT bool mlir_attention_enabled();
struct MIGRAPHX_GPU_EXPORT fuse_mlir struct MIGRAPHX_GPU_EXPORT fuse_mlir
{ {
context* ctx = nullptr; context* ctx = nullptr;
bool enable_extra = false; bool enable_extra = false;
std::string name() const { return "gpu::fuse_mlir"; } std::string name() const { return "gpu::fuse_mlir"; }
void apply(module_pass_manager& mpm) const; void apply(module_pass_manager& mpm) const;
......
...@@ -66,6 +66,10 @@ struct gemm_softmax_gemm ...@@ -66,6 +66,10 @@ struct gemm_softmax_gemm
} }
static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); } static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); }
static bool is_mlir_supported_type(shape::type_t t)
{
return contains({shape::type_t::float_type, shape::half_type}, t);
}
}; };
} // namespace gpu } // namespace gpu
......
...@@ -217,6 +217,12 @@ inline pooling_descriptor make_pooling(const migraphx::op::pooling& op) ...@@ -217,6 +217,12 @@ inline pooling_descriptor make_pooling(const migraphx::op::pooling& op)
ss << op.mode; ss << op.mode;
MIGRAPHX_THROW(ss.str()); MIGRAPHX_THROW(ss.str());
} }
if(not std::all_of(
op.dilations.cbegin(), op.dilations.cend(), [](std::size_t d) { return d == 1; }))
{
MIGRAPHX_THROW("Unsupported dilations for pooling: [" + to_string_range(op.dilations) +
"]");
}
auto p = make_obj<pooling_descriptor>(&miopenCreatePoolingDescriptor); auto p = make_obj<pooling_descriptor>(&miopenCreatePoolingDescriptor);
int kdims = op.kdims(); int kdims = op.kdims();
......
...@@ -146,7 +146,6 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler> ...@@ -146,7 +146,6 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler>
vectorize vec{}; vectorize vec{};
auto nelements = options.virtual_inputs.back().elements(); auto nelements = options.virtual_inputs.back().elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs)); auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs));
if(algo == "block") if(algo == "block")
{ {
// Vectorize if the axis is a reduction axis // Vectorize if the axis is a reduction axis
...@@ -170,13 +169,13 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler> ...@@ -170,13 +169,13 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler>
options.kernel_name = "reduce_kernel"; options.kernel_name = "reduce_kernel";
std::string identity = "[](auto x) { return x; }"; std::string identity = "[](auto x) { return x; }";
auto src = interpolate_string(simple_reduce_kernel, auto src = interpolate_string(simple_reduce_kernel,
{{"reduction", v.at("reduction").to<std::string>()}, {{"reduction", v.at("reduction").to<std::string>()},
{"init", v.get("init", std::string{"0"})}, {"init", v.get("init", std::string{"0"})},
{"read", v.get("read", identity)}, {"read", v.get("read", identity)},
{"write", v.get("write", identity)}, {"write", v.get("write", identity)},
{"algo", algo}, {"algo", algo},
{"transformers", make_transformer_args(vec)}, {"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}}); {"preamble", v.get("preamble", std::string{})}});
options.params += "-Wno-float-equal"; options.params += "-Wno-float-equal";
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
...@@ -267,13 +266,13 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler> ...@@ -267,13 +266,13 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
auto src = interpolate_string( auto src = interpolate_string(
fused_reduce_kernel, fused_reduce_kernel,
{{"kernel", options.kernel_name}, {{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"algo", algo}, {"algo", algo},
{"reduced", "decltype(" + generate_make_shape(reduce_output_shape) + ")"}, {"reduced", "decltype(" + generate_make_shape(reduce_output_shape) + ")"},
{"lambda", v.at("lambda").to<std::string>()}, {"lambda", v.at("lambda").to<std::string>()},
{"transformers", make_transformer_args(vec)}, {"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}}); {"preamble", v.get("preamble", std::string{})}});
options.params += "-Wno-float-equal"; options.params += "-Wno-float-equal";
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,41 +21,58 @@ ...@@ -21,41 +21,58 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_PAD_HPP #ifndef MIGRAPHX_GUARD_JIT_SCATTER_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAD_HPP #define MIGRAPHX_GUARD_JIT_SCATTER_HPP
#include <migraphx/argument.hpp> #include <migraphx/gpu/compiler.hpp>
#include <migraphx/reflect.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/pad.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
struct context; template <typename Derived>
struct scatter_compiler : compiler<Derived>
struct hip_pad
{ {
op::pad op; compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
template <class Self, class F>
static auto reflect(Self& self, F f)
{ {
return migraphx::reflect(self.op, f); const auto inputs =
to_shapes(std::vector<instruction_ref>{ins->inputs().begin() + 1, ins->inputs().end()});
hip_compile_options options;
options.set_launch_params(op.to_value(), compute_global_for(ctx, inputs.at(1).elements()));
options.inputs = inputs;
options.output = inputs.back();
options.kernel_name = derived().get_kernel_name(op);
options.virtual_inputs = inputs;
// The compiler protests the inequality comparison in assign_mul when pertaining to floating
// point, despite it making sense in the context. Thus the warning removal.
options.params += "-Wno-float-equal";
const auto src = derived().make_interpolated_string(op);
return prepend_copy_data_to_output(compile_hip_code_object(src, options));
} }
std::string name() const { return "gpu::pad"; } compiler_replace prepend_copy_data_to_output(const operation& co) const
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return {co, [](module& m, instruction_ref ins, const operation& op) {
auto args = ins->inputs();
args.back() =
m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back());
args.erase(args.begin());
return m.replace_instruction(ins, op, args);
}};
} }
std::string get_kernel_name(const operation& op) const { return op.name() + "_kernel"; }
const Derived& derived() const { return static_cast<const Derived&>(*this); }
}; };
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
...@@ -21,11 +21,7 @@ ...@@ -21,11 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/compiler.hpp> #include "scatter.hpp"
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -55,46 +51,21 @@ MIGRAPHX_GLOBAL void scatternd_kernel(void* in_indices, void* in_updates, void* ...@@ -55,46 +51,21 @@ MIGRAPHX_GLOBAL void scatternd_kernel(void* in_indices, void* in_updates, void*
)__migraphx__"; )__migraphx__";
struct scatternd_compiler : compiler<scatternd_compiler> struct scatternd_compiler : scatter_compiler<scatternd_compiler>
{ {
std::vector<std::string> names() const std::vector<std::string> names() const
{ {
return {"scatternd_none", "scatternd_add", "scatternd_mul"}; return {
"scatternd_none", "scatternd_add", "scatternd_mul", "scatternd_min", "scatternd_max"};
} }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const std::string make_interpolated_string(const operation& op) const
{ {
hip_compile_options options; const auto reduction = op.name().substr(std::char_traits<char>::length("scatternd_"));
options.set_launch_params(v, compute_global_for(ctx, inputs.at(1).elements())); return interpolate_string(scatternd_kernel, {{"reduction", "assign_" + reduction}});
options.inputs = inputs;
options.output = inputs.back();
options.kernel_name = "scatternd_kernel";
options.virtual_inputs = inputs;
auto reduction = "assign_" + v.get("reduction", std::string{"none"});
auto src = interpolate_string(scatternd_kernel, {{"reduction", reduction}});
return compile_hip_code_object(src, options);
} }
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const std::string get_kernel_name(const operation&) const { return "scatternd_kernel"; }
{
assert(starts_with(op.name(), "scatternd_"));
auto reduction = op.name().substr(10);
return insert(compile_op(
ctx,
to_shapes(std::vector<instruction_ref>{ins->inputs().begin() + 1, ins->inputs().end()}),
{{"reduction", reduction}}));
}
compiler_replace insert(const operation& co) const
{
return {co, [](module& m, instruction_ref ins, const operation& op) {
auto args = ins->inputs();
args.back() =
m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back());
args.erase(args.begin());
return m.replace_instruction(ins, op, args);
}};
}
}; };
} // namespace gpu } // namespace gpu
......
...@@ -22,8 +22,12 @@ ...@@ -22,8 +22,12 @@
#ifndef MIGRAPHX_GUARD_KERNELS_BITCAST_HPP #ifndef MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
#define MIGRAPHX_GUARD_KERNELS_BITCAST_HPP #define MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
#include <migraphx/kernels/type_traits.hpp>
namespace migraphx { namespace migraphx {
template <typename To, typename From> template <typename To,
typename From,
MIGRAPHX_REQUIRES(is_trivially_copyable<To>{} and is_trivially_copyable<From>{})>
inline constexpr To bit_cast(From fr) noexcept inline constexpr To bit_cast(From fr) noexcept
{ {
static_assert(sizeof(To) == sizeof(From)); static_assert(sizeof(To) == sizeof(From));
......
...@@ -365,15 +365,6 @@ struct float8 ...@@ -365,15 +365,6 @@ struct float8
inline __device__ constexpr float8& operator=(const float8& rhs) = default; inline __device__ constexpr float8& operator=(const float8& rhs) = default;
inline __device__ constexpr float8& operator=(float8&& rhs) noexcept = default; inline __device__ constexpr float8& operator=(float8&& rhs) noexcept = default;
inline __device__ constexpr bool operator==(const float8& rhs) const
{
if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf())
return false;
else if((rhs.is_zero() and this->is_zero()) or (this->data == rhs.data))
return true;
return false;
}
inline __device__ constexpr bool operator<(const float8& rhs) const inline __device__ constexpr bool operator<(const float8& rhs) const
{ {
const auto we = static_cast<float>(*this); const auto we = static_cast<float>(*this);
...@@ -403,12 +394,20 @@ using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>; ...@@ -403,12 +394,20 @@ using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
} }
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_FP8_FABS(T) \ #define MIGRAPHX_FP8_OTHER_OPS(T) \
inline constexpr __device__ T fabs(T v) \ inline constexpr __device__ T fabs(T v) \
{ \ { \
/*NOLINTNEXTLINE*/ \ /*NOLINTNEXTLINE*/ \
v.data = v.data & 0x7f; \ v.data = v.data & 0x7f; \
return v; \ return v; \
} \
inline __device__ constexpr bool operator==(const T& lhs, const T& rhs) \
{ \
if(rhs.is_nan() or rhs.is_inf() or lhs.is_nan() or lhs.is_inf()) \
return false; \
else if((rhs.is_zero() and lhs.is_zero()) or (lhs.data == rhs.data)) \
return true; \
return false; \
} }
// NOLINTNEXTLINE // NOLINTNEXTLINE
...@@ -417,11 +416,10 @@ using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>; ...@@ -417,11 +416,10 @@ using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
MIGRAPHX_FP8_BINARY_OP(-, T, T) \ MIGRAPHX_FP8_BINARY_OP(-, T, T) \
MIGRAPHX_FP8_BINARY_OP(/, T, T) \ MIGRAPHX_FP8_BINARY_OP(/, T, T) \
MIGRAPHX_FP8_BINARY_OP(+, T, T) \ MIGRAPHX_FP8_BINARY_OP(+, T, T) \
MIGRAPHX_FP8_BINARY_OP(==, T, bool) \
MIGRAPHX_FP8_BINARY_OP(>=, T, bool) \ MIGRAPHX_FP8_BINARY_OP(>=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(<=, T, bool) \ MIGRAPHX_FP8_BINARY_OP(<=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(!=, T, bool) \ MIGRAPHX_FP8_BINARY_OP(!=, T, bool) \
MIGRAPHX_FP8_FABS(T) MIGRAPHX_FP8_OTHER_OPS(T)
MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e5m2) MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e5m2)
MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e5m2fnuz) MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e5m2fnuz)
...@@ -447,7 +445,7 @@ class numeric_limits<fp8e4m3fnuz> ...@@ -447,7 +445,7 @@ class numeric_limits<fp8e4m3fnuz>
{ {
return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits()); return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits());
} }
// this is min value that is not DeNorm. DeNorm min is 0x01 // this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01
static constexpr __device__ fp8e4m3fnuz min() static constexpr __device__ fp8e4m3fnuz min()
{ {
return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits()); return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits());
...@@ -475,7 +473,7 @@ class numeric_limits<fp8e4m3fn> ...@@ -475,7 +473,7 @@ class numeric_limits<fp8e4m3fn>
} }
static constexpr __device__ fp8e4m3fn max() { return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits()); } static constexpr __device__ fp8e4m3fn max() { return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits()); }
// this is min value that is not DeNorm. DeNorm min is 0x01 // this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01
static constexpr __device__ fp8e4m3fn min() { return fp8e4m3fn(0x08, fp8e4m3fn::from_bits()); } static constexpr __device__ fp8e4m3fn min() { return fp8e4m3fn(0x08, fp8e4m3fn::from_bits()); }
static constexpr __device__ fp8e4m3fn lowest() static constexpr __device__ fp8e4m3fn lowest()
...@@ -503,8 +501,7 @@ class numeric_limits<fp8e5m2fnuz> ...@@ -503,8 +501,7 @@ class numeric_limits<fp8e5m2fnuz>
{ {
return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits()); return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits());
} }
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make // this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
// this distinction. For the floating points we would end up using lowest most of the times.
static constexpr __device__ fp8e5m2fnuz min() static constexpr __device__ fp8e5m2fnuz min()
{ {
return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits()); return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits());
...@@ -529,8 +526,7 @@ class numeric_limits<fp8e5m2> ...@@ -529,8 +526,7 @@ class numeric_limits<fp8e5m2>
} }
static constexpr __device__ fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); } static constexpr __device__ fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); }
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make // this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
// this distinction. For the floating points we would end up using lowest most of the times.
static constexpr __device__ fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); } static constexpr __device__ fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); }
static constexpr __device__ fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); } static constexpr __device__ fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); }
...@@ -539,24 +535,26 @@ class numeric_limits<fp8e5m2> ...@@ -539,24 +535,26 @@ class numeric_limits<fp8e5m2>
}; };
} // namespace fp8 } // namespace fp8
template <class T,
// NOLINTNEXTLINE MIGRAPHX_REQUIRES(is_same<T, fp8::fp8e4m3fnuz>{} or is_same<T, fp8::fp8e5m2fnuz>{} or
#define MIGRAPHX_FP8_MIN_MAX(T) \ is_same<T, fp8::fp8e4m3fn>{} or is_same<T, fp8::fp8e5m2>{})>
template <> \ constexpr T numeric_max(migraphx::fp8::f8_type unused = migraphx::fp8::f8_type::fp8)
constexpr T numeric_max<T, void>() \ {
{ \ // unused parameter is added to make this numeric_max different overload definition
return fp8::numeric_limits<T>::max(); \ // compared to numeric_max defined in type_traits.hpp
} \ (void)(unused);
template <> \ return fp8::numeric_limits<T>::max();
constexpr T numeric_lowest<T>() \ }
{ \ template <class T,
return fp8::numeric_limits<T>::lowest(); \ MIGRAPHX_REQUIRES(is_same<T, fp8::fp8e4m3fnuz>{} or is_same<T, fp8::fp8e5m2fnuz>{} or
} is_same<T, fp8::fp8e4m3fn>{} or is_same<T, fp8::fp8e5m2>{})>
constexpr T numeric_lowest(migraphx::fp8::f8_type unused = migraphx::fp8::f8_type::fp8)
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e4m3fnuz); {
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e5m2fnuz); // unused parameter is added to make this numeric_lowest different overload definition
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e4m3fn); // compared to numeric_lowest defined in type_traits.hpp
MIGRAPHX_FP8_MIN_MAX(fp8::fp8e5m2); (void)(unused);
return fp8::numeric_limits<T>::lowest();
}
} // namespace migraphx } // namespace migraphx
// ================================================================================================= // =================================================================================================
#if defined(__clang__) #if defined(__clang__)
......
...@@ -53,35 +53,35 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t, ...@@ -53,35 +53,35 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t,
auto indices_shape_lens = indices_shape.lens; auto indices_shape_lens = indices_shape.lens;
auto data_shape_lens = data_shape.lens; auto data_shape_lens = data_shape.lens;
auto num_slice_dims = indices_shape_lens.back(); auto num_slice_dims = indices_shape_lens.back();
std::size_t num_slices = size_t num_slices =
accumulate(indices_shape_lens.begin(), indices_shape_lens.end() - 1, 1, op::product{}); accumulate(indices_shape_lens.begin(), indices_shape_lens.end() - 1, 1, op::product{});
std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims, size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims,
data_shape_lens.end(), data_shape_lens.end(),
1, 1,
op::product{}); op::product{});
const std::size_t num_batches = const size_t num_batches =
accumulate(data_shape_lens.begin(), data_shape_lens.begin() + batch_dims, 1, op::product{}); accumulate(data_shape_lens.begin(), data_shape_lens.begin() + batch_dims, 1, op::product{});
const std::size_t data_batch_stride = const size_t data_batch_stride =
accumulate(data_shape_lens.begin() + batch_dims, data_shape_lens.end(), 1, op::product{}); accumulate(data_shape_lens.begin() + batch_dims, data_shape_lens.end(), 1, op::product{});
const auto num_slices_per_batch = num_slices / num_batches; const auto num_slices_per_batch = num_slices / num_batches;
ind.global_stride(output_shape.elements(), [&](auto i) { ind.global_stride(output_shape.elements(), [&](auto i) {
const auto* indices_ptr = indices_t.data(); const auto* indices_ptr = indices_t.data();
const std::size_t j = i / slice_size; const size_t j = i / slice_size;
const std::size_t batch_idx = j / num_slices_per_batch; const size_t batch_idx = j / num_slices_per_batch;
auto* slice_indices = indices_ptr + (j * num_slice_dims); auto* slice_indices = indices_ptr + (j * num_slice_dims);
std::size_t relative_slice_offset = 0; size_t relative_slice_offset = 0;
for(std::size_t idx = 0; idx < num_slice_dims; ++idx) for(size_t idx = 0; idx < num_slice_dims; ++idx)
{ {
int64_t index = slice_indices[idx]; int64_t index = slice_indices[idx];
const std::size_t input_dim_idx = batch_dims + idx; const size_t input_dim_idx = batch_dims + idx;
const auto input_dim = data_shape_lens[input_dim_idx]; const auto input_dim = data_shape_lens[input_dim_idx];
MIGRAPHX_ASSERT(index >= -static_cast<int64_t>(input_dim) and MIGRAPHX_ASSERT(index >= -static_cast<int64_t>(input_dim) and
index < static_cast<int64_t>(input_dim)); index < static_cast<int64_t>(input_dim));
if(index < 0) if(index < 0)
index += input_dim; index += input_dim;
std::size_t size_from_slice_dims = size_t size_from_slice_dims =
accumulate(data_shape_lens.begin() + batch_dims + idx + 1, accumulate(data_shape_lens.begin() + batch_dims + idx + 1,
data_shape_lens.begin() + batch_dims + num_slice_dims, data_shape_lens.begin() + batch_dims + num_slice_dims,
slice_size, slice_size,
......
...@@ -54,12 +54,11 @@ __device__ void generic_binary_layernorm( ...@@ -54,12 +54,11 @@ __device__ void generic_binary_layernorm(
using value_type = typename Input1::type; using value_type = typename Input1::type;
using vec_value_type = vec_type<value_type>; using vec_value_type = vec_type<value_type>;
constexpr auto relements = r.template elements<Input1>(); constexpr auto relements = r.template elements<Input1>();
constexpr auto relements_r = static_cast<vec_value_type>(1.0 / relements); constexpr auto relements_r = vec_value_type{1.0 / relements};
auto relements_rsqrt = sqrt(relements_r); auto relements_rsqrt = sqrt(relements_r);
auto means = r.reduce(op::sum{}, auto means = r.reduce(op::sum{},
make_array<vec_value_type>(static_cast<vec_value_type>(0), make_array<vec_value_type>(vec_value_type{0}, vec_value_type{0}),
static_cast<vec_value_type>(0)),
[&](auto x) { [&](auto x) {
auto x_out = x * relements_r; auto x_out = x * relements_r;
// dividing x by sqrt(relements) before squaring allows computing // dividing x by sqrt(relements) before squaring allows computing
...@@ -71,7 +70,7 @@ __device__ void generic_binary_layernorm( ...@@ -71,7 +70,7 @@ __device__ void generic_binary_layernorm(
auto mean_x = means[0]; auto mean_x = means[0];
auto mean_x2 = means[1]; auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x); auto variance = mean_x2 - (mean_x * mean_x);
value_type eps_val = static_cast<value_type>(eps); value_type eps_val = implicit_conversion(eps);
r.inner([&](auto& y, auto x, auto... xs) { r.inner([&](auto& y, auto x, auto... xs) {
auto m = x - mean_x; auto m = x - mean_x;
......
...@@ -290,7 +290,7 @@ MIGRAPHX_DEVICE_MATH_VEC(where) ...@@ -290,7 +290,7 @@ MIGRAPHX_DEVICE_MATH_VEC(where)
template <class T, class U> template <class T, class U>
constexpr auto convert(U v) constexpr auto convert(U v)
{ {
return vec_transform(v)([](auto x) { return static_cast<T>(x); }); return vec_transform(v)([](auto x) -> T { return static_cast<T>(x); });
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -118,7 +118,7 @@ struct highest ...@@ -118,7 +118,7 @@ struct highest
template <class T> template <class T>
constexpr operator T() const constexpr operator T() const
{ {
return numeric_max<vec_type<T>, void>(); return numeric_max<vec_type<T>>();
} }
}; };
} // namespace migraphx } // namespace migraphx
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp> #include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/ranges.hpp> #include <migraphx/kernels/ranges.hpp>
#include <migraphx/kernels/vec.hpp>
namespace migraphx { namespace migraphx {
...@@ -39,7 +40,6 @@ __device__ void pad(const index& idx, ...@@ -39,7 +40,6 @@ __device__ void pad(const index& idx,
const PadVal& pad_val) const PadVal& pad_val)
{ {
auto output_shape = output.get_shape(); auto output_shape = output.get_shape();
using otype = typename Output::type;
idx.global_stride(output_shape.elements(), [&](auto i) { idx.global_stride(output_shape.elements(), [&](auto i) {
// 1. get current multi-index for output // 1. get current multi-index for output
// 2. get the size of the input to determine input boundaries // 2. get the size of the input to determine input boundaries
...@@ -54,9 +54,9 @@ __device__ void pad(const index& idx, ...@@ -54,9 +54,9 @@ __device__ void pad(const index& idx,
if(any_of(range_multi.begin(), range_multi.end(), [&](auto j) { if(any_of(range_multi.begin(), range_multi.end(), [&](auto j) {
return multi[j] < offsets[j] or input_idx[j] >= input_bounds[j]; return multi[j] < offsets[j] or input_idx[j] >= input_bounds[j];
})) }))
output[multi] = otype(pad_val); output[multi] = implicit_conversion(pad_val);
else else
output[multi] = otype(input[input_idx]); output[multi] = implicit_conversion(input[input_idx]);
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment