Commit fae5170b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into ref_op_name

parents a3906038 2a79a9ff
...@@ -35,7 +35,7 @@ struct argmin ...@@ -35,7 +35,7 @@ struct argmin
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1);
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
lens[axis] = 1; lens[axis] = 1;
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -21,25 +21,26 @@ struct clip ...@@ -21,25 +21,26 @@ struct clip
{ {
std::string name() const { return "clip"; } std::string name() const { return "clip"; }
value attributes() const
{
return {{"pointwise", true},
{"point_op", "${function:min}(${function:max}(${1}, ${0}), ${2})"}};
}
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(3).same_type(); check_shapes{inputs, *this}.has(3).same_type().same_dims();
return inputs.front(); return inputs.front();
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0], args[1], args[2])([&](auto output, auto x, auto min, auto max) {
par_for(output_shape.elements(),
[&](auto i) { output[i] = std::min(std::max(min[i], x[i]), max[i]); });
});
visit_all(result, args[0], args[1], args[2])(
[&](auto output, auto input, auto min_val, auto max_val) {
auto max = max_val.front();
auto min = min_val.front();
std::transform(input.begin(), input.end(), output.begin(), [max, min](auto x) {
using type = decltype(x);
return std::min(std::max(type(min), x), type(max));
});
});
return result; return result;
} }
}; };
......
...@@ -32,6 +32,11 @@ struct convert : unary<convert> ...@@ -32,6 +32,11 @@ struct convert : unary<convert>
return {target_type, inputs.at(0).lens(), inputs.at(0).strides()}; return {target_type, inputs.at(0).lens(), inputs.at(0).strides()};
} }
std::string point_op() const
{
return "${function:convert}<" + shape::cpp_type(target_type) + ">(${0})";
}
auto apply() const auto apply() const
{ {
auto type = target_type; auto type = target_type;
......
...@@ -38,7 +38,7 @@ struct gather ...@@ -38,7 +38,7 @@ struct gather
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2).standard(); check_shapes{inputs, *this}.has(2);
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
auto type = inputs[0].type(); auto type = inputs[0].type();
lens.erase(lens.begin() + axis); lens.erase(lens.begin() + axis);
......
File mode changed from 100755 to 100644
...@@ -179,6 +179,7 @@ instruction_ref module::insert_instruction(instruction_ref ins, ...@@ -179,6 +179,7 @@ instruction_ref module::insert_instruction(instruction_ref ins,
const operation& op, const operation& op,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
assert(has_instruction(ins) or is_end(ins, this->end()));
assert(not starts_with(op.name(), "@")); assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args); shape r = compute_shape(op, args);
auto result = impl->insert(ins, {op, r, std::move(args)}); auto result = impl->insert(ins, {op, r, std::move(args)});
...@@ -200,6 +201,7 @@ instruction_ref module::insert_instruction(instruction_ref ins, ...@@ -200,6 +201,7 @@ instruction_ref module::insert_instruction(instruction_ref ins,
std::vector<instruction_ref> args, std::vector<instruction_ref> args,
std::vector<module_ref> module_args) std::vector<module_ref> module_args)
{ {
assert(has_instruction(ins) or is_end(ins, this->end()));
assert(not starts_with(op.name(), "@")); assert(not starts_with(op.name(), "@"));
auto out_shape = compute_shape(op, args, module_args); auto out_shape = compute_shape(op, args, module_args);
auto result = impl->insert(ins, {op, out_shape, std::move(args), std::move(module_args)}); auto result = impl->insert(ins, {op, out_shape, std::move(args), std::move(module_args)});
...@@ -212,6 +214,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, ...@@ -212,6 +214,7 @@ instruction_ref module::replace_instruction(instruction_ref ins,
const operation& op, const operation& op,
std::vector<instruction_ref> args) MIGRAPHX_TIDY_CONST std::vector<instruction_ref> args) MIGRAPHX_TIDY_CONST
{ {
assert(has_instruction(ins));
assert(not starts_with(op.name(), "@")); assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args); shape r = compute_shape(op, args);
...@@ -225,6 +228,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, ...@@ -225,6 +228,7 @@ instruction_ref module::replace_instruction(instruction_ref ins,
std::vector<instruction_ref> args, std::vector<instruction_ref> args,
std::vector<module_ref> module_args) MIGRAPHX_TIDY_CONST std::vector<module_ref> module_args) MIGRAPHX_TIDY_CONST
{ {
assert(has_instruction(ins));
assert(not starts_with(op.name(), "@")); assert(not starts_with(op.name(), "@"));
auto out_shape = compute_shape(op, args, module_args); auto out_shape = compute_shape(op, args, module_args);
instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args)); instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args));
...@@ -291,6 +295,8 @@ instruction_ref module::remove_instructions(instruction_ref first, instruction_r ...@@ -291,6 +295,8 @@ instruction_ref module::remove_instructions(instruction_ref first, instruction_r
instruction_ref module::move_instruction(instruction_ref src, instruction_ref dst) instruction_ref module::move_instruction(instruction_ref src, instruction_ref dst)
{ {
assert(has_instruction(src));
assert(has_instruction(dst) or is_end(dst, this->end()));
impl->instructions.splice(dst, impl->instructions, src); impl->instructions.splice(dst, impl->instructions, src);
return src; return src;
} }
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_greaterorequal : op_parser<parse_greaterorequal>
{
std::vector<op_desc> operators() const { return {{"GreaterOrEqual"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto in_res = info.add_broadcastable_binary_op("less", args[0], args[1]);
if(in_res->get_shape().type() != shape::bool_type)
{
in_res = info.add_instruction(make_op("convert", {{"target_type", shape::bool_type}}),
in_res);
}
return info.add_instruction(make_op("not"), in_res);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_hardsigmoid : op_parser<parse_hardsigmoid>
{
std::vector<op_desc> operators() const { return {{"HardSigmoid"}, {"HardSwish"}}; }
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float alpha = 0.2;
float beta = 0.5;
if(opd.onnx_name == "HardSwish")
{
alpha = 1.0 / 6.0;
}
else
{
if(contains(info.attributes, "alpha"))
alpha = info.attributes.at("alpha").f();
if(contains(info.attributes, "beta"))
beta = info.attributes.at("beta").f();
}
auto input_lens = args[0]->get_shape().lens();
auto input_type = args[0]->get_shape().type();
auto mb_alpha = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {alpha}}));
auto mb_beta = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {beta}}));
auto mb_zero = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {0}}));
auto mb_one = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {1}}));
auto mul = info.add_instruction(migraphx::make_op("mul"), mb_alpha, args[0]);
auto add = info.add_instruction(migraphx::make_op("add"), mb_beta, mul);
auto hardsigmoid = info.add_instruction(migraphx::make_op("clip"), add, mb_zero, mb_one);
if(opd.onnx_name == "HardSwish")
return info.add_instruction(migraphx::make_op("mul"), args[0], hardsigmoid);
return hardsigmoid;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_mean : op_parser<parse_mean>
{
std::vector<op_desc> operators() const { return {{"Mean"}}; }
/// Calculates the element-wise mean of n>=1 input tensors
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto num_data = args.size();
if(num_data == 1)
return args[0];
auto divisor = info.add_literal(
migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {num_data}});
return std::accumulate(args.begin(), args.end(), args[0], [&](auto& mean, auto& data_i) {
// Pre-divide each tensor element-wise by n to reduce risk of overflow during summation
data_i = info.add_broadcastable_binary_op("div", data_i, divisor);
if(data_i != args[0])
return info.add_broadcastable_binary_op("add", mean, data_i);
return data_i;
});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -27,11 +27,6 @@ struct parse_multinomial : op_parser<parse_multinomial> ...@@ -27,11 +27,6 @@ struct parse_multinomial : op_parser<parse_multinomial>
if(contains(info.attributes, "sample_size")) if(contains(info.attributes, "sample_size"))
sample_size = info.attributes.at("sample_size").i(); sample_size = info.attributes.at("sample_size").i();
float seed = static_cast<float>(
std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
seed = info.attributes.at("seed").f();
// Subtract the per-batch maximum log-probability, making the per-batch max 0 // Subtract the per-batch maximum log-probability, making the per-batch max 0
auto maxes = auto maxes =
info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), args[0]); info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), args[0]);
...@@ -46,7 +41,10 @@ struct parse_multinomial : op_parser<parse_multinomial> ...@@ -46,7 +41,10 @@ struct parse_multinomial : op_parser<parse_multinomial>
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf); migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);
// Pre-compute random distribution // Pre-compute random distribution
std::mt19937 gen(seed); std::mt19937 gen(std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
gen.seed(info.attributes.at("seed").f());
std::uniform_real_distribution<> dis(0.0, 1.0); std::uniform_real_distribution<> dis(0.0, 1.0);
size_t batch_size = args[0]->get_shape().lens().front(); size_t batch_size = args[0]->get_shape().lens().front();
migraphx::shape dist_shape{migraphx::shape::float_type, {batch_size, sample_size}}; migraphx::shape dist_shape{migraphx::shape::float_type, {batch_size, sample_size}};
......
...@@ -42,11 +42,6 @@ struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops> ...@@ -42,11 +42,6 @@ struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops>
if(contains(info.attributes, "scale")) if(contains(info.attributes, "scale"))
scale = info.attributes.at("scale").f(); scale = info.attributes.at("scale").f();
float seed = static_cast<float>(
std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
seed = info.attributes.at("seed").f();
shape out_shape; shape out_shape;
if(contains(info.attributes, "shape")) if(contains(info.attributes, "shape"))
{ {
...@@ -75,7 +70,10 @@ struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops> ...@@ -75,7 +70,10 @@ struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops>
": cannot deduce shape without shape attribute or argument."); ": cannot deduce shape without shape attribute or argument.");
} }
std::mt19937 gen(seed); std::mt19937 gen(std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
gen.seed(info.attributes.at("seed").f());
std::normal_distribution<> d(mean, scale); std::normal_distribution<> d(mean, scale);
std::vector<double> rand_vals(out_shape.elements()); std::vector<double> rand_vals(out_shape.elements());
std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); }); std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); });
......
...@@ -42,11 +42,6 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops> ...@@ -42,11 +42,6 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
if(contains(info.attributes, "low")) if(contains(info.attributes, "low"))
low = info.attributes.at("low").f(); low = info.attributes.at("low").f();
float seed = static_cast<float>(
std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
seed = info.attributes.at("seed").f();
shape out_shape; shape out_shape;
if(contains(info.attributes, "shape")) if(contains(info.attributes, "shape"))
{ {
...@@ -75,7 +70,10 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops> ...@@ -75,7 +70,10 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
": cannot deduce shape without shape attribute or argument."); ": cannot deduce shape without shape attribute or argument.");
} }
std::mt19937 gen(seed); std::mt19937 gen(std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
gen.seed(info.attributes.at("seed").f());
std::uniform_real_distribution<> d(high, low); std::uniform_real_distribution<> d(high, low);
std::vector<double> rand_vals(out_shape.elements()); std::vector<double> rand_vals(out_shape.elements());
std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); }); std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); });
......
...@@ -163,9 +163,9 @@ static std::string get_nearest_mode(const onnx_parser::attribute_map& attr) ...@@ -163,9 +163,9 @@ static std::string get_nearest_mode(const onnx_parser::attribute_map& attr)
struct parse_resize : op_parser<parse_resize> struct parse_resize : op_parser<parse_resize>
{ {
std::vector<op_desc> operators() const { return {{"Resize"}}; } std::vector<op_desc> operators() const { return {{"Resize"}, {"Upsample"}}; }
instruction_ref parse(const op_desc& /*opd*/, instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/, const onnx_parser& /*parser*/,
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
...@@ -183,7 +183,7 @@ struct parse_resize : op_parser<parse_resize> ...@@ -183,7 +183,7 @@ struct parse_resize : op_parser<parse_resize>
if(contains(info.attributes, "exclude_outside") and if(contains(info.attributes, "exclude_outside") and
info.attributes.at("exclude_outside").i() == 1) info.attributes.at("exclude_outside").i() == 1)
{ {
MIGRAPHX_THROW("PARSE_RESIZE: exclude_outside 1 is not supported!"); MIGRAPHX_THROW("PARSE_" + opd.op_name + ": exclude_outside 1 is not supported!");
} }
// input data shape info // input data shape info
...@@ -215,12 +215,14 @@ struct parse_resize : op_parser<parse_resize> ...@@ -215,12 +215,14 @@ struct parse_resize : op_parser<parse_resize>
if(type == shape::int64_type) if(type == shape::int64_type)
{ {
auto arg_out_s = arg->eval(); auto arg_out_s = arg->eval();
check_arg_empty(arg_out_s, "PARSE_RESIZE: dynamic output size is not supported!"); check_arg_empty(arg_out_s,
"PARSE_" + opd.op_name + ": dynamic output size is not supported!");
arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); }); arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_lens.size()) if(out_lens.size() != in_lens.size())
{ {
MIGRAPHX_THROW("PARSE_RESIZE: specified output size does not match input size"); MIGRAPHX_THROW("PARSE_" + opd.op_name +
": specified output size does not match input size");
} }
// compute the scale // compute the scale
...@@ -239,12 +241,14 @@ struct parse_resize : op_parser<parse_resize> ...@@ -239,12 +241,14 @@ struct parse_resize : op_parser<parse_resize>
{ {
auto arg_scale = arg->eval(); auto arg_scale = arg->eval();
check_arg_empty(arg_scale, check_arg_empty(arg_scale,
"PARSE_RESIZE: dynamic input scale is not supported!"); "PARSE_" + opd.op_name +
": dynamic input scale is not supported!");
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); }); arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });
if(in_lens.size() != vec_scale.size()) if(in_lens.size() != vec_scale.size())
{ {
MIGRAPHX_THROW("PARSE_RESIZE: ranks of input and scale are different!"); MIGRAPHX_THROW("PARSE_" + opd.op_name +
": ranks of input and scale are different!");
} }
std::transform(in_lens.begin(), std::transform(in_lens.begin(),
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_softplus : op_parser<parse_softplus>
{
std::vector<op_desc> operators() const { return {{"Softplus"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
// Apply pointwise formula: y = ln(exp(x) + 1)
auto mb_ones = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", args[0]->get_shape().lens()}}),
info.add_literal(migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {1}}));
auto exp = info.add_instruction(migraphx::make_op("exp"), args[0]);
auto add = info.add_instruction(migraphx::make_op("add"), exp, mb_ones);
return info.add_instruction(migraphx::make_op("log"), add);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_softsign : op_parser<parse_softsign>
{
std::vector<op_desc> operators() const { return {{"Softsign"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
// Apply pointwise formula: y = x / (1 + |x|)
auto mb_ones = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", args[0]->get_shape().lens()}}),
info.add_literal(migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {1}}));
auto abs = info.add_instruction(migraphx::make_op("abs"), args[0]);
auto add = info.add_instruction(migraphx::make_op("add"), abs, mb_ones);
return info.add_instruction(migraphx::make_op("div"), args[0], add);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_upsample : op_parser<parse_upsample>
{
std::vector<op_desc> operators() const { return {{"Upsample"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
if(contains(info.attributes, "mode"))
{
auto mode = info.attributes.at("mode").s();
if(mode != "nearest")
{
MIGRAPHX_THROW("PARSE_UPSAMPLE: only nearest mode is supported!");
}
}
auto arg_scale = args[1]->eval();
check_arg_empty(arg_scale, "PARSE_UPSAMPLE: only constant scale is supported!");
std::vector<float> vec_scale;
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });
auto in_s = args[0]->get_shape();
auto in_lens = in_s.lens();
if(in_lens.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_UPSAMPLE: ranks of input and scale are different!");
}
std::vector<std::size_t> out_lens(in_lens.size());
std::transform(in_lens.begin(),
in_lens.end(),
vec_scale.begin(),
out_lens.begin(),
[&](auto idx, auto scale) { return static_cast<std::size_t>(idx * scale); });
std::vector<float> idx_scale(in_lens.size());
std::transform(
out_lens.begin(),
out_lens.end(),
in_lens.begin(),
idx_scale.begin(),
[](auto od, auto id) { return (od == id) ? 1.0f : (id - 1.0f) / (od - 1.0f); });
shape out_s{in_s.type(), out_lens};
std::vector<int> ind(out_s.elements());
// map out_idx to in_idx
shape_for_each(out_s, [&](auto idx) {
auto in_idx = idx;
std::transform(idx.begin(),
idx.end(),
idx_scale.begin(),
in_idx.begin(),
// nearest mode
[](auto index, auto scale) {
return static_cast<std::size_t>(std::round(index * scale));
});
ind[out_s.index(idx)] = static_cast<int64_t>(in_s.index(in_idx));
});
// reshape input to one-dimension
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
shape ind_s{shape::int32_type, out_lens};
auto rsp = info.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
auto ins_ind = info.add_literal(literal(ind_s, ind));
return info.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -91,28 +91,34 @@ add_library(migraphx_device ...@@ -91,28 +91,34 @@ add_library(migraphx_device
device/unary_not.cpp device/unary_not.cpp
device/where.cpp device/where.cpp
) )
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device) add_library(compile_for_gpu INTERFACE)
rocm_set_soversion(migraphx_device ${MIGRAPHX_SO_VERSION}) target_compile_options(compile_for_gpu INTERFACE -std=c++17 -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns)
rocm_clang_tidy_check(migraphx_device) target_link_libraries(compile_for_gpu INTERFACE hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument)
target_compile_options(migraphx_device PRIVATE -std=c++17 -fno-gpu-rdc -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns)
target_link_libraries(migraphx_device migraphx hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument)
if(CMAKE_CXX_COMPILER MATCHES ".*hcc")
set(AMDGPU_TARGETS "gfx803;gfx900;gfx906" CACHE STRING "")
foreach(AMDGPU_TARGET ${AMDGPU_TARGETS})
target_compile_options(migraphx_device PRIVATE -amdgpu-target=${AMDGPU_TARGET})
target_link_libraries(migraphx_device -amdgpu-target=${AMDGPU_TARGET})
endforeach()
else()
target_compile_options(migraphx_device PRIVATE -Wno-cuda-compat)
endif()
check_cxx_compiler_flag("--cuda-host-only -fhip-lambda-host-device -x hip" HAS_HIP_LAMBDA_HOST_DEVICE) check_cxx_compiler_flag("--cuda-host-only -fhip-lambda-host-device -x hip" HAS_HIP_LAMBDA_HOST_DEVICE)
if(HAS_HIP_LAMBDA_HOST_DEVICE) if(HAS_HIP_LAMBDA_HOST_DEVICE)
message(STATUS "Enable -fhip-lambda-host-device") message(STATUS "Enable -fhip-lambda-host-device")
target_compile_options(migraphx_device PRIVATE -fhip-lambda-host-device) target_compile_options(compile_for_gpu INTERFACE -fhip-lambda-host-device)
endif() endif()
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_set_soversion(migraphx_device ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_device)
target_link_libraries(migraphx_device PUBLIC migraphx)
target_link_libraries(migraphx_device PRIVATE compile_for_gpu)
target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>) target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>) target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>)
add_library(kernel_file_check EXCLUDE_FROM_ALL)
foreach(KERNEL_FILE ${KERNEL_FILES})
get_filename_component(KERNEL_BASE_FILE ${KERNEL_FILE} NAME_WE)
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp "#include <migraphx/kernels/${KERNEL_BASE_FILE}.hpp>\n")
target_sources(kernel_file_check PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp)
endforeach()
target_include_directories(kernel_file_check PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/>)
target_link_libraries(kernel_file_check compile_for_gpu)
rocm_clang_tidy_check(kernel_file_check)
add_library(migraphx_gpu add_library(migraphx_gpu
abs.cpp abs.cpp
analyze_streams.cpp analyze_streams.cpp
...@@ -310,8 +316,12 @@ target_flags(HIP_COMPILER_FLAGS hip::device) ...@@ -310,8 +316,12 @@ target_flags(HIP_COMPILER_FLAGS hip::device)
# Remove cuda arch flags # Remove cuda arch flags
string(REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") string(REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REGEX REPLACE --offload-arch=[a-z0-9:+-]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") string(REGEX REPLACE --offload-arch=[a-z0-9:+-]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REPLACE "$<LINK_LANGUAGE:CXX>" "1" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") # Skip library paths since hip will incorrectly treat it as a source file
string(REPLACE "SHELL:" "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") string(APPEND HIP_COMPILER_FLAGS " ")
foreach(_unused RANGE 2)
string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
endforeach()
message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}") message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}")
target_compile_definitions(migraphx_gpu PRIVATE target_compile_definitions(migraphx_gpu PRIVATE
"-DMIGRAPHX_HIP_COMPILER=${CMAKE_CXX_COMPILER}" "-DMIGRAPHX_HIP_COMPILER=${CMAKE_CXX_COMPILER}"
...@@ -341,7 +351,7 @@ target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels) ...@@ -341,7 +351,7 @@ target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
add_subdirectory(driver) add_subdirectory(driver)
rocm_install_targets( rocm_install_targets(
TARGETS migraphx_gpu migraphx_device TARGETS migraphx_gpu migraphx_device compile_for_gpu
INCLUDE INCLUDE
${CMAKE_CURRENT_SOURCE_DIR}/include ${CMAKE_CURRENT_SOURCE_DIR}/include
) )
......
...@@ -9,7 +9,7 @@ namespace gpu { ...@@ -9,7 +9,7 @@ namespace gpu {
shape hip_argmax::compute_shape(const std::vector<shape>& inputs) const shape hip_argmax::compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2).standard(); check_shapes{inputs, *this}.has(2);
return op.normalize_compute_shape({inputs.at(0)}); return op.normalize_compute_shape({inputs.at(0)});
} }
......
...@@ -9,7 +9,7 @@ namespace gpu { ...@@ -9,7 +9,7 @@ namespace gpu {
shape hip_argmin::compute_shape(const std::vector<shape>& inputs) const shape hip_argmin::compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2).standard(); check_shapes{inputs, *this}.has(2);
return op.normalize_compute_shape({inputs.at(0)}); return op.normalize_compute_shape({inputs.at(0)});
} }
......
...@@ -108,12 +108,13 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option ...@@ -108,12 +108,13 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
srcs.push_back(src_file{fs::path{"main.cpp"}, srcs.push_back(src_file{fs::path{"main.cpp"},
std::make_pair(content.data(), content.data() + content.size())}); std::make_pair(content.data(), content.data() + content.size())});
auto args_hpp = auto args_hpp =
generate_args_hpp(options.reduced_inputs.empty() ? options.inputs : options.reduced_inputs); generate_args_hpp(options.virtual_inputs.empty() ? options.inputs : options.virtual_inputs);
srcs.push_back(src_file{fs::path{"args.hpp"}, srcs.push_back(src_file{fs::path{"args.hpp"},
std::make_pair(args_hpp.data(), args_hpp.data() + args_hpp.size())}); std::make_pair(args_hpp.data(), args_hpp.data() + args_hpp.size())});
options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global); options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global);
options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local); options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local);
options.params += " " + join_strings(compiler_warnings(), " "); options.params += " " + join_strings(compiler_warnings(), " ");
options.params += " -ftemplate-backtrace-limit=0";
options.params += " -Werror"; options.params += " -Werror";
auto cos = compile_hip_src(srcs, std::move(options.params), get_device_name()); auto cos = compile_hip_src(srcs, std::move(options.params), get_device_name());
if(cos.size() != 1) if(cos.size() != 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment