"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "2ec8ba6aabed5275e7a8b1c114fd0217fc92a7ca"
Unverified Commit 97d4bb6c authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'develop' into add_parity_check_ci

parents 39b097c7 bdbc38bc
...@@ -69,7 +69,7 @@ struct multibroadcast ...@@ -69,7 +69,7 @@ struct multibroadcast
auto make_bcast_strides = [&](std::vector<std::size_t> bcast_lens, std::size_t offset) { auto make_bcast_strides = [&](std::vector<std::size_t> bcast_lens, std::size_t offset) {
std::vector<size_t> bcast_strides(bcast_lens.size(), 0); std::vector<size_t> bcast_strides(bcast_lens.size(), 0);
for(std::ptrdiff_t i = s0.lens().size() - 1; i >= 0; i--) for(std::ptrdiff_t i = s0.ndim() - 1; i >= 0; i--)
{ {
if(bcast_lens[i + offset] == s0.lens()[i]) if(bcast_lens[i + offset] == s0.lens()[i])
{ {
...@@ -84,13 +84,13 @@ struct multibroadcast ...@@ -84,13 +84,13 @@ struct multibroadcast
if(s0.dynamic()) if(s0.dynamic())
MIGRAPHX_THROW( MIGRAPHX_THROW(
"MULTIBROADCAST: Single dynamic input shape not supported. Use two inputs."); "MULTIBROADCAST: Single dynamic input shape not supported. Use two inputs.");
if(s0.lens().size() > output_lens.size()) if(s0.ndim() > output_lens.size())
{ {
MIGRAPHX_THROW("MULTIBROADCAST: input dimensions should <= output size"); MIGRAPHX_THROW("MULTIBROADCAST: input dimensions should <= output size");
} }
auto offset = output_lens.size() - s0.lens().size(); auto offset = output_lens.size() - s0.ndim();
for(std::ptrdiff_t i = s0.lens().size() - 1; i >= 0; i--) for(std::ptrdiff_t i = s0.ndim() - 1; i >= 0; i--)
{ {
if(output_lens[i + offset] != s0.lens()[i] and s0.lens()[i] != 1) if(output_lens[i + offset] != s0.lens()[i] and s0.lens()[i] != 1)
{ {
...@@ -119,7 +119,7 @@ struct multibroadcast ...@@ -119,7 +119,7 @@ struct multibroadcast
{ {
// output_lens will not be set for 2+ input version // output_lens will not be set for 2+ input version
auto bcast_lens = compute_common_lens(inputs); auto bcast_lens = compute_common_lens(inputs);
auto offset = bcast_lens.size() - s0.lens().size(); auto offset = bcast_lens.size() - s0.ndim();
auto bcast_strides = make_bcast_strides(bcast_lens, offset); auto bcast_strides = make_bcast_strides(bcast_lens, offset);
return {t, std::move(bcast_lens), std::move(bcast_strides)}; return {t, std::move(bcast_lens), std::move(bcast_strides)};
} }
......
...@@ -143,7 +143,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs) ...@@ -143,7 +143,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
if(inputs.empty()) if(inputs.empty())
MIGRAPHX_THROW("At least one input is required for " + x.name()); MIGRAPHX_THROW("At least one input is required for " + x.name());
dependent_type<operation, T> y = x; dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].max_lens()); normalize_attributes(y, inputs[0]);
return any_cast<T>(y).normalize_compute_shape(inputs); return any_cast<T>(y).normalize_compute_shape(inputs);
} }
...@@ -251,9 +251,10 @@ auto compute_op(rank<1>, ...@@ -251,9 +251,10 @@ auto compute_op(rank<1>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) F f) -> decltype(x.compute(make_compute_output_shape(pack(x, output, inputs)),
-> decltype( inputs,
x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f)) module_args,
f))
{ {
return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f); return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f);
} }
...@@ -309,9 +310,10 @@ auto compute_op(rank<3>, ...@@ -309,9 +310,10 @@ auto compute_op(rank<3>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) F f) -> decltype(x.compute(make_compute_output_shape(pack(x, output, inputs)),
-> decltype( inputs,
x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f)) module_args,
f))
{ {
return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f); return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f);
} }
......
...@@ -45,9 +45,10 @@ ...@@ -45,9 +45,10 @@
#include <migraphx/op/contiguous.hpp> #include <migraphx/op/contiguous.hpp>
#include <migraphx/op/convert.hpp> #include <migraphx/op/convert.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/convolution_backwards.hpp>
#include <migraphx/op/cosh.hpp> #include <migraphx/op/cosh.hpp>
#include <migraphx/op/cos.hpp> #include <migraphx/op/cos.hpp>
#include <migraphx/op/deconvolution.hpp> #include <migraphx/op/dimensions_of.hpp>
#include <migraphx/op/div.hpp> #include <migraphx/op/div.hpp>
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/elu.hpp> #include <migraphx/op/elu.hpp>
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -240,6 +240,10 @@ struct MIGRAPHX_EXPORT shape ...@@ -240,6 +240,10 @@ struct MIGRAPHX_EXPORT shape
template <class Iterator> template <class Iterator>
std::size_t index(Iterator start, Iterator last) const std::size_t index(Iterator start, Iterator last) const
{ {
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
}
assert(std::distance(start, last) <= this->lens().size()); assert(std::distance(start, last) <= this->lens().size());
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
return std::inner_product(start, last, this->strides().begin(), std::size_t{0}); // NOLINT return std::inner_product(start, last, this->strides().begin(), std::size_t{0}); // NOLINT
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace verify {
// Compute the value of a range // Compute the value of a range
template <class R> template <class R>
...@@ -196,6 +197,7 @@ bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out ...@@ -196,6 +197,7 @@ bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out
return error <= threshold; return error <= threshold;
} }
} // namespace verify
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
...@@ -467,7 +467,7 @@ operation instruction::normalized_operator() const ...@@ -467,7 +467,7 @@ operation instruction::normalized_operator() const
if(this->need_normalization()) if(this->need_normalization())
{ {
auto s = this->inputs().front()->get_shape(); auto s = this->inputs().front()->get_shape();
if(not normalize_attributes(o, s.max_lens())) if(not normalize_attributes(o, s))
return this->get_operator(); return this->get_operator();
} }
return o; return o;
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -35,8 +35,9 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -35,8 +35,9 @@ inline namespace MIGRAPHX_INLINE_NS {
* vec: the vector attribute to normalize * vec: the vector attribute to normalize
* axes: the operator's axes attribute if it exists, empty otherwise * axes: the operator's axes attribute if it exists, empty otherwise
* val: the normalize_axes key and options. Ex: normalize["axes"] = * val: the normalize_axes key and options. Ex: normalize["axes"] =
* value::array{normalize_attribute::include_min}; lens: shape dimensions passed when calling * value::array{normalize_attribute::include_min};
* normalize_attributes(op&, lens) * input_shape: input shape passed when calling
* normalize_attributes(op&, input_shape)
* *
* See normalize_attribute.hpp for explaining the options. * See normalize_attribute.hpp for explaining the options.
*/ */
...@@ -44,11 +45,11 @@ template <class Message> ...@@ -44,11 +45,11 @@ template <class Message>
auto tune_attribute(const std::vector<int64_t>& vec, auto tune_attribute(const std::vector<int64_t>& vec,
const std::vector<int64_t>& axes, const std::vector<int64_t>& axes,
const value& val, const value& val,
const std::vector<std::size_t>& lens, const shape& input_shape,
Message m) Message m)
{ {
std::vector<int64_t> result(vec); std::vector<int64_t> result(vec);
int64_t n_rank = lens.size(); int64_t n_rank = input_shape.ndim();
std::vector<op::normalize_attribute> vec_attrs = val.to_vector<op::normalize_attribute>(); std::vector<op::normalize_attribute> vec_attrs = val.to_vector<op::normalize_attribute>();
if(contains(vec_attrs, op::normalize_attribute::use_output)) if(contains(vec_attrs, op::normalize_attribute::use_output))
{ {
...@@ -56,9 +57,28 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -56,9 +57,28 @@ auto tune_attribute(const std::vector<int64_t>& vec,
} }
std::vector<int64_t> max_vals(vec.size(), n_rank); std::vector<int64_t> max_vals(vec.size(), n_rank);
if(contains(vec_attrs, op::normalize_attribute::use_len)) if(contains(vec_attrs, op::normalize_attribute::use_len))
{ {
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) { return lens[i]; }); if(input_shape.dynamic())
{
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) {
const auto& dd = input_shape.dyn_dims().at(i);
if(not dd.is_fixed())
{
MIGRAPHX_THROW(
"NORMALIZE_ATTR: 'use_lens' on a non-fixed dynamic dimension, axis=" +
std::to_string(i));
}
return dd.max;
});
}
else
{
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) {
return input_shape.lens().at(i);
});
}
} }
if(contains(vec_attrs, op::normalize_attribute::clip_max)) if(contains(vec_attrs, op::normalize_attribute::clip_max))
...@@ -159,9 +179,9 @@ auto tune_pad_attribute(const value& val) ...@@ -159,9 +179,9 @@ auto tune_pad_attribute(const value& val)
/** /**
* Assumptions: * Assumptions:
* Dimensions to pad start from the third dimension (index 2). * Dimensions to pad start from the third dimension (index 2).
* Called by compute_shape_op() with the `lens` of the first input. * Called by compute_shape_op() with the shape of the first input.
*/ */
bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) bool normalize_attributes(operation& op, const shape& input_shape)
{ {
bool tuned = false; bool tuned = false;
auto attrs = op.attributes(); auto attrs = op.attributes();
...@@ -172,9 +192,9 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -172,9 +192,9 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
auto padding_size = padding.size(); auto padding_size = padding.size();
auto padding_start = 2; auto padding_start = 2;
if(padding_size == 2 * (lens.size() - padding_start)) if(padding_size == 2 * (input_shape.ndim() - padding_start))
tuned = true; tuned = true;
else if(padding_size != (lens.size() - padding_start)) else if(padding_size != (input_shape.ndim() - padding_start))
MIGRAPHX_THROW("inconsistent padding size"); MIGRAPHX_THROW("inconsistent padding size");
else else
{ {
...@@ -205,7 +225,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -205,7 +225,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
axes = val.at("axes").without_key().to_vector<int64_t>(); axes = val.at("axes").without_key().to_vector<int64_t>();
} }
auto vec = vv.to_vector<int64_t>(); auto vec = vv.to_vector<int64_t>();
auto result = tune_attribute(vec, axes, rv.without_key(), lens, message); auto result = tune_attribute(vec, axes, rv.without_key(), input_shape, message);
val[key] = result; val[key] = result;
op.from_value(val); op.from_value(val);
val = op.to_value(); val = op.to_value();
...@@ -214,7 +234,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -214,7 +234,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
else else
{ {
auto num = vv.to<int64_t>(); auto num = vv.to<int64_t>();
auto result = tune_attribute({num}, {num}, rv.without_key(), lens, message); auto result = tune_attribute({num}, {num}, rv.without_key(), input_shape, message);
val[key] = result.front(); val[key] = result.front();
op.from_value(val); op.from_value(val);
val = op.to_value(); val = op.to_value();
......
...@@ -45,7 +45,7 @@ void normalize_ops::apply(module& m) const ...@@ -45,7 +45,7 @@ void normalize_ops::apply(module& m) const
auto s = inputs[0]->get_shape(); auto s = inputs[0]->get_shape();
migraphx::operation tuned_op = ins->get_operator(); migraphx::operation tuned_op = ins->get_operator();
if(normalize_attributes(tuned_op, s.max_lens())) if(normalize_attributes(tuned_op, s))
{ {
m.replace_instruction(ins, tuned_op, inputs); m.replace_instruction(ins, tuned_op, inputs);
ins->set_normalized(); ins->set_normalized();
......
...@@ -57,13 +57,12 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -57,13 +57,12 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
auto x_rank = x_lens.size(); auto x_rank = x_lens.size();
if(x_rank == 1 or x_rank == 2) if(x_rank == 1 or x_rank == 2)
{ {
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}}); auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}); auto x_sub_mean = info.add_broadcastable_binary_op("sub", args[0], args[3]);
auto numer = info.add_broadcastable_binary_op("sub", args[0], args[3]); auto var_eps = info.add_broadcastable_binary_op("add", args[4], eps);
auto var_eps = info.add_broadcastable_binary_op("add", args[4], eps); auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto denom = info.add_broadcastable_binary_op("pow", var_eps, rt); auto mul0 = info.add_broadcastable_binary_op("mul", args[1], rsqrt);
auto div0 = info.add_broadcastable_binary_op("div", numer, denom); auto r0 = info.add_broadcastable_binary_op("mul", x_sub_mean, mul0);
auto r0 = info.add_broadcastable_binary_op("mul", div0, args[1]);
return info.add_broadcastable_binary_op("add", r0, args[2]); return info.add_broadcastable_binary_op("add", r0, args[2]);
} }
else if(x_rank > 2) else if(x_rank > 2)
...@@ -71,7 +70,6 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -71,7 +70,6 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
// unsqueeze tensors of shape (C) to broadcast correctly // unsqueeze tensors of shape (C) to broadcast correctly
std::vector<int64_t> unsqueeze_axes(x_lens.size() - 2); std::vector<int64_t> unsqueeze_axes(x_lens.size() - 2);
std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 1); std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 1);
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}); auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto scale_unsqueeze = info.add_instruction( auto scale_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[1]); migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[1]);
...@@ -81,11 +79,11 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -81,11 +79,11 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[3]); migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[3]);
auto var_unsqueeze = info.add_instruction( auto var_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[4]); migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[4]);
auto numer = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze); auto x_sub_mean = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze);
auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps); auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps);
auto denom = info.add_broadcastable_binary_op("pow", var_eps, rt); auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto div0 = info.add_broadcastable_binary_op("div", numer, denom); auto mul0 = info.add_broadcastable_binary_op("mul", scale_unsqueeze, rsqrt);
auto r0 = info.add_broadcastable_binary_op("mul", div0, scale_unsqueeze); auto r0 = info.add_broadcastable_binary_op("mul", x_sub_mean, mul0);
return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze); return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze);
} }
else else
......
...@@ -42,7 +42,7 @@ std::vector<int64_t> to_int64_vector(const std::vector<T>& input_vector) ...@@ -42,7 +42,7 @@ std::vector<int64_t> to_int64_vector(const std::vector<T>& input_vector)
return output_vector; return output_vector;
} }
struct parse_deconvolution : op_parser<parse_deconvolution> struct parse_conv_transpose : op_parser<parse_conv_transpose>
{ {
std::vector<op_desc> operators() const { return {{"ConvTranspose"}}; } std::vector<op_desc> operators() const { return {{"ConvTranspose"}}; }
...@@ -51,17 +51,15 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -51,17 +51,15 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
operation op = make_op("deconvolution"); operation op = make_op("convolution_backwards");
value values = op.to_value(); value values = op.to_value();
// op::deconvolution op; auto l0 = args[0];
auto l0 = args[0];
std::vector<std::int64_t> padding; std::vector<std::int64_t> padding;
bool asym_padding = false; bool asym_padding = false;
auto in_lens = l0->get_shape().lens(); assert(l0->get_shape().ndim() > 2);
assert(in_lens.size() > 2); auto kdims = l0->get_shape().ndim() - 2;
auto kdims = in_lens.size() - 2;
// ensure pads availabe only when auto_pad is "NOT_SET" // ensure pads available only when auto_pad is "NOT_SET"
check_padding_mode(info, "CONV_TRANSPOSE"); check_padding_mode(info, "CONV_TRANSPOSE");
if(contains(info.attributes, "pads")) if(contains(info.attributes, "pads"))
...@@ -70,9 +68,9 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -70,9 +68,9 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
asym_padding = is_asym_padding(padding); asym_padding = is_asym_padding(padding);
size_t pad_ndims = padding.size() / 2;
if(not asym_padding) if(not asym_padding)
{ {
size_t pad_ndims = padding.size() / 2;
check_attr_sizes(kdims, pad_ndims, "PARSE_CONV_TRANSPOSE: inconsistent paddings"); check_attr_sizes(kdims, pad_ndims, "PARSE_CONV_TRANSPOSE: inconsistent paddings");
values["padding"].clear(); values["padding"].clear();
std::transform(padding.begin(), std::transform(padding.begin(),
...@@ -80,7 +78,19 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -80,7 +78,19 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
std::back_inserter(values["padding"]), std::back_inserter(values["padding"]),
[](auto pad_val) { return pad_val; }); [](auto pad_val) { return pad_val; });
} }
else if(l0->get_shape().dynamic())
{
MIGRAPHX_THROW("PARSE_CONV_TRANSPOSE: asymmetric padding (padding_L != padding_R) "
"not supported with dynamic shapes");
}
else
{
// set padding to 0s, asym_padding handled by parser with slice
// TODO changing parser and op to do asym padding in op
values["padding"] = std::vector<std::size_t>(pad_ndims, 0);
}
} }
if(contains(info.attributes, "strides")) if(contains(info.attributes, "strides"))
{ {
values["stride"].clear(); values["stride"].clear();
...@@ -88,6 +98,7 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -88,6 +98,7 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
check_attr_sizes( check_attr_sizes(
kdims, values["stride"].size(), "PARSE_CONV_TRANSPOSE: inconsistent strides"); kdims, values["stride"].size(), "PARSE_CONV_TRANSPOSE: inconsistent strides");
} }
if(contains(info.attributes, "dilations")) if(contains(info.attributes, "dilations"))
{ {
values["dilation"].clear(); values["dilation"].clear();
...@@ -97,21 +108,10 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -97,21 +108,10 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
} }
// TODO: auto padding needs to be implemented for this parser and operator // TODO: auto padding needs to be implemented for this parser and operator
if(contains(info.attributes, "auto_pad")) if(contains(info.attributes, "auto_pad") and
to_upper(info.attributes.at("auto_pad").s()) != "NOTSET")
{ {
auto s = info.attributes["auto_pad"].s(); MIGRAPHX_THROW("PARSE_CONV_TRANSPOSE: auto padding not supported");
if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
{
MIGRAPHX_THROW("PARSE_CONV_TRANSPOSE: auto_pad and padding cannot be specified "
"simultaneously");
}
if(s.find("SAME") != std::string::npos)
{
bool is_same_upper = (s.find("SAME_UPPER") != std::string::npos);
values["padding_mode"] = is_same_upper ? to_value(op::padding_mode_t::same_upper)
: to_value(op::padding_mode_t::same_lower);
}
} }
if(contains(info.attributes, "group")) if(contains(info.attributes, "group"))
...@@ -122,11 +122,11 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -122,11 +122,11 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
recalc_conv_attributes(values, kdims); recalc_conv_attributes(values, kdims);
op.from_value(values); op.from_value(values);
auto l1 = info.add_instruction(op, l0, args[1]); auto l1 = info.add_instruction(op, l0, args[1]);
std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
std::vector<int64_t> curr_shape(dims.begin() + 2, dims.end());
if(asym_padding) if(asym_padding)
{ {
std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
std::vector<int64_t> curr_shape(dims.begin() + 2, dims.end());
std::vector<int64_t> axes(kdims); std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2); // ignore first 2 dims std::iota(axes.begin(), axes.end(), 2); // ignore first 2 dims
...@@ -144,9 +144,11 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -144,9 +144,11 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), l1); make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), l1);
} }
if(contains(info.attributes, "output_padding")) // TODO, should check output_padding < (strides or dilations)
if(contains(info.attributes, "output_padding") and
not contains(info.attributes, "output_shape"))
{ {
size_t non_kdims = dims.size() * 2 - kdims; size_t non_kdims = l1->get_shape().ndim() * 2 - kdims;
std::vector<int64_t> output_padding(non_kdims, 0); std::vector<int64_t> output_padding(non_kdims, 0);
copy(info.attributes["output_padding"].ints(), std::back_inserter(output_padding)); copy(info.attributes["output_padding"].ints(), std::back_inserter(output_padding));
check_attr_sizes(kdims, check_attr_sizes(kdims,
...@@ -155,14 +157,21 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -155,14 +157,21 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
l1 = info.add_instruction(make_op("pad", {{"pads", output_padding}}), l1); l1 = info.add_instruction(make_op("pad", {{"pads", output_padding}}), l1);
} }
// TODO, doing unnecessary calcuations with this. Could instead
// calculate the padding to conv_transpose that would give the output_shape.
if(contains(info.attributes, "output_shape")) if(contains(info.attributes, "output_shape"))
{ {
if(l1->get_shape().dynamic())
{
MIGRAPHX_THROW("PARSE_CONV_TRANSPOSE: output_shape attribute and dynamic shapes "
"not supported");
}
std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
std::vector<int64_t> curr_shape(dims.begin() + 2, dims.end());
std::vector<int64_t> output_shape; std::vector<int64_t> output_shape;
copy(info.attributes["output_shape"].ints(), std::back_inserter(output_shape)); copy(info.attributes["output_shape"].ints(), std::back_inserter(output_shape));
check_attr_sizes( check_attr_sizes(
kdims, output_shape.size(), "PARSE_CONV_TRANSPOSE: inconsistent output shape"); kdims, output_shape.size(), "PARSE_CONV_TRANSPOSE: inconsistent output shape");
dims = to_int64_vector(l1->get_shape().lens());
copy(dims.begin() + 2, dims.end(), curr_shape.begin());
if(curr_shape != output_shape) if(curr_shape != output_shape)
{ {
std::vector<int64_t> target_padding(dims.size() * 2 - kdims, 0); std::vector<int64_t> target_padding(dims.size() * 2 - kdims, 0);
......
...@@ -79,13 +79,11 @@ struct parse_instancenorm : op_parser<parse_instancenorm> ...@@ -79,13 +79,11 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto x = args[0]; auto x = args[0];
auto scale = args[1]; auto scale = args[1];
auto bias = args[2]; auto bias = args[2];
auto dims = x->get_shape().lens();
if(not contains(valid_types, dtype)) if(not contains(valid_types, dtype))
MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) + MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) +
". Valid types are 1 (float), 10 (half), and 11 (double)."); ". Valid types are 1 (float), 10 (half), and 11 (double).");
bool dyn_input = x->get_shape().dynamic(); auto ndims = x->get_shape().ndim();
auto ndims = x->get_shape().ndim();
assert(ndims >= 2); assert(ndims >= 2);
auto kdims = ndims - 2; auto kdims = ndims - 2;
std::vector<int64_t> axes(kdims); std::vector<int64_t> axes(kdims);
...@@ -102,6 +100,12 @@ struct parse_instancenorm : op_parser<parse_instancenorm> ...@@ -102,6 +100,12 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
(dtype == shape::half_type and not convert_fp16) ? "reduce_sum" : "reduce_mean"; (dtype == shape::half_type and not convert_fp16) ? "reduce_sum" : "reduce_mean";
if(dtype == shape::half_type and not convert_fp16) if(dtype == shape::half_type and not convert_fp16)
{ {
if(x->get_shape().dynamic())
{
MIGRAPHX_THROW("PARSE_INSTANCENORM: half type not supported with dynamic shape "
"unless convert_fp16 is TRUE");
}
auto dims = x->get_shape().lens();
double n = double n =
std::accumulate(dims.begin() + 2, dims.end(), 1, [&](const auto& i, const auto& j) { std::accumulate(dims.begin() + 2, dims.end(), 1, [&](const auto& i, const auto& j) {
return i * j; return i * j;
...@@ -122,13 +126,14 @@ struct parse_instancenorm : op_parser<parse_instancenorm> ...@@ -122,13 +126,14 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
// both scale and bias. // both scale and bias.
instruction_ref scale_bcast; instruction_ref scale_bcast;
instruction_ref bias_bcast; instruction_ref bias_bcast;
if(dyn_input) if(x->get_shape().dynamic())
{ {
scale_bcast = info.add_instruction(make_op("broadcast", {{"axis", 1}}), scale, x); scale_bcast = info.add_instruction(make_op("broadcast", {{"axis", 1}}), scale, x);
bias_bcast = info.add_instruction(make_op("broadcast", {{"axis", 1}}), bias, x); bias_bcast = info.add_instruction(make_op("broadcast", {{"axis", 1}}), bias, x);
} }
else else
{ {
auto dims = x->get_shape().lens();
scale_bcast = info.add_instruction( scale_bcast = info.add_instruction(
make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale); make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
bias_bcast = bias_bcast =
......
...@@ -30,8 +30,11 @@ namespace migraphx { ...@@ -30,8 +30,11 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace onnx { namespace onnx {
// Use a literal instruction to replace the shape since, output of /**
// shape operator are literals in migraphx * If static shape input, creates a literal in migraphx.
* If dynamic shape input, creates a dimensions_of operator in migraphx (runtime evaluation of
* shape).
*/
struct parse_shape : op_parser<parse_shape> struct parse_shape : op_parser<parse_shape>
{ {
std::vector<op_desc> operators() const { return {{"Shape"}}; } std::vector<op_desc> operators() const { return {{"Shape"}}; }
...@@ -43,13 +46,54 @@ struct parse_shape : op_parser<parse_shape> ...@@ -43,13 +46,54 @@ struct parse_shape : op_parser<parse_shape>
{ {
if(args.size() != 1) if(args.size() != 1)
MIGRAPHX_THROW("Shape: operator should have 1 operand"); MIGRAPHX_THROW("Shape: operator should have 1 operand");
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens(); auto input_shape = args[0]->get_shape();
std::vector<int64_t> vec_shape(arg_shape.size()); int input_ndim = input_shape.ndim();
migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()}); std::size_t start = 0;
std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) { std::size_t end = input_ndim;
return int64_t(i); // Normalizing the start and end is handled here because of how the static shape version
}); // works. Clamping to [-r, r], where r is ndim of input and then making positive.
return info.add_literal(migraphx::literal{s, vec_shape}); auto normalize_ind = [&](int64_t ind) {
if(ind < (-1 * input_ndim))
{
ind = -1 * input_ndim;
}
if(ind > input_ndim)
{
ind = input_ndim;
}
return (ind >= 0) ? ind : input_ndim + ind;
};
if(contains(info.attributes, "end"))
{
end = normalize_ind(info.attributes.at("end").i());
}
if(contains(info.attributes, "start"))
{
start = normalize_ind(info.attributes.at("start").i());
}
if(end <= start)
{
MIGRAPHX_THROW("PARSE_SHAPE: ending axis <= starting axis, end: " +
std::to_string(end) + " start: " + std::to_string(start));
}
if(input_shape.dynamic())
{
return info.add_instruction(make_op("dimensions_of", {{"start", start}, {"end", end}}),
args[0]);
}
else
{
std::size_t output_ndim = end - start;
std::vector<int64_t> vec_shape(output_ndim);
migraphx::shape s(migraphx::shape::int64_type, {output_ndim});
std::vector<std::size_t> input_lens = input_shape.lens();
std::transform(input_lens.begin() + start,
input_lens.begin() + end,
vec_shape.begin(),
[](auto i) { return int64_t(i); });
return info.add_literal(migraphx::literal{s, vec_shape});
}
} }
}; };
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_qdq.hpp> #include <migraphx/simplify_qdq.hpp>
#include <migraphx/eliminate_common_subexpression.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/optimize_module.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -48,19 +49,12 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS) ...@@ -48,19 +49,12 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
// This function is to convert any instructions specified in the input // This function is to convert any instructions specified in the input
// from double or float to float16 by inserting a convert operator. // from double or float to float16 by inserting a convert operator.
// For the conversion, there could be cases of overflowing, but it // For the conversion, there could be cases of overflowing or underflowing, but it
// is very rare in the area of deeping learning, so we just do a // is uncommon. Run optimize_module() before converting to fp16 to const eval and fold in FP32 to
// truncate of the input to get the fp16. // avoid loss of precision.
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names) void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
{ {
run_passes(prog, run_passes(prog, {optimize_module{}, quantize_fp16_pass{ins_names}, optimize_module{}});
{quantize_fp16_pass{ins_names},
eliminate_common_subexpression{},
dead_code_elimination{},
simplify_reshapes{},
dead_code_elimination{},
simplify_qdq{},
dead_code_elimination{}});
} }
void quantize_int8(program& prog, void quantize_int8(program& prog,
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -273,9 +273,23 @@ shape shape::from_permutation(type_t t, ...@@ -273,9 +273,23 @@ shape shape::from_permutation(type_t t,
shape::type_t shape::type() const { return impl->m_type; } shape::type_t shape::type() const { return impl->m_type; }
const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; } const std::vector<std::size_t>& shape::lens() const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: lens() called on a dynamic shape");
}
return impl->m_lens;
}
const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; } const std::vector<std::size_t>& shape::strides() const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: strides() called on a dynamic shape");
}
return impl->m_strides;
}
std::size_t shape::ndim() const std::size_t shape::ndim() const
{ {
...@@ -535,7 +549,14 @@ bool shape::any_of_dynamic() const ...@@ -535,7 +549,14 @@ bool shape::any_of_dynamic() const
}); });
} }
const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; } const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const
{
if(not this->dynamic())
{
MIGRAPHX_THROW("SHAPE: dyn_dims() called on a static shape");
}
return impl->m_dyn_dims;
}
std::vector<std::size_t> shape::min_lens() const std::vector<std::size_t> shape::min_lens() const
{ {
...@@ -679,12 +700,22 @@ const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; } ...@@ -679,12 +700,22 @@ const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; }
void migraphx_to_value(value& v, const shape& s) void migraphx_to_value(value& v, const shape& s)
{ {
value result; value result;
result["type"] = migraphx::to_value(s.type_string()); result["type"] = migraphx::to_value(s.type_string());
result["lens"] = migraphx::to_value(s.lens()); result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
result["strides"] = migraphx::to_value(s.strides()); // avoid calling functions that will throw
result["sub_shapes"] = migraphx::to_value(s.sub_shapes()); if(s.dynamic())
result["dynamic_dimensions"] = migraphx::to_value(s.dyn_dims()); {
v = result; result["lens"] = {};
result["strides"] = {};
result["dynamic_dimensions"] = migraphx::to_value(s.dyn_dims());
}
else
{
result["lens"] = migraphx::to_value(s.lens());
result["strides"] = migraphx::to_value(s.strides());
result["dynamic_dimensions"] = {};
}
v = result;
} }
void migraphx_from_value(const value& v, shape& s) void migraphx_from_value(const value& v, shape& s)
......
...@@ -89,38 +89,23 @@ struct find_reshaper ...@@ -89,38 +89,23 @@ struct find_reshaper
{ {
auto matcher() const auto matcher() const
{ {
return match::name(reshaper_names())( auto reshaper = match::name(reshaper_names());
match::any_of[match::outputs()](match::name(reshaper_names()))); auto contiguous = match::name("contiguous");
auto no_output_reshape = match::none_of[match::outputs()](reshaper);
auto input_reshape = match::arg(0)(match::skip(contiguous)(reshaper));
auto input = match::skip(reshaper, contiguous)(match::any().bind("x"));
return reshaper(no_output_reshape, input_reshape, input);
} }
void apply(module& m, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
std::vector<instruction_ref> reshapes{ins}; auto input = mr.instructions["x"];
while(is_reshaper(reshapes.back())) auto dims = ins->get_shape().lens();
{
assert(not reshapes.back()->inputs().empty());
assert(m.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front();
reshapes.push_back(input);
}
std::pair<instruction_ref, instruction_ref> r{m.end(), m.end()}; if(not input->get_shape().standard())
for(auto start : iterator_for(reshapes)) input = m.insert_instruction(ins, make_op("contiguous"), input);
{ m.replace_instruction(ins, make_op("reshape", {{"dims", dims}}), input);
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
return i->get_shape() == (*start)->get_shape() and i != (*start);
});
if(last != reshapes.rend())
{
r = std::make_pair(*start, *last);
break;
}
}
if(r.first != r.second)
{
m.replace_instruction(r.first, r.second);
}
} }
}; };
...@@ -804,9 +789,9 @@ void simplify_reshapes::apply(module& m) const ...@@ -804,9 +789,9 @@ void simplify_reshapes::apply(module& m) const
match::find_matches(m, match::find_matches(m,
find_where_op{}, find_where_op{},
find_resize{}, find_resize{},
find_reshape_cont{},
find_nop_reshapes{}, find_nop_reshapes{},
find_reshaper{}, find_reshaper{},
find_reshape_cont{},
find_transpose{}, find_transpose{},
find_concat_transpose{}, find_concat_transpose{},
find_concat_multibroadcasts{}, find_concat_multibroadcasts{},
......
...@@ -23,14 +23,14 @@ ...@@ -23,14 +23,14 @@
*/ */
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/cpu/dnnl.hpp> #include <migraphx/cpu/dnnl.hpp>
#include <migraphx/op/deconvolution.hpp> #include <migraphx/op/convolution_backwards.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace cpu { namespace cpu {
struct dnnl_deconvolution struct dnnl_deconvolution
: dnnl_extend_op<dnnl_deconvolution, dnnl::deconvolution_forward, op::deconvolution> : dnnl_extend_op<dnnl_deconvolution, dnnl::deconvolution_forward, op::convolution_backwards>
{ {
std::vector<int> arg_map(int) const std::vector<int> arg_map(int) const
{ {
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp> #include <migraphx/op/convolution_backwards.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
...@@ -345,7 +345,7 @@ struct cpu_apply ...@@ -345,7 +345,7 @@ struct cpu_apply
extend_op("contiguous", "dnnl::reorder"); extend_op("contiguous", "dnnl::reorder");
extend_op("convolution", "dnnl::convolution"); extend_op("convolution", "dnnl::convolution");
#ifndef MIGRAPHX_ENABLE_ZENDNN #ifndef MIGRAPHX_ENABLE_ZENDNN
extend_op("deconvolution", "dnnl::deconvolution"); extend_op("convolution_backwards", "dnnl::convolution_backwards");
extend_op("dot", "dnnl::dot"); extend_op("dot", "dnnl::dot");
#endif #endif
extend_op("erf", "cpu::erf"); extend_op("erf", "cpu::erf");
......
...@@ -176,7 +176,7 @@ register_op(migraphx_gpu ...@@ -176,7 +176,7 @@ register_op(migraphx_gpu
OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot> OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot>
INCLUDES migraphx/gpu/context.hpp) INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp
OPERATORS gpu::miopen_convolution<op::convolution> gpu::miopen_convolution<op::deconvolution> gpu::miopen_convolution<op::quant_convolution> OPERATORS gpu::miopen_convolution<op::convolution> gpu::miopen_convolution<op::convolution_backwards> gpu::miopen_convolution<op::quant_convolution>
INCLUDES migraphx/gpu/context.hpp) INCLUDES migraphx/gpu/context.hpp)
rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_gpu) rocm_clang_tidy_check(migraphx_gpu)
...@@ -188,7 +188,9 @@ if(MIGRAPHX_ENABLE_MLIR) ...@@ -188,7 +188,9 @@ if(MIGRAPHX_ENABLE_MLIR)
find_package(rocMLIR 1.0.0 CONFIG REQUIRED) find_package(rocMLIR 1.0.0 CONFIG REQUIRED)
message(STATUS "Build with rocMLIR::rockCompiler ${rocMLIR_VERSION}") message(STATUS "Build with rocMLIR::rockCompiler ${rocMLIR_VERSION}")
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_MLIR") target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_MLIR")
target_link_libraries(migraphx_gpu PUBLIC rocMLIR::rockCompiler) # Make this private to avoid multiple inclusions of LLVM symbols.
# TODO: Fix rocMLIR's library to hide LLVM internals.
target_link_libraries(migraphx_gpu PRIVATE rocMLIR::rockCompiler)
endif() endif()
if(MIGRAPHX_USE_HIPRTC) if(MIGRAPHX_USE_HIPRTC)
...@@ -234,7 +236,12 @@ check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_ ...@@ -234,7 +236,12 @@ check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_
set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "") set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")
if(MIGRAPHX_USE_FIND_2_API) if(MIGRAPHX_USE_FIND_2_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API) check_library_exists(MIOpen "miopenSetFindOptionPreallocatedTensor" "${MIOPEN_LOCATION}" HAS_PREALLOCATION_API)
if(HAS_PREALLOCATION_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API -DMIGRAPHX_PREALLOCATE_MIOPEN_BUFFERS)
else()
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API)
endif()
message(STATUS "MIGraphx is using Find-2.0 API of MIOpen") message(STATUS "MIGraphx is using Find-2.0 API of MIOpen")
else() else()
message(STATUS "MIGraphx is using legacy Find API in MIOpen") message(STATUS "MIGraphx is using legacy Find API in MIOpen")
......
...@@ -135,14 +135,13 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over) ...@@ -135,14 +135,13 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std::size_t max_global = ctx.get_current_device().get_cu_count() * std::size_t max_global = ctx.get_current_device().get_cu_count() *
ctx.get_current_device().get_max_workitems_per_cu(); ctx.get_current_device().get_max_workitems_per_cu();
return [n, over, max_global](std::size_t local) { return [n, over, max_global](std::size_t local) {
std::size_t num_elements = n; // hip require global workitems multiple of local workitems. It may degrade performance.
// [TODO]: consider adding "fno-hip-uniform-block" flag when it becomes available.
// https://reviews.llvm.org/D155213
std::size_t num_elements = ((n + local - 1) / local) * local;
std::size_t groups = (num_elements + local - 1) / local; std::size_t groups = (num_elements + local - 1) / local;
std::size_t max_blocks = max_global / local; std::size_t max_blocks = max_global / local;
std::size_t nglobal = std::min(max_blocks * over, groups) * local; std::size_t nglobal = std::min(max_blocks * over, groups) * local;
#ifdef MIGRAPHX_USE_HIPRTC
if(enabled(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS{}))
num_elements = ((num_elements + local - 1) / local) * local;
#endif
return std::min(nglobal, num_elements); return std::min(nglobal, num_elements);
}; };
} }
......
...@@ -79,7 +79,7 @@ void compile_miopen::apply(module& m) const ...@@ -79,7 +79,7 @@ void compile_miopen::apply(module& m) const
std::size_t ws = 0; std::size_t ws = 0;
try try
{ {
// for the regular convolution and deconvolution, this try would always succeed // for the regular convolution and convolution_backwards, this try would always succeed
ws = compile(op, ins, int8_x4_format); ws = compile(op, ins, int8_x4_format);
} }
catch(migraphx::exception&) catch(migraphx::exception&)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment