"testing/vscode:/vscode.git/clone" did not exist on "b10d49b2c0197d74a2c2864e57a2f67a9d880345"
Unverified Commit 25e8cf0b authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'develop' into test_onnx_zoo

parents a313a68e 635502be
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -46,14 +47,60 @@ struct reshape ...@@ -46,14 +47,60 @@ struct reshape
value attributes() const { return {{"require_std_shape", true}}; } value attributes() const { return {{"require_std_shape", true}}; }
std::string name() const { return "reshape"; } std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const
shape dyn_compute_shape(shape s0) const
{
auto dyn_dims = s0.dyn_dims();
auto num_not_fixed = std::count_if(
dyn_dims.cbegin(), dyn_dims.cend(), [](auto dd) { return not dd.is_fixed(); });
if(num_not_fixed != 1)
{
MIGRAPHX_THROW("Reshape: Only supports one non-fixed dynamic_dimension");
}
// track number of fixed elements in input and output
std::size_t num_dims_ele = 1;
std::size_t num_dd_ele = 1;
for(std::size_t i = 0; i < dyn_dims.size(); ++i)
{
if(dyn_dims[i].is_fixed())
{
num_dims_ele *= dims[i];
num_dd_ele *= dyn_dims[i].min;
}
else
{
if(dims[i] != 0 and dims[i] != -1)
{
MIGRAPHX_THROW(
"Reshape: Non-fixed dynamic_dimension doesn't match with 0 or -1 "
"output dimension");
}
}
}
if(num_dims_ele != num_dd_ele)
{
MIGRAPHX_THROW("Reshape: Number of fixed elements must match. Input: " +
std::to_string(num_dd_ele) + " Output: " + std::to_string(num_dims_ele));
}
// construct output dynamic shape from dims attribute
std::vector<shape::dynamic_dimension> output_dyn_dims(dims.size());
std::transform(dims.cbegin(),
dims.cend(),
dyn_dims.cbegin(),
output_dyn_dims.begin(),
[](std::size_t dim, auto dyn_dim) {
if(not dyn_dim.is_fixed())
return dyn_dim;
return shape::dynamic_dimension{dim, dim};
});
return {s0.type(), output_dyn_dims};
}
shape static_compute_shape(std::vector<shape> inputs, std::size_t n_neg_dims) const
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.standard();
auto&& idims = inputs.front().lens(); auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end()); std::vector<std::size_t> rdims(dims.begin(), dims.end());
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1)
MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim");
for(std::size_t i = 0; i < dims.size(); i++) for(std::size_t i = 0; i < dims.size(); i++)
{ {
...@@ -86,9 +133,26 @@ struct reshape ...@@ -86,9 +133,26 @@ struct reshape
return s; return s;
} }
argument compute(shape output_shape, std::vector<argument> args) const shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1)
MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim");
auto s0 = inputs[0];
if(s0.dynamic())
{
return dyn_compute_shape(s0);
}
else
{
return static_compute_shape(inputs, n_neg_dims);
}
}
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
...@@ -53,15 +53,15 @@ struct softmax ...@@ -53,15 +53,15 @@ struct softmax
std::string name() const { return "softmax"; } std::string name() const { return "softmax"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
if(inputs.at(0).packed()) auto s0 = inputs[0];
if(s0.dynamic() or s0.packed())
{ {
return inputs.at(0); return s0;
} }
else else
{ {
auto lens = inputs.at(0).lens(); return {s0.type(), s0.lens()};
return {inputs.at(0).type(), lens};
} }
} }
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -54,52 +55,85 @@ struct squeeze ...@@ -54,52 +55,85 @@ struct squeeze
std::string name() const { return "squeeze"; } std::string name() const { return "squeeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); if(input_shape.dynamic())
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
{ {
MIGRAPHX_THROW("squeeze axis dimension should be equal to 1"); if(std::any_of(axes.begin(), axes.end(), [&](auto axis) {
} return input_shape.dyn_dims()[axis] != 1;
std::vector<std::size_t> new_lens; }))
std::vector<std::size_t> new_strides; {
if(axes.empty()) MIGRAPHX_THROW(
{ "SQUEEZE: dynamic axis dimension should be equal to {1, 1, 0} or {1, 1, 1}");
for(auto i : range(old_lens.size())) }
std::vector<shape::dynamic_dimension> dyn_dims = {};
if(axes.empty())
{
std::copy_if(input_shape.dyn_dims().cbegin(),
input_shape.dyn_dims().cend(),
std::back_inserter(dyn_dims),
[&](auto dd) { return dd != 1; });
}
else
{ {
if(old_lens[i] != 1) for(auto i : range(input_shape.ndim()))
{ {
new_lens.push_back(old_lens[i]); if(std::find(axes.begin(), axes.end(), i) == axes.end())
new_strides.push_back(old_strides[i]); {
dyn_dims.push_back(input_shape.dyn_dims()[i]);
}
} }
} }
return {input_shape.type(), dyn_dims};
} }
else else
{ {
for(auto i : range(old_lens.size())) auto type = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(std::any_of(
axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
{ {
if(std::find(axes.begin(), axes.end(), i) == axes.end()) MIGRAPHX_THROW("SQUEEZE: static axis dimension should be equal to 1");
}
std::vector<std::size_t> new_lens;
std::vector<std::size_t> new_strides;
if(axes.empty())
{
for(auto i : range(old_lens.size()))
{ {
new_lens.push_back(old_lens[i]); if(old_lens[i] != 1)
new_strides.push_back(old_strides[i]); {
new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
}
} }
} }
} else
if(new_lens.empty()) {
{ for(auto i : range(old_lens.size()))
return shape{type}; {
} if(std::find(axes.begin(), axes.end(), i) == axes.end())
else {
{ new_lens.push_back(old_lens[i]);
return shape{type, new_lens, new_strides}; new_strides.push_back(old_strides[i]);
}
}
}
if(new_lens.empty())
{
return shape{type};
}
else
{
return shape{type, new_lens, new_strides};
}
} }
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -45,17 +46,15 @@ struct transpose ...@@ -45,17 +46,15 @@ struct transpose
} }
std::string name() const { return "transpose"; } std::string name() const { return "transpose"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto input = inputs.at(0); auto input = inputs.at(0);
auto input_lens = input.lens();
auto input_strides = input.strides();
auto t = input.type();
if(dims.size() != input_lens.size()) if(dims.size() != input.ndim())
{ {
MIGRAPHX_THROW("Permutation has wrong number of axes"); MIGRAPHX_THROW("TRANSPOSE: Permutation has wrong number of axes");
} }
std::vector<int64_t> axes(dims.size()); std::vector<int64_t> axes(dims.size());
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
...@@ -63,19 +62,36 @@ struct transpose ...@@ -63,19 +62,36 @@ struct transpose
{ {
MIGRAPHX_THROW("TRANSPOSE: Invalid permutation"); MIGRAPHX_THROW("TRANSPOSE: Invalid permutation");
} }
std::vector<size_t> output_lens(input_lens.size());
std::vector<size_t> output_strides(input_lens.size()); if(input.dynamic())
for(std::size_t i = 0; i < output_lens.size(); i++)
{ {
output_lens[i] = input_lens[dims[i]]; std::vector<shape::dynamic_dimension> output_dyn_dims(input.ndim());
output_strides[i] = input_strides[dims[i]]; std::transform(dims.cbegin(), dims.cend(), output_dyn_dims.begin(), [&](auto dim) {
return input.dyn_dims()[dim];
});
return {input.type(), output_dyn_dims};
}
else
{
auto input_lens = input.lens();
auto input_strides = input.strides();
std::vector<size_t> output_lens(input.ndim());
std::vector<size_t> output_strides(input.ndim());
for(std::size_t i = 0; i < input.ndim(); i++)
{
output_lens[i] = input_lens[dims[i]];
output_strides[i] = input_strides[dims[i]];
}
return {input.type(), output_lens, output_strides};
} }
return {t, output_lens, output_strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -29,11 +29,20 @@ ...@@ -29,11 +29,20 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/**
* Adds dimensions to a tensor based on the axes attribute.
* `axes` are based on the number of output shape dimensions and should not contain duplicates.
* `steps` are for modifying dimensions added to the middle of the original shape.
* Each step must be a factor of the original dimension.
* ex: unsqueeze(shape = [3, 4, 10], axes = [2, 4, 5], steps = [2]) -> shape = [3, 4, 2, 5, 1, 1]
* Dynamic shape version does not handle `steps`.
*/
struct unsqueeze struct unsqueeze
{ {
std::vector<int64_t> axes; std::vector<int64_t> axes;
...@@ -56,63 +65,89 @@ struct unsqueeze ...@@ -56,63 +65,89 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; } std::string name() const { return "unsqueeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type();
auto old_lens = input_shape.lens(); if(input_shape.dynamic())
auto old_strides = input_shape.strides();
if(input_shape.scalar())
{ {
if(old_lens.size() == 1 and old_lens.front() == 1) if(not steps.empty())
return shape{type, old_lens}; {
else MIGRAPHX_THROW("UNSQUEEZE_dyn: nonempty steps attribute");
MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar"); }
std::vector<shape::dynamic_dimension> dyn_dims = {};
auto new_ndim = input_shape.ndim() + axes.size();
std::size_t k = 0;
for(auto i : range(new_ndim))
{
if(std::find(axes.begin(), axes.end(), i) != axes.end())
{
dyn_dims.push_back({1, 1, 0});
}
else
{
dyn_dims.push_back(input_shape.dyn_dims().at(k++));
}
}
return {input_shape.type(), dyn_dims};
} }
else
{
auto type = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(input_shape.scalar())
{
if(old_lens.size() == 1 and old_lens.front() == 1)
return shape{type, old_lens};
else
MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar");
}
if(steps.size() > axes.size()) if(steps.size() > axes.size())
MIGRAPHX_THROW("UNSQUEEZE: Steps provided with no axis"); MIGRAPHX_THROW("UNSQUEEZE: Steps provided with no axis");
std::size_t new_size = old_lens.size() + axes.size(); std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size); std::vector<std::size_t> new_lens(new_size);
std::vector<std::size_t> new_strides(new_size); std::vector<std::size_t> new_strides(new_size);
std::size_t p = 0; std::size_t p = 0;
for(auto i : range(new_size)) for(auto i : range(new_size))
{
auto axis_idx = std::find(axes.begin(), axes.end(), i) - axes.begin();
if(axis_idx < axes.size())
{ {
std::int64_t step = 1; auto axis_idx = std::find(axes.begin(), axes.end(), i) - axes.begin();
if(axis_idx < steps.size()) if(axis_idx < axes.size())
step = steps[axis_idx];
if(step == 0)
MIGRAPHX_THROW("UNSQUEEZE: step must be non-zero");
new_lens[i] = step;
if(p < old_strides.size())
{ {
if((old_lens[p] % step) != 0) std::int64_t step = 1;
MIGRAPHX_THROW("UNSQUEEZE: Axis dimenstion is not divisible by step"); if(axis_idx < steps.size())
old_lens[p] /= step; step = steps[axis_idx];
new_strides[i] = old_strides[p] * old_lens[p]; if(step == 0)
MIGRAPHX_THROW("UNSQUEEZE: step must be non-zero");
new_lens[i] = step;
if(p < old_strides.size())
{
if((old_lens[p] % step) != 0)
MIGRAPHX_THROW("UNSQUEEZE: Axis dimenstion is not divisible by step");
old_lens[p] /= step;
new_strides[i] = old_strides[p] * old_lens[p];
}
else
{
if(step != 1)
MIGRAPHX_THROW("UNSQUEEZE: Step must be 1 for extra axes");
new_strides[i] = 1;
}
} }
else else
{ {
if(step != 1) new_lens[i] = old_lens[p];
MIGRAPHX_THROW("UNSQUEEZE: Step must be 1 for extra axes"); new_strides[i] = old_strides[p++];
new_strides[i] = 1;
} }
} }
else return shape{type, new_lens, new_strides};
{
new_lens[i] = old_lens[p];
new_strides[i] = old_strides[p++];
}
} }
return shape{type, new_lens, new_strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -115,6 +115,7 @@ struct program ...@@ -115,6 +115,7 @@ struct program
print_func) const; print_func) const;
void print_graph(std::ostream& os, bool brief = false) const; void print_graph(std::ostream& os, bool brief = false) const;
void print_py(std::ostream& os) const;
void print_cpp(std::ostream& os) const; void print_cpp(std::ostream& os) const;
void dry_run(parameter_map params) const; void dry_run(parameter_map params) const;
......
...@@ -101,6 +101,19 @@ struct shape ...@@ -101,6 +101,19 @@ struct shape
friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y); friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y);
friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y); friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y);
friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x); friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x);
// compare to fixed std::size_t dimension
friend bool operator==(const dynamic_dimension& x, const std::size_t& y);
friend bool operator==(const std::size_t& x, const dynamic_dimension& y);
friend bool operator!=(const dynamic_dimension& x, const std::size_t& y);
friend bool operator!=(const std::size_t& x, const dynamic_dimension& y);
// add and subtract fixed std::size_t dimension
dynamic_dimension& operator+=(const std::size_t& x);
dynamic_dimension& operator-=(const std::size_t& x);
friend dynamic_dimension operator+(const dynamic_dimension& x, const std::size_t& y);
friend dynamic_dimension operator+(const std::size_t& x, const dynamic_dimension& y);
friend dynamic_dimension operator-(const dynamic_dimension& x, const std::size_t& y);
}; };
static const std::vector<type_t>& types(); static const std::vector<type_t>& types();
......
...@@ -31,6 +31,9 @@ ...@@ -31,6 +31,9 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
/**
* Iterates the given function over the indices from the shape in order.
*/
template <class F> template <class F>
void shape_for_each(const migraphx::shape& s, F f) void shape_for_each(const migraphx::shape& s, F f)
{ {
...@@ -51,7 +54,6 @@ void shape_for_each(const migraphx::shape& s, F f) ...@@ -51,7 +54,6 @@ void shape_for_each(const migraphx::shape& s, F f)
call(indices); call(indices);
} }
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -77,14 +77,14 @@ static void update_pooling(const instruction_ref& input, const instruction_ref& ...@@ -77,14 +77,14 @@ static void update_pooling(const instruction_ref& input, const instruction_ref&
{ {
return; return;
} }
auto kdims = input->get_shape().lens().size() - 2; auto kdims = input->get_shape().ndim() - 2;
if(std::equal(op.padding.begin(), if(std::equal(op.padding.begin(),
op.padding.begin() + kdims, op.padding.begin() + kdims,
op.padding.begin() + kdims, op.padding.begin() + kdims,
op.padding.end())) op.padding.end()))
return; return;
std::vector<int64_t> padding(input->get_shape().lens().size() * 2, 0); std::vector<int64_t> padding(input->get_shape().ndim() * 2, 0);
std::vector<size_t> pads_l(op.padding.begin(), op.padding.begin() + kdims); std::vector<size_t> pads_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> pads_r(op.padding.begin() + kdims, op.padding.end()); std::vector<size_t> pads_r(op.padding.begin() + kdims, op.padding.end());
op.padding = std::vector<size_t>(kdims * 2, 0); op.padding = std::vector<size_t>(kdims * 2, 0);
......
...@@ -302,6 +302,24 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod) ...@@ -302,6 +302,24 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
std::replace(module_args.begin(), module_args.end(), old, new_mod); std::replace(module_args.begin(), module_args.end(), old, new_mod);
} }
bool instruction::is_undefined() const
{
if(op.name() == "undefined")
{
return true;
}
else if(this->inputs().empty())
{
return false;
}
else
{
return std::all_of(this->inputs().begin(), this->inputs().end(), [](auto arg) {
return arg->is_undefined();
});
}
}
bool instruction::can_eval() const bool instruction::can_eval() const
{ {
if(op.name() == "@literal") if(op.name() == "@literal")
......
...@@ -789,6 +789,22 @@ static std::string cpp_var_name(const std::string& name) ...@@ -789,6 +789,22 @@ static std::string cpp_var_name(const std::string& name)
return to_c_id("x_" + replace_string(name, ":", "_module_")); return to_c_id("x_" + replace_string(name, ":", "_module_"));
} }
static void print_py_op(std::ostream& os, const operation& op)
{
auto v = op.to_value();
os << "migraphx.op(" << enclose_name(op.name());
auto default_values = make_op(op.name()).to_value();
for(auto&& x : v)
{
auto name = x.get_key();
if(default_values[name] == x)
continue;
os << ", " << name << "=" << to_json_string(x.without_key());
}
os << ")";
}
static void print_make_op(std::ostream& os, const operation& op) static void print_make_op(std::ostream& os, const operation& op)
{ {
auto v = op.to_value(); auto v = op.to_value();
...@@ -804,6 +820,15 @@ static void print_make_op(std::ostream& os, const operation& op) ...@@ -804,6 +820,15 @@ static void print_make_op(std::ostream& os, const operation& op)
os << ")"; os << ")";
} }
static void print_py_shape(std::ostream& os, const migraphx::shape& s)
{
os << "migraphx.shape(type=" << to_json_string(s.type_string())
<< ", lens=" << to_json_string(s.lens());
if(not s.standard())
os << ", strides=" << to_json_string(s.strides());
os << ")";
}
static void print_cpp_shape(std::ostream& os, const migraphx::shape& s) static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
{ {
os << "migraphx::shape{migraphx::shape::" << s.type_string(); os << "migraphx::shape{migraphx::shape::" << s.type_string();
...@@ -813,6 +838,68 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s) ...@@ -813,6 +838,68 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
os << "}"; os << "}";
} }
std::unordered_map<instruction_ref, std::string>
module::print_py(std::ostream& os,
const std::string& mname,
std::unordered_map<instruction_ref, std::string> names) const
{
// cppcheck-suppress variableScope
unsigned long seed = names.size();
auto last = std::prev(this->end());
names = this->print(
[&](auto ins, auto ins_names) {
std::vector<std::string> input_vars;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(input_vars),
[&](auto input) { return cpp_var_name(ins_names.at(input)); });
if(ins != last)
os << cpp_var_name(ins_names.at(ins)) << " = ";
if(ins->name() == "@literal")
{
os << mname << ".add_literal(";
bool use_abs = false;
ins->get_literal().visit([&](auto v) {
use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
});
// Disable abs for now
use_abs = false;
if(use_abs)
os << "migraphx.abs_literal(";
os << "migraphx.generate_literal(";
print_py_shape(os, ins->get_shape());
os << ", " << seed << ")";
if(use_abs)
os << ")";
os << ")" << std::endl;
seed++;
}
else if(ins->name() == "@param")
{
std::string name = any_cast<builtin::param>(ins->get_operator()).parameter;
os << mname << ".add_parameter(" << enclose_name(name) << ",";
print_py_shape(os, ins->get_shape());
os << ")" << std::endl;
}
else if(ins->name() == "@return")
{
os << mname << ".add_return([" << join_strings(input_vars, ", ") << "])"
<< std::endl;
}
else
{
assert(ins->name().front() != '@');
os << mname << ".add_instruction(";
print_py_op(os, ins->get_operator());
os << ", [" << join_strings(input_vars, ", ") << "]";
os << ")" << std::endl;
}
},
names);
return names;
}
std::unordered_map<instruction_ref, std::string> std::unordered_map<instruction_ref, std::string>
module::print_cpp(std::ostream& os, module::print_cpp(std::ostream& os,
const std::string& mname, const std::string& mname,
...@@ -874,6 +961,8 @@ module::print_cpp(std::ostream& os, ...@@ -874,6 +961,8 @@ module::print_cpp(std::ostream& os,
return names; return names;
} }
void module::print_py(std::ostream& os) const { this->print_py(os, this->name(), {}); }
void module::print_cpp(std::ostream& os) const { this->print_cpp(os, this->name(), {}); } void module::print_cpp(std::ostream& os) const { this->print_cpp(os, this->name(), {}); }
void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const
......
...@@ -110,9 +110,19 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r ...@@ -110,9 +110,19 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
{ {
if(args.size() == 3) if(args.size() == 3)
{ {
auto bias_bcast = mod->add_instruction( instruction_ref bias_bcast;
make_op("broadcast", {{"axis", axis}, {"out_lens", curr_ins->get_shape().lens()}}), // if curr_ins has a dynamic output shape use 2 input broadcast
args[2]); if(curr_ins->get_shape().dynamic())
{
bias_bcast =
mod->add_instruction(make_op("broadcast", {{"axis", axis}}), args[2], curr_ins);
}
else
{
bias_bcast = mod->add_instruction(
make_op("broadcast", {{"axis", axis}, {"out_lens", curr_ins->get_shape().lens()}}),
args[2]);
}
return mod->add_instruction(make_op("add"), curr_ins, bias_bcast); return mod->add_instruction(make_op("add"), curr_ins, bias_bcast);
} }
return curr_ins; return curr_ins;
...@@ -393,18 +403,31 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const ...@@ -393,18 +403,31 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
{ {
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end()); std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
if(not t.external_data().empty()) auto type = get_type(t.data_type());
shape tensor_shape(type, dims);
auto external_data = t.external_data();
if(not external_data.empty())
{ {
const std::string& data_file = t.external_data().at(0).value(); const std::string& data_file = external_data.at(0).value();
auto raw_buffer = read_buffer(path + "/" + data_file); size_t num_data_fields = external_data.size();
size_t offset = 0;
size_t nbytes = tensor_shape.bytes();
if(num_data_fields > 1) // if offset field is present
{
offset = std::stoul(t.external_data().at(1).value());
}
if(num_data_fields > 2) // if nbytes field is present
{
nbytes = std::stoul(t.external_data().at(2).value());
}
auto raw_buffer = read_buffer(path + "/" + data_file, offset, nbytes);
std::string s(raw_buffer.begin(), raw_buffer.end()); std::string s(raw_buffer.begin(), raw_buffer.end());
auto type = get_type(t.data_type());
return create_literal(type, dims, s.data()); return create_literal(type, dims, s.data());
} }
if(t.has_raw_data()) if(t.has_raw_data())
{ {
const std::string& s = t.raw_data(); const std::string& s = t.raw_data();
auto type = get_type(t.data_type());
return create_literal(type, dims, s.data()); return create_literal(type, dims, s.data());
} }
......
...@@ -39,10 +39,19 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -39,10 +39,19 @@ struct parse_gemm : op_parser<parse_gemm>
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
float alpha = 1.0f; auto a_arg = args[0];
float beta = 1.0f; auto b_arg = args[1];
bool transa = false; if(a_arg->get_shape().ndim() != 2 or b_arg->get_shape().ndim() != 2)
bool transb = false; {
MIGRAPHX_THROW("PARSE_GEMM: A and B should be rank 2, A is rank " +
std::to_string(a_arg->get_shape().ndim()) + ", B is rank " +
std::to_string(b_arg->get_shape().ndim()));
}
float alpha = 1.0f;
float beta = 1.0f;
bool trans_a = false;
bool trans_b = false;
if(contains(info.attributes, "alpha")) if(contains(info.attributes, "alpha"))
{ {
alpha = parser.parse_value(info.attributes.at("alpha")).at<float>(); alpha = parser.parse_value(info.attributes.at("alpha")).at<float>();
...@@ -53,65 +62,73 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -53,65 +62,73 @@ struct parse_gemm : op_parser<parse_gemm>
} }
if(contains(info.attributes, "transA")) if(contains(info.attributes, "transA"))
{ {
transa = parser.parse_value(info.attributes.at("transA")).at<bool>(); trans_a = parser.parse_value(info.attributes.at("transA")).at<bool>();
} }
if(contains(info.attributes, "transB")) if(contains(info.attributes, "transB"))
{ {
transb = parser.parse_value(info.attributes.at("transB")).at<bool>(); trans_b = parser.parse_value(info.attributes.at("transB")).at<bool>();
} }
std::vector<int64_t> perm(args[0]->get_shape().lens().size()); std::vector<int64_t> perm = {1, 0};
std::iota(perm.begin(), perm.end(), int64_t{0}); auto dot_type = a_arg->get_shape().type();
// swap the last two elements
std::swap(*perm.rbegin(), *(perm.rbegin() + 1));
auto l1 = args[0];
auto dot_type = l1->get_shape().type();
if(alpha != 1.0f) if(alpha != 1.0f)
{ {
auto alpha_literal = info.add_literal(alpha); auto alpha_literal = info.add_literal(alpha);
l1 = info.add_broadcastable_binary_op("mul", alpha_literal, l1); a_arg = info.add_broadcastable_binary_op("mul", alpha_literal, a_arg);
if(l1->get_shape().type() != dot_type)
if(a_arg->get_shape().type() != dot_type)
{ {
l1 = info.add_instruction(make_op("convert", {{"target_type", dot_type}}), l1); a_arg =
info.add_instruction(make_op("convert", {{"target_type", dot_type}}), a_arg);
} }
} }
l1 = a_arg = (trans_a)
(transa) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), l1) : l1; ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), a_arg)
auto l2 = (transb) : a_arg;
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1]) b_arg = (trans_b)
: args[1]; ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1];
auto ret = info.add_instruction(make_op("dot"), l1, l2); auto dot_ins = info.add_instruction(make_op("dot"), a_arg, b_arg);
if(args.size() == 3) if(args.size() == 3)
{ {
if(not float_equal(beta, 0.0f) && args[2]->get_shape().elements() > 0) if(not float_equal(beta, 0.0f))
{ {
auto out_lens = l1->get_shape().lens(); auto c_arg = args[2];
out_lens.back() = l2->get_shape().lens().back(); if(dot_ins->get_shape().dynamic())
auto l3 = args[2];
auto l3_lens = l3->get_shape().lens();
if(not std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{ {
l3 = info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}), c_arg = info.add_instruction(make_op("multibroadcast"), args[2], dot_ins);
args[2]);
} }
auto beta_literal = info.add_literal(beta); else
auto beta_l3 = info.add_broadcastable_binary_op("mul", l3, beta_literal);
if(beta_l3->get_shape().type() != dot_type)
{ {
beta_l3 = info.add_instruction(make_op("convert", {{"target_type", dot_type}}), auto out_lens = a_arg->get_shape().lens();
beta_l3); out_lens.back() = b_arg->get_shape().lens().back();
auto c_lens = c_arg->get_shape().lens();
if(not std::equal(
out_lens.begin(), out_lens.end(), c_lens.begin(), c_lens.end()))
{
c_arg = info.add_instruction(
make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
}
} }
return info.add_instruction(make_op("add"), ret, beta_l3); if(not float_equal(beta, 1.0f))
{
auto beta_literal = info.add_literal(beta);
c_arg = info.add_broadcastable_binary_op("mul", c_arg, beta_literal);
if(c_arg->get_shape().type() != dot_type)
{
c_arg = info.add_instruction(
make_op("convert", {{"target_type", dot_type}}), c_arg);
}
}
return info.add_instruction(make_op("add"), dot_ins, c_arg);
} }
} }
return dot_ins;
return ret;
} }
}; };
......
...@@ -43,55 +43,79 @@ struct parse_matmul : op_parser<parse_matmul> ...@@ -43,55 +43,79 @@ struct parse_matmul : op_parser<parse_matmul>
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
auto l0 = args[0]; auto a0 = args[0];
auto l1 = args[1]; auto a1 = args[1];
auto l0_lens = l0->get_shape().lens(); auto s0 = a0->get_shape();
auto l1_lens = l1->get_shape().lens(); auto s1 = a1->get_shape();
// args[0] is a vector, prepend 1 to the shape instruction_ref dot_res;
bool is_a_prepended = false; bool is_a_prepended = false;
if(l0_lens.size() == 1) bool is_b_appended = false;
if(s0.ndim() == 1)
{ {
is_a_prepended = true; is_a_prepended = true;
l0_lens.insert(l0_lens.begin(), 1); a0 = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), args[0]);
l0 = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), args[0]);
} }
if(s1.ndim() == 1)
bool is_b_appended = false;
if(l1_lens.size() == 1)
{ {
is_b_appended = true; is_b_appended = true;
l1_lens.push_back(1); a1 = info.add_instruction(make_op("unsqueeze", {{"axes", {1}}}), args[1]);
l1 = info.add_instruction(make_op("unsqueeze", {{"axes", {1}}}), args[1]);
} }
instruction_ref bl0 = l0; if(s0.dynamic() or s1.dynamic())
instruction_ref bl1 = l1;
if(not std::equal(
l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
{ {
auto l0_it = l0_lens.begin() + l0_lens.size() - 2; if(opd.op_name == "quant_dot")
std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it); {
auto l1_it = l1_lens.begin() + l1_lens.size() - 2; MIGRAPHX_THROW("PARSE_MATMUL: dynamic MatMulInteger not supported");
std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it); }
auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens); auto s0_dds = a0->get_shape().to_dynamic().dyn_dims();
l0_broadcasted_lens = output_lens; auto s1_dds = a1->get_shape().to_dynamic().dyn_dims();
l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
l1_broadcasted_lens = output_lens; // TODO: handling this case requires a new multibroadcast mode
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end()); if(not std::equal(
if(l0_lens != l0_broadcasted_lens) s0_dds.rbegin() + 2, s0_dds.rend(), s1_dds.rbegin() + 2, s1_dds.rend()))
{ {
bl0 = info.add_instruction( MIGRAPHX_THROW("PARSE_MATMUL: dynamic shape broadcasting not supported");
make_op("multibroadcast", {{"out_lens", l0_broadcasted_lens}}), l0);
} }
if(l1_lens != l1_broadcasted_lens)
dot_res = info.add_instruction(make_op(opd.op_name), a0, a1);
}
else
{
auto s0_lens = a0->get_shape().lens();
auto s1_lens = a1->get_shape().lens();
instruction_ref ba0 = a0;
instruction_ref ba1 = a1;
// try broadcasting if dimensions other than last two do not match
if(not std::equal(
s0_lens.rbegin() + 2, s0_lens.rend(), s1_lens.rbegin() + 2, s1_lens.rend()))
{ {
bl1 = info.add_instruction( auto l0_it = s0_lens.begin() + s0_lens.size() - 2;
make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), l1); std::vector<std::size_t> l0_broadcasted_lens(s0_lens.begin(), l0_it);
auto l1_it = s1_lens.begin() + s1_lens.size() - 2;
std::vector<std::size_t> l1_broadcasted_lens(s1_lens.begin(), l1_it);
auto output_lens =
compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
l0_broadcasted_lens = output_lens;
l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, s0_lens.end());
l1_broadcasted_lens = output_lens;
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, s1_lens.end());
if(s0_lens != l0_broadcasted_lens)
{
ba0 = info.add_instruction(
make_op("multibroadcast", {{"out_lens", l0_broadcasted_lens}}), a0);
}
if(s1_lens != l1_broadcasted_lens)
{
ba1 = info.add_instruction(
make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), a1);
}
} }
dot_res = info.add_instruction(make_op(opd.op_name), ba0, ba1);
} }
instruction_ref dot_res = info.add_instruction(make_op(opd.op_name), bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size()); // squeeze the appended or prepended dimensions
int64_t num_axis = dot_res->get_shape().ndim();
if(is_a_prepended) if(is_a_prepended)
{ {
dot_res = info.add_instruction(make_op("squeeze", {{"axes", {num_axis - 2}}}), dot_res); dot_res = info.add_instruction(make_op("squeeze", {{"axes", {num_axis - 2}}}), dot_res);
......
...@@ -147,7 +147,13 @@ struct parse_pad : op_parser<parse_pad> ...@@ -147,7 +147,13 @@ struct parse_pad : op_parser<parse_pad>
{ {
auto mode = info.attributes.at("mode").s(); auto mode = info.attributes.at("mode").s();
if(mode == "reflect") if(mode == "reflect")
{
if(args.front()->get_shape().dynamic())
{
MIGRAPHX_THROW("PARSE_PAD: reflect padding with dynamic shape not supported");
}
return reflect_pad(info, pads, args.front()); return reflect_pad(info, pads, args.front());
}
if(mode != "constant") if(mode != "constant")
{ {
MIGRAPHX_THROW( MIGRAPHX_THROW(
......
...@@ -47,52 +47,42 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -47,52 +47,42 @@ struct parse_pooling : op_parser<parse_pooling>
{"GlobalLpPool", "lpnorm"}}; {"GlobalLpPool", "lpnorm"}};
} }
instruction_ref parse(const op_desc& opd, value handle_values(const op_desc& opd,
const onnx_parser& /*parser*/, onnx_parser::node_info info,
onnx_parser::node_info info, const shape& in_shape,
std::vector<instruction_ref> args) const value values) const
{ {
const std::unordered_map<std::string, op::pooling_mode> mode_map = { auto kdims = in_shape.ndim() - 2;
{"max", op::pooling_mode::max},
{"average", op::pooling_mode::average},
{"lpnorm", op::pooling_mode::lpnorm}};
std::string mode = opd.op_name;
if(not contains(mode_map, mode))
{
MIGRAPHX_THROW("onnx pooling mode must be [\"max\", \"average\", \"lpnorm\"]");
}
operation op = make_op("pooling", {{"mode", mode_map.at(mode)}});
value values = op.to_value();
auto l0 = args[0];
auto in_lens = l0->get_shape().lens();
assert(in_lens.size() > 2);
auto kdims = in_lens.size() - 2;
if(starts_with(opd.onnx_name, "Global")) if(starts_with(opd.onnx_name, "Global"))
{ {
values["lengths"] = std::vector<size_t>(in_lens.begin() + 2, in_lens.end()); // if spatial dimensions are dynamic use dyn_global flag
if(in_shape.dynamic() and std::any_of(in_shape.dyn_dims().cbegin() + 2,
in_shape.dyn_dims().cend(),
[](auto dd) { return not dd.is_fixed(); }))
{
values["dyn_global"] = true;
values["lengths"] = std::vector<size_t>();
}
else
{
// works with static and fixed dynamic shape
auto m_lens = in_shape.max_lens();
values["lengths"] = std::vector<size_t>(m_lens.begin() + 2, m_lens.end());
}
} }
// does not support ceil_mode
if(contains(info.attributes, "ceil_mode")) if(contains(info.attributes, "ceil_mode"))
{ {
values["ceil_mode"] = static_cast<bool>(info.attributes.at("ceil_mode").i()); values["ceil_mode"] = static_cast<bool>(info.attributes.at("ceil_mode").i());
} }
// count include padding, if count include pad is 1, we always use
// explicit pad
int count_include_pad = 0;
if(contains(info.attributes, "count_include_pad"))
{
count_include_pad = info.attributes.at("count_include_pad").i();
}
if(contains(info.attributes, "strides")) if(contains(info.attributes, "strides"))
{ {
values["stride"].clear(); values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"])); copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(kdims, values["stride"].size(), "PARSE_POOLING: inconsistent strides"); check_attr_sizes(kdims, values["stride"].size(), "PARSE_POOLING: inconsistent strides");
} }
if(contains(info.attributes, "kernel_shape")) if(contains(info.attributes, "kernel_shape"))
{ {
values["lengths"].clear(); values["lengths"].clear();
...@@ -110,6 +100,46 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -110,6 +100,46 @@ struct parse_pooling : op_parser<parse_pooling>
// ensure pads availabe only when auto_pad is "NOT_SET" // ensure pads availabe only when auto_pad is "NOT_SET"
check_padding_mode(info, "POOLING"); check_padding_mode(info, "POOLING");
return values;
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
std::string mode = opd.op_name;
const std::unordered_map<std::string, op::pooling_mode> mode_map = {
{"max", op::pooling_mode::max},
{"average", op::pooling_mode::average},
{"lpnorm", op::pooling_mode::lpnorm}};
if(not contains(mode_map, mode))
{
MIGRAPHX_THROW(
"PARSE_POOLING: onnx pooling mode must be [\"max\", \"average\", \"lpnorm\"]");
}
operation op = make_op("pooling", {{"mode", mode_map.at(mode)}});
value values = op.to_value();
auto l0 = args[0];
auto in_shape = l0->get_shape();
assert(in_shape.ndim() > 2);
auto kdims = in_shape.ndim() - 2;
values = handle_values(opd, info, in_shape, values);
// count include padding, if count include pad is 1, we always use
// explicit pad
int count_include_pad = 0;
if(contains(info.attributes, "count_include_pad"))
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW("PARSE_POOLING: count_include_pad attribute is not supported for "
"dynamic input shape");
}
count_include_pad = info.attributes.at("count_include_pad").i();
}
std::vector<int64_t> paddings; std::vector<int64_t> paddings;
float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f); float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
...@@ -123,14 +153,22 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -123,14 +153,22 @@ struct parse_pooling : op_parser<parse_pooling>
if(contains(info.attributes, "auto_pad")) if(contains(info.attributes, "auto_pad"))
{ {
values["padding"].clear(); if(in_shape.dynamic())
// return paddings could be empty, then setting to 0 for no padding {
cal_auto_padding_size(info, MIGRAPHX_THROW(
values, "PARSE_POOLING: Auto padding pooling with dynamic input shape not supported");
values["lengths"].to_vector<std::size_t>(), }
{1, 1}, else
in_lens, {
paddings); values["padding"].clear();
// return paddings could be empty, then setting to 0 for no padding
cal_auto_padding_size(info,
values,
values["lengths"].to_vector<std::size_t>(),
{1, 1},
in_shape.lens(),
paddings);
}
} }
if(paddings.size() != 2 * kdims) if(paddings.size() != 2 * kdims)
...@@ -150,6 +188,7 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -150,6 +188,7 @@ struct parse_pooling : op_parser<parse_pooling>
values["stride"].resize(kdims); values["stride"].resize(kdims);
std::fill_n(values["stride"].begin(), kdims, 1); std::fill_n(values["stride"].begin(), kdims, 1);
} }
// used to calculate the supposed output shape // used to calculate the supposed output shape
std::vector<int64_t> orig_padding = paddings; std::vector<int64_t> orig_padding = paddings;
...@@ -159,6 +198,11 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -159,6 +198,11 @@ struct parse_pooling : op_parser<parse_pooling>
if(not slice_start.empty()) if(not slice_start.empty())
{ {
if(in_shape.dynamic())
{
MIGRAPHX_THROW(
"PARSE_POOLING: asymmetric padding not supported for dynamic input shape");
}
// calculate expected output shape // calculate expected output shape
orig_padding.insert(orig_padding.begin() + kdims, 2, 0); orig_padding.insert(orig_padding.begin() + kdims, 2, 0);
orig_padding.insert(orig_padding.begin(), 2, 0); orig_padding.insert(orig_padding.begin(), 2, 0);
......
...@@ -68,8 +68,7 @@ instruction_ref parse_reduce_oper(const std::string& op_name, ...@@ -68,8 +68,7 @@ instruction_ref parse_reduce_oper(const std::string& op_name,
} }
else else
{ {
std::size_t n_dim = args.front()->get_shape().lens().size(); axes.resize(args.front()->get_shape().ndim());
axes.resize(n_dim);
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
} }
} }
......
...@@ -49,7 +49,7 @@ struct parse_reshape : op_parser<parse_reshape> ...@@ -49,7 +49,7 @@ struct parse_reshape : op_parser<parse_reshape>
if(args.size() == 2) if(args.size() == 2)
{ {
auto s = args[1]->eval(); auto s = args[1]->eval();
check_arg_empty(s, "Reshape: dynamic shape is not supported"); check_arg_empty(s, "Reshape: non-constant shape input is not supported");
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
} }
......
...@@ -47,7 +47,7 @@ struct parse_transpose : op_parser<parse_transpose> ...@@ -47,7 +47,7 @@ struct parse_transpose : op_parser<parse_transpose>
} }
// if perm is empty, use the default value // if perm is empty, use the default value
auto n_dim = args.front()->get_shape().lens().size(); auto n_dim = args.front()->get_shape().ndim();
if(perm.empty()) if(perm.empty())
{ {
perm.resize(n_dim); perm.resize(n_dim);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_trilu : op_parser<parse_trilu>
{
std::vector<op_desc> operators() const { return {{"Trilu"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto input_shape = args[0]->get_shape();
assert(input_shape.ndim() >= 2);
auto input_lens = input_shape.lens();
size_t num_rows = *(input_lens.rbegin() + 1);
size_t num_cols = input_lens.back();
int k = 0;
bool upper = true;
if(args.size() > 1)
{
auto arg_k = args[1]->eval();
check_arg_empty(arg_k, "PARSE_TRILU: dynamic k not supported");
k = arg_k.at<int>();
}
if(k < 0)
MIGRAPHX_THROW("PARSE_TRILU: negative k values not supported");
if(contains(info.attributes, "upper"))
{
upper = static_cast<bool>(info.attributes.at("upper").i());
}
shape::type_t output_type = args[0]->get_shape().type();
// when creating the mask, if upper == 1,
// the inner triangle will have values set to 0
std::vector<bool> mask_mat(num_rows * num_cols, upper);
for(size_t i = 0; i < num_rows; i++)
{
for(size_t j = 0; j < std::min(k, static_cast<int>(num_cols)); j++)
{
mask_mat[i * num_cols + j] = not upper;
}
k++;
}
auto mask = info.add_literal(
migraphx::literal{migraphx::shape{output_type, {num_rows, num_cols}}, mask_mat});
return info.add_broadcastable_binary_op("mul", mask, args[0]);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment