"example/vscode:/vscode.git/clone" did not exist on "7fcff6f51e8df71b919603c518c75e717b1b177e"
Unverified Commit 5a1feb14 authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'develop' into fix_parse_if

parents cfbd5e8b 03c39761
...@@ -53,15 +53,15 @@ struct softmax ...@@ -53,15 +53,15 @@ struct softmax
std::string name() const { return "softmax"; } std::string name() const { return "softmax"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
if(inputs.at(0).packed()) auto s0 = inputs[0];
if(s0.dynamic() or s0.packed())
{ {
return inputs.at(0); return s0;
} }
else else
{ {
auto lens = inputs.at(0).lens(); return {s0.type(), s0.lens()};
return {inputs.at(0).type(), lens};
} }
} }
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -54,52 +55,85 @@ struct squeeze ...@@ -54,52 +55,85 @@ struct squeeze
std::string name() const { return "squeeze"; } std::string name() const { return "squeeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); if(input_shape.dynamic())
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
{ {
MIGRAPHX_THROW("squeeze axis dimension should be equal to 1"); if(std::any_of(axes.begin(), axes.end(), [&](auto axis) {
} return input_shape.dyn_dims()[axis] != 1;
std::vector<std::size_t> new_lens; }))
std::vector<std::size_t> new_strides; {
if(axes.empty()) MIGRAPHX_THROW(
{ "SQUEEZE: dynamic axis dimension should be equal to {1, 1, 0} or {1, 1, 1}");
for(auto i : range(old_lens.size())) }
std::vector<shape::dynamic_dimension> dyn_dims = {};
if(axes.empty())
{
std::copy_if(input_shape.dyn_dims().cbegin(),
input_shape.dyn_dims().cend(),
std::back_inserter(dyn_dims),
[&](auto dd) { return dd != 1; });
}
else
{ {
if(old_lens[i] != 1) for(auto i : range(input_shape.ndim()))
{ {
new_lens.push_back(old_lens[i]); if(std::find(axes.begin(), axes.end(), i) == axes.end())
new_strides.push_back(old_strides[i]); {
dyn_dims.push_back(input_shape.dyn_dims()[i]);
}
} }
} }
return {input_shape.type(), dyn_dims};
} }
else else
{ {
for(auto i : range(old_lens.size())) auto type = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(std::any_of(
axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
{ {
if(std::find(axes.begin(), axes.end(), i) == axes.end()) MIGRAPHX_THROW("SQUEEZE: static axis dimension should be equal to 1");
}
std::vector<std::size_t> new_lens;
std::vector<std::size_t> new_strides;
if(axes.empty())
{
for(auto i : range(old_lens.size()))
{ {
new_lens.push_back(old_lens[i]); if(old_lens[i] != 1)
new_strides.push_back(old_strides[i]); {
new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
}
} }
} }
} else
if(new_lens.empty()) {
{ for(auto i : range(old_lens.size()))
return shape{type}; {
} if(std::find(axes.begin(), axes.end(), i) == axes.end())
else {
{ new_lens.push_back(old_lens[i]);
return shape{type, new_lens, new_strides}; new_strides.push_back(old_strides[i]);
}
}
}
if(new_lens.empty())
{
return shape{type};
}
else
{
return shape{type, new_lens, new_strides};
}
} }
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -45,17 +46,15 @@ struct transpose ...@@ -45,17 +46,15 @@ struct transpose
} }
std::string name() const { return "transpose"; } std::string name() const { return "transpose"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto input = inputs.at(0); auto input = inputs.at(0);
auto input_lens = input.lens();
auto input_strides = input.strides();
auto t = input.type();
if(dims.size() != input_lens.size()) if(dims.size() != input.ndim())
{ {
MIGRAPHX_THROW("Permutation has wrong number of axes"); MIGRAPHX_THROW("TRANSPOSE: Permutation has wrong number of axes");
} }
std::vector<int64_t> axes(dims.size()); std::vector<int64_t> axes(dims.size());
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
...@@ -63,19 +62,36 @@ struct transpose ...@@ -63,19 +62,36 @@ struct transpose
{ {
MIGRAPHX_THROW("TRANSPOSE: Invalid permutation"); MIGRAPHX_THROW("TRANSPOSE: Invalid permutation");
} }
std::vector<size_t> output_lens(input_lens.size());
std::vector<size_t> output_strides(input_lens.size()); if(input.dynamic())
for(std::size_t i = 0; i < output_lens.size(); i++)
{ {
output_lens[i] = input_lens[dims[i]]; std::vector<shape::dynamic_dimension> output_dyn_dims(input.ndim());
output_strides[i] = input_strides[dims[i]]; std::transform(dims.cbegin(), dims.cend(), output_dyn_dims.begin(), [&](auto dim) {
return input.dyn_dims()[dim];
});
return {input.type(), output_dyn_dims};
}
else
{
auto input_lens = input.lens();
auto input_strides = input.strides();
std::vector<size_t> output_lens(input.ndim());
std::vector<size_t> output_strides(input.ndim());
for(std::size_t i = 0; i < input.ndim(); i++)
{
output_lens[i] = input_lens[dims[i]];
output_strides[i] = input_strides[dims[i]];
}
return {input.type(), output_lens, output_strides};
} }
return {t, output_lens, output_strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -29,11 +29,20 @@ ...@@ -29,11 +29,20 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/**
* Adds dimensions to a tensor based on the axes attribute.
* `axes` are based on the number of output shape dimensions and should not contain duplicates.
* `steps` are for modifying dimensions added to the middle of the original shape.
* Each step must be a factor of the original dimension.
* ex: unsqueeze(shape = [3, 4, 10], axes = [2, 4, 5], steps = [2]) -> shape = [3, 4, 2, 5, 1, 1]
* Dynamic shape version does not handle `steps`.
*/
struct unsqueeze struct unsqueeze
{ {
std::vector<int64_t> axes; std::vector<int64_t> axes;
...@@ -56,63 +65,89 @@ struct unsqueeze ...@@ -56,63 +65,89 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; } std::string name() const { return "unsqueeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type();
auto old_lens = input_shape.lens(); if(input_shape.dynamic())
auto old_strides = input_shape.strides();
if(input_shape.scalar())
{ {
if(old_lens.size() == 1 and old_lens.front() == 1) if(not steps.empty())
return shape{type, old_lens}; {
else MIGRAPHX_THROW("UNSQUEEZE_dyn: nonempty steps attribute");
MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar"); }
std::vector<shape::dynamic_dimension> dyn_dims = {};
auto new_ndim = input_shape.ndim() + axes.size();
std::size_t k = 0;
for(auto i : range(new_ndim))
{
if(std::find(axes.begin(), axes.end(), i) != axes.end())
{
dyn_dims.push_back({1, 1, 0});
}
else
{
dyn_dims.push_back(input_shape.dyn_dims().at(k++));
}
}
return {input_shape.type(), dyn_dims};
} }
else
{
auto type = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(input_shape.scalar())
{
if(old_lens.size() == 1 and old_lens.front() == 1)
return shape{type, old_lens};
else
MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar");
}
if(steps.size() > axes.size()) if(steps.size() > axes.size())
MIGRAPHX_THROW("UNSQUEEZE: Steps provided with no axis"); MIGRAPHX_THROW("UNSQUEEZE: Steps provided with no axis");
std::size_t new_size = old_lens.size() + axes.size(); std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size); std::vector<std::size_t> new_lens(new_size);
std::vector<std::size_t> new_strides(new_size); std::vector<std::size_t> new_strides(new_size);
std::size_t p = 0; std::size_t p = 0;
for(auto i : range(new_size)) for(auto i : range(new_size))
{
auto axis_idx = std::find(axes.begin(), axes.end(), i) - axes.begin();
if(axis_idx < axes.size())
{ {
std::int64_t step = 1; auto axis_idx = std::find(axes.begin(), axes.end(), i) - axes.begin();
if(axis_idx < steps.size()) if(axis_idx < axes.size())
step = steps[axis_idx];
if(step == 0)
MIGRAPHX_THROW("UNSQUEEZE: step must be non-zero");
new_lens[i] = step;
if(p < old_strides.size())
{ {
if((old_lens[p] % step) != 0) std::int64_t step = 1;
MIGRAPHX_THROW("UNSQUEEZE: Axis dimenstion is not divisible by step"); if(axis_idx < steps.size())
old_lens[p] /= step; step = steps[axis_idx];
new_strides[i] = old_strides[p] * old_lens[p]; if(step == 0)
MIGRAPHX_THROW("UNSQUEEZE: step must be non-zero");
new_lens[i] = step;
if(p < old_strides.size())
{
if((old_lens[p] % step) != 0)
MIGRAPHX_THROW("UNSQUEEZE: Axis dimenstion is not divisible by step");
old_lens[p] /= step;
new_strides[i] = old_strides[p] * old_lens[p];
}
else
{
if(step != 1)
MIGRAPHX_THROW("UNSQUEEZE: Step must be 1 for extra axes");
new_strides[i] = 1;
}
} }
else else
{ {
if(step != 1) new_lens[i] = old_lens[p];
MIGRAPHX_THROW("UNSQUEEZE: Step must be 1 for extra axes"); new_strides[i] = old_strides[p++];
new_strides[i] = 1;
} }
} }
else return shape{type, new_lens, new_strides};
{
new_lens[i] = old_lens[p];
new_strides[i] = old_strides[p++];
}
} }
return shape{type, new_lens, new_strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -115,6 +115,7 @@ struct program ...@@ -115,6 +115,7 @@ struct program
print_func) const; print_func) const;
void print_graph(std::ostream& os, bool brief = false) const; void print_graph(std::ostream& os, bool brief = false) const;
void print_py(std::ostream& os) const;
void print_cpp(std::ostream& os) const; void print_cpp(std::ostream& os) const;
void dry_run(parameter_map params) const; void dry_run(parameter_map params) const;
......
...@@ -101,6 +101,12 @@ struct shape ...@@ -101,6 +101,12 @@ struct shape
friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y); friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y);
friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y); friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y);
friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x); friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x);
// compare to fixed std::size_t dimension
friend bool operator==(const dynamic_dimension& x, const std::size_t& y);
friend bool operator==(const std::size_t& x, const dynamic_dimension& y);
friend bool operator!=(const dynamic_dimension& x, const std::size_t& y);
friend bool operator!=(const std::size_t& x, const dynamic_dimension& y);
}; };
static const std::vector<type_t>& types(); static const std::vector<type_t>& types();
......
...@@ -31,6 +31,9 @@ ...@@ -31,6 +31,9 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
/**
* Iterates the given function over the indices from the shape in order.
*/
template <class F> template <class F>
void shape_for_each(const migraphx::shape& s, F f) void shape_for_each(const migraphx::shape& s, F f)
{ {
...@@ -51,7 +54,6 @@ void shape_for_each(const migraphx::shape& s, F f) ...@@ -51,7 +54,6 @@ void shape_for_each(const migraphx::shape& s, F f)
call(indices); call(indices);
} }
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -77,14 +77,14 @@ static void update_pooling(const instruction_ref& input, const instruction_ref& ...@@ -77,14 +77,14 @@ static void update_pooling(const instruction_ref& input, const instruction_ref&
{ {
return; return;
} }
auto kdims = input->get_shape().lens().size() - 2; auto kdims = input->get_shape().ndim() - 2;
if(std::equal(op.padding.begin(), if(std::equal(op.padding.begin(),
op.padding.begin() + kdims, op.padding.begin() + kdims,
op.padding.begin() + kdims, op.padding.begin() + kdims,
op.padding.end())) op.padding.end()))
return; return;
std::vector<int64_t> padding(input->get_shape().lens().size() * 2, 0); std::vector<int64_t> padding(input->get_shape().ndim() * 2, 0);
std::vector<size_t> pads_l(op.padding.begin(), op.padding.begin() + kdims); std::vector<size_t> pads_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> pads_r(op.padding.begin() + kdims, op.padding.end()); std::vector<size_t> pads_r(op.padding.begin() + kdims, op.padding.end());
op.padding = std::vector<size_t>(kdims * 2, 0); op.padding = std::vector<size_t>(kdims * 2, 0);
......
...@@ -302,6 +302,24 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod) ...@@ -302,6 +302,24 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
std::replace(module_args.begin(), module_args.end(), old, new_mod); std::replace(module_args.begin(), module_args.end(), old, new_mod);
} }
bool instruction::is_undefined() const
{
if(op.name() == "undefined")
{
return true;
}
else if(this->inputs().empty())
{
return false;
}
else
{
return std::all_of(this->inputs().begin(), this->inputs().end(), [](auto arg) {
return arg->is_undefined();
});
}
}
bool instruction::can_eval() const bool instruction::can_eval() const
{ {
if(op.name() == "@literal") if(op.name() == "@literal")
......
...@@ -789,6 +789,22 @@ static std::string cpp_var_name(const std::string& name) ...@@ -789,6 +789,22 @@ static std::string cpp_var_name(const std::string& name)
return to_c_id("x_" + replace_string(name, ":", "_module_")); return to_c_id("x_" + replace_string(name, ":", "_module_"));
} }
static void print_py_op(std::ostream& os, const operation& op)
{
auto v = op.to_value();
os << "migraphx.op(" << enclose_name(op.name());
auto default_values = make_op(op.name()).to_value();
for(auto&& x : v)
{
auto name = x.get_key();
if(default_values[name] == x)
continue;
os << ", " << name << "=" << to_json_string(x.without_key());
}
os << ")";
}
static void print_make_op(std::ostream& os, const operation& op) static void print_make_op(std::ostream& os, const operation& op)
{ {
auto v = op.to_value(); auto v = op.to_value();
...@@ -804,6 +820,14 @@ static void print_make_op(std::ostream& os, const operation& op) ...@@ -804,6 +820,14 @@ static void print_make_op(std::ostream& os, const operation& op)
os << ")"; os << ")";
} }
static void print_py_shape(std::ostream& os, const migraphx::shape& s)
{
os << "migraphx.shape(" << s.type_string() << ", lens=" << to_json_string(s.lens());
if(not s.standard())
os << ", strides=" << to_json_string(s.strides());
os << ")";
}
static void print_cpp_shape(std::ostream& os, const migraphx::shape& s) static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
{ {
os << "migraphx::shape{migraphx::shape::" << s.type_string(); os << "migraphx::shape{migraphx::shape::" << s.type_string();
...@@ -813,6 +837,68 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s) ...@@ -813,6 +837,68 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
os << "}"; os << "}";
} }
std::unordered_map<instruction_ref, std::string>
module::print_py(std::ostream& os,
const std::string& mname,
std::unordered_map<instruction_ref, std::string> names) const
{
// cppcheck-suppress variableScope
unsigned long seed = names.size();
auto last = std::prev(this->end());
names = this->print(
[&](auto ins, auto ins_names) {
std::vector<std::string> input_vars;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(input_vars),
[&](auto input) { return cpp_var_name(ins_names.at(input)); });
if(ins != last)
os << cpp_var_name(ins_names.at(ins)) << " = ";
if(ins->name() == "@literal")
{
os << mname << ".add_literal(";
bool use_abs = false;
ins->get_literal().visit([&](auto v) {
use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
});
// Disable abs for now
use_abs = false;
if(use_abs)
os << "migraphx.abs_literal(";
os << "migraphx.generate_literal(";
print_py_shape(os, ins->get_shape());
os << ", " << seed << ")";
if(use_abs)
os << ")";
os << ")" << std::endl;
seed++;
}
else if(ins->name() == "@param")
{
std::string name = any_cast<builtin::param>(ins->get_operator()).parameter;
os << mname << ".add_parameter(" << enclose_name(name) << ",";
print_py_shape(os, ins->get_shape());
os << ")" << std::endl;
}
else if(ins->name() == "@return")
{
os << mname << ".add_return([" << join_strings(input_vars, ", ") << "])"
<< std::endl;
}
else
{
assert(ins->name().front() != '@');
os << mname << ".add_instruction(";
print_py_op(os, ins->get_operator());
os << ", [" << join_strings(input_vars, ", ") << "]";
os << ")" << std::endl;
}
},
names);
return names;
}
std::unordered_map<instruction_ref, std::string> std::unordered_map<instruction_ref, std::string>
module::print_cpp(std::ostream& os, module::print_cpp(std::ostream& os,
const std::string& mname, const std::string& mname,
...@@ -874,6 +960,8 @@ module::print_cpp(std::ostream& os, ...@@ -874,6 +960,8 @@ module::print_cpp(std::ostream& os,
return names; return names;
} }
void module::print_py(std::ostream& os) const { this->print_py(os, this->name(), {}); }
void module::print_cpp(std::ostream& os) const { this->print_cpp(os, this->name(), {}); } void module::print_cpp(std::ostream& os) const { this->print_cpp(os, this->name(), {}); }
void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const
......
...@@ -393,18 +393,31 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const ...@@ -393,18 +393,31 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
{ {
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end()); std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
if(not t.external_data().empty()) auto type = get_type(t.data_type());
shape tensor_shape(type, dims);
auto external_data = t.external_data();
if(not external_data.empty())
{ {
const std::string& data_file = t.external_data().at(0).value(); const std::string& data_file = external_data.at(0).value();
auto raw_buffer = read_buffer(path + "/" + data_file); size_t num_data_fields = external_data.size();
size_t offset = 0;
size_t nbytes = tensor_shape.bytes();
if(num_data_fields > 1) // if offset field is present
{
offset = std::stoul(t.external_data().at(1).value());
}
if(num_data_fields > 2) // if nbytes field is present
{
nbytes = std::stoul(t.external_data().at(2).value());
}
auto raw_buffer = read_buffer(path + "/" + data_file, offset, nbytes);
std::string s(raw_buffer.begin(), raw_buffer.end()); std::string s(raw_buffer.begin(), raw_buffer.end());
auto type = get_type(t.data_type());
return create_literal(type, dims, s.data()); return create_literal(type, dims, s.data());
} }
if(t.has_raw_data()) if(t.has_raw_data())
{ {
const std::string& s = t.raw_data(); const std::string& s = t.raw_data();
auto type = get_type(t.data_type());
return create_literal(type, dims, s.data()); return create_literal(type, dims, s.data());
} }
......
...@@ -47,52 +47,42 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -47,52 +47,42 @@ struct parse_pooling : op_parser<parse_pooling>
{"GlobalLpPool", "lpnorm"}}; {"GlobalLpPool", "lpnorm"}};
} }
instruction_ref parse(const op_desc& opd, value handle_values(const op_desc& opd,
const onnx_parser& /*parser*/, onnx_parser::node_info info,
onnx_parser::node_info info, const shape& in_shape,
std::vector<instruction_ref> args) const value values) const
{ {
const std::unordered_map<std::string, op::pooling_mode> mode_map = { auto kdims = in_shape.ndim() - 2;
{"max", op::pooling_mode::max},
{"average", op::pooling_mode::average},
{"lpnorm", op::pooling_mode::lpnorm}};
std::string mode = opd.op_name;
if(not contains(mode_map, mode))
{
MIGRAPHX_THROW("onnx pooling mode must be [\"max\", \"average\", \"lpnorm\"]");
}
operation op = make_op("pooling", {{"mode", mode_map.at(mode)}});
value values = op.to_value();
auto l0 = args[0];
auto in_lens = l0->get_shape().lens();
assert(in_lens.size() > 2);
auto kdims = in_lens.size() - 2;
if(starts_with(opd.onnx_name, "Global")) if(starts_with(opd.onnx_name, "Global"))
{ {
values["lengths"] = std::vector<size_t>(in_lens.begin() + 2, in_lens.end()); // if spatial dimensions are dynamic use dyn_global flag
if(in_shape.dynamic() and std::any_of(in_shape.dyn_dims().cbegin() + 2,
in_shape.dyn_dims().cend(),
[](auto dd) { return not dd.is_fixed(); }))
{
values["dyn_global"] = true;
values["lengths"] = std::vector<size_t>();
}
else
{
// works with static and fixed dynamic shape
auto m_lens = in_shape.max_lens();
values["lengths"] = std::vector<size_t>(m_lens.begin() + 2, m_lens.end());
}
} }
// does not support ceil_mode
if(contains(info.attributes, "ceil_mode")) if(contains(info.attributes, "ceil_mode"))
{ {
values["ceil_mode"] = static_cast<bool>(info.attributes.at("ceil_mode").i()); values["ceil_mode"] = static_cast<bool>(info.attributes.at("ceil_mode").i());
} }
// count include padding, if count include pad is 1, we always use
// explicit pad
int count_include_pad = 0;
if(contains(info.attributes, "count_include_pad"))
{
count_include_pad = info.attributes.at("count_include_pad").i();
}
if(contains(info.attributes, "strides")) if(contains(info.attributes, "strides"))
{ {
values["stride"].clear(); values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"])); copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(kdims, values["stride"].size(), "PARSE_POOLING: inconsistent strides"); check_attr_sizes(kdims, values["stride"].size(), "PARSE_POOLING: inconsistent strides");
} }
if(contains(info.attributes, "kernel_shape")) if(contains(info.attributes, "kernel_shape"))
{ {
values["lengths"].clear(); values["lengths"].clear();
...@@ -110,6 +100,46 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -110,6 +100,46 @@ struct parse_pooling : op_parser<parse_pooling>
// ensure pads availabe only when auto_pad is "NOT_SET" // ensure pads availabe only when auto_pad is "NOT_SET"
check_padding_mode(info, "POOLING"); check_padding_mode(info, "POOLING");
return values;
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
std::string mode = opd.op_name;
const std::unordered_map<std::string, op::pooling_mode> mode_map = {
{"max", op::pooling_mode::max},
{"average", op::pooling_mode::average},
{"lpnorm", op::pooling_mode::lpnorm}};
if(not contains(mode_map, mode))
{
MIGRAPHX_THROW(
"PARSE_POOLING: onnx pooling mode must be [\"max\", \"average\", \"lpnorm\"]");
}
operation op = make_op("pooling", {{"mode", mode_map.at(mode)}});
value values = op.to_value();
auto l0 = args[0];
auto in_shape = l0->get_shape();
assert(in_shape.ndim() > 2);
auto kdims = in_shape.ndim() - 2;
values = handle_values(opd, info, in_shape, values);
// count include padding, if count include pad is 1, we always use
// explicit pad
int count_include_pad = 0;
if(contains(info.attributes, "count_include_pad"))
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW("PARSE_POOLING: count_include_pad attribute is not supported for "
"dynamic input shape");
}
count_include_pad = info.attributes.at("count_include_pad").i();
}
std::vector<int64_t> paddings; std::vector<int64_t> paddings;
float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f); float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
...@@ -123,14 +153,22 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -123,14 +153,22 @@ struct parse_pooling : op_parser<parse_pooling>
if(contains(info.attributes, "auto_pad")) if(contains(info.attributes, "auto_pad"))
{ {
values["padding"].clear(); if(in_shape.dynamic())
// return paddings could be empty, then setting to 0 for no padding {
cal_auto_padding_size(info, MIGRAPHX_THROW(
values, "PARSE_POOLING: Auto padding pooling with dynamic input shape not supported");
values["lengths"].to_vector<std::size_t>(), }
{1, 1}, else
in_lens, {
paddings); values["padding"].clear();
// return paddings could be empty, then setting to 0 for no padding
cal_auto_padding_size(info,
values,
values["lengths"].to_vector<std::size_t>(),
{1, 1},
in_shape.lens(),
paddings);
}
} }
if(paddings.size() != 2 * kdims) if(paddings.size() != 2 * kdims)
...@@ -150,6 +188,7 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -150,6 +188,7 @@ struct parse_pooling : op_parser<parse_pooling>
values["stride"].resize(kdims); values["stride"].resize(kdims);
std::fill_n(values["stride"].begin(), kdims, 1); std::fill_n(values["stride"].begin(), kdims, 1);
} }
// used to calculate the supposed output shape // used to calculate the supposed output shape
std::vector<int64_t> orig_padding = paddings; std::vector<int64_t> orig_padding = paddings;
...@@ -159,6 +198,11 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -159,6 +198,11 @@ struct parse_pooling : op_parser<parse_pooling>
if(not slice_start.empty()) if(not slice_start.empty())
{ {
if(in_shape.dynamic())
{
MIGRAPHX_THROW(
"PARSE_POOLING: asymmetric padding not supported for dynamic input shape");
}
// calculate expected output shape // calculate expected output shape
orig_padding.insert(orig_padding.begin() + kdims, 2, 0); orig_padding.insert(orig_padding.begin() + kdims, 2, 0);
orig_padding.insert(orig_padding.begin(), 2, 0); orig_padding.insert(orig_padding.begin(), 2, 0);
......
...@@ -68,8 +68,7 @@ instruction_ref parse_reduce_oper(const std::string& op_name, ...@@ -68,8 +68,7 @@ instruction_ref parse_reduce_oper(const std::string& op_name,
} }
else else
{ {
std::size_t n_dim = args.front()->get_shape().lens().size(); axes.resize(args.front()->get_shape().ndim());
axes.resize(n_dim);
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
} }
} }
......
...@@ -47,7 +47,7 @@ struct parse_transpose : op_parser<parse_transpose> ...@@ -47,7 +47,7 @@ struct parse_transpose : op_parser<parse_transpose>
} }
// if perm is empty, use the default value // if perm is empty, use the default value
auto n_dim = args.front()->get_shape().lens().size(); auto n_dim = args.front()->get_shape().ndim();
if(perm.empty()) if(perm.empty())
{ {
perm.resize(n_dim); perm.resize(n_dim);
......
...@@ -854,6 +854,25 @@ void program::print_graph(std::ostream& os, bool brief) const ...@@ -854,6 +854,25 @@ void program::print_graph(std::ostream& os, bool brief) const
mm->print_graph(os, brief); mm->print_graph(os, brief);
} }
void program::print_py(std::ostream& os) const
{
auto vec_modules = this->get_modules();
std::unordered_map<instruction_ref, std::string> names;
os << "p = migraphx.program()\n";
for(auto& mod : vec_modules)
{
std::string var_name = "m" + mod->name();
os << var_name << " = ";
if(mod->name() == "main")
os << "p.get_main_module()";
else
os << "p.create_module(\"" << mod->name() << "\");";
os << std::endl;
names = mod->print_py(os, var_name, names);
os << std::endl;
}
}
void program::print_cpp(std::ostream& os) const void program::print_cpp(std::ostream& os) const
{ {
auto vec_modules = this->get_modules(); auto vec_modules = this->get_modules();
......
...@@ -92,7 +92,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const ...@@ -92,7 +92,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// process sequence length // process sequence length
instruction_ref seq_lens = m.end(); instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined") if((args.size() >= 5) and not args[4]->is_undefined())
{ {
seq_lens = args[4]; seq_lens = args[4];
} }
...@@ -117,7 +117,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const ...@@ -117,7 +117,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// process bias // process bias
instruction_ref bias_forward = m.end(); instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = m.end(); instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias_forward = m.insert_instruction( bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
...@@ -129,7 +129,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const ...@@ -129,7 +129,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// or the 5th one (if the sequence len argument is ignored) // or the 5th one (if the sequence len argument is ignored)
instruction_ref ih_forward{}; instruction_ref ih_forward{};
instruction_ref ih_reverse{}; instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined") if(args.size() == 6 and not args[5]->is_undefined())
{ {
ih_forward = m.insert_instruction( ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
...@@ -195,14 +195,14 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const ...@@ -195,14 +195,14 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// process bias and initial hidden state // process bias and initial hidden state
instruction_ref bias = m.end(); instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias = args[3]; bias = args[3];
} }
// process intial hidden state // process intial hidden state
instruction_ref ih; instruction_ref ih;
if(args.size() == 6 && args[5]->name() != "undefined") if(args.size() == 6 and not args[5]->is_undefined())
{ {
ih = args[5]; ih = args[5];
} }
...@@ -398,7 +398,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const ...@@ -398,7 +398,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// process sequence length // process sequence length
instruction_ref seq_lens = m.end(); instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined") if((args.size() >= 5) and not args[4]->is_undefined())
{ {
seq_lens = args[4]; seq_lens = args[4];
} }
...@@ -423,7 +423,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const ...@@ -423,7 +423,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// bias // bias
instruction_ref bias_forward = m.end(); instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = m.end(); instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias_forward = m.insert_instruction( bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
...@@ -434,7 +434,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const ...@@ -434,7 +434,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// intial hidden state // intial hidden state
instruction_ref ih_forward{}; instruction_ref ih_forward{};
instruction_ref ih_reverse{}; instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined") if(args.size() == 6 and not args[5]->is_undefined())
{ {
ih_forward = m.insert_instruction( ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
...@@ -501,14 +501,14 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const ...@@ -501,14 +501,14 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// bias // bias
instruction_ref bias = m.end(); instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias = args[3]; bias = args[3];
} }
// intial hidden state // intial hidden state
instruction_ref ih{}; instruction_ref ih{};
if(args.size() == 6 && args[5]->name() != "undefined") if(args.size() == 6 and not args[5]->is_undefined())
{ {
ih = args[5]; ih = args[5];
} }
...@@ -784,7 +784,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -784,7 +784,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process sequence length // process sequence length
instruction_ref seq_lens = m.end(); instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined") if((args.size() >= 5) and not args[4]->is_undefined())
{ {
seq_lens = args[4]; seq_lens = args[4];
} }
...@@ -813,7 +813,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -813,7 +813,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process bias // process bias
instruction_ref bias_forward = m.end(); instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = m.end(); instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias_forward = m.insert_instruction( bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
...@@ -824,7 +824,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -824,7 +824,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process intial hidden state, it is the 6th argument // process intial hidden state, it is the 6th argument
instruction_ref ih_forward{}; instruction_ref ih_forward{};
instruction_ref ih_reverse{}; instruction_ref ih_reverse{};
if(args.size() >= 6 && args[5]->name() != "undefined") if(args.size() >= 6 and not args[5]->is_undefined())
{ {
ih_forward = m.insert_instruction( ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
...@@ -840,7 +840,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -840,7 +840,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process initial cell value // process initial cell value
instruction_ref ic_forward{}; instruction_ref ic_forward{};
instruction_ref ic_reverse{}; instruction_ref ic_reverse{};
if(args.size() >= 7 && args[6]->name() != "undefined") if(args.size() >= 7 and not args[6]->is_undefined())
{ {
ic_forward = m.insert_instruction( ic_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[6]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[6]);
...@@ -856,7 +856,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -856,7 +856,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process weight of the peephole // process weight of the peephole
instruction_ref pph_forward = m.end(); instruction_ref pph_forward = m.end();
instruction_ref pph_reverse = m.end(); instruction_ref pph_reverse = m.end();
if(args.size() == 8 && args[7]->name() != "undefined") if(args.size() == 8 and not args[7]->is_undefined())
{ {
pph_forward = m.insert_instruction( pph_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[7]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[7]);
...@@ -940,14 +940,14 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -940,14 +940,14 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// bias // bias
instruction_ref bias = m.end(); instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias = args[3]; bias = args[3];
} }
// initial hidden state // initial hidden state
instruction_ref ih{}; instruction_ref ih{};
if(args.size() >= 6 && args[5]->name() != "undefined") if(args.size() >= 6 and not args[5]->is_undefined())
{ {
ih = args[5]; ih = args[5];
} }
...@@ -958,7 +958,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -958,7 +958,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// initial cell value // initial cell value
instruction_ref ic{}; instruction_ref ic{};
if(args.size() >= 7 && args[6]->name() != "undefined") if(args.size() >= 7 and not args[6]->is_undefined())
{ {
ic = args[6]; ic = args[6];
} }
...@@ -969,7 +969,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -969,7 +969,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process weight of the peephole // process weight of the peephole
instruction_ref pph = m.end(); instruction_ref pph = m.end();
if(args.size() == 8 && args[7]->name() != "undefined") if(args.size() == 8 and not args[7]->is_undefined())
{ {
pph = args[7]; pph = args[7];
} }
......
...@@ -521,6 +521,14 @@ std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x) ...@@ -521,6 +521,14 @@ std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
return os; return os;
} }
bool operator==(const shape::dynamic_dimension& x, const std::size_t& y)
{
return x.min == y and x.max == y;
}
bool operator==(const std::size_t& x, const shape::dynamic_dimension& y) { return y == x; }
bool operator!=(const shape::dynamic_dimension& x, const std::size_t& y) { return not(x == y); }
bool operator!=(const std::size_t& x, const shape::dynamic_dimension& y) { return not(x == y); }
bool operator==(const shape& x, const shape& y) bool operator==(const shape& x, const shape& y)
{ {
if(x.dynamic() and y.dynamic()) if(x.dynamic() and y.dynamic())
......
...@@ -233,11 +233,14 @@ get_target_property(MIOPEN_LOCATION MIOpen LOCATION) ...@@ -233,11 +233,14 @@ get_target_property(MIOPEN_LOCATION MIOpen LOCATION)
check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API) check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API)
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API) check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
if(HAS_FIND_2_API) # TODO: Set default to HAS_FIND_2_API
set(MIGRAPHX_USE_FIND_2_API OFF CACHE BOOL "")
if(MIGRAPHX_USE_FIND_2_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API) target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API)
message(STATUS "MIGraphx is using Find-2.0 API of MIOpen") message(STATUS "MIGraphx is using Find-2.0 API of MIOpen")
else() else()
message(STATUS "MIOpen does not have Find-2.0 API") message(STATUS "MIGraphx is using legacy Find API in MIOpen")
endif() endif()
if(HAS_FIND_MODE_API) if(HAS_FIND_MODE_API)
......
...@@ -185,7 +185,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -185,7 +185,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
options.push_back("-fno-gpu-rdc"); options.push_back("-fno-gpu-rdc");
options.push_back(" -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3")); options.push_back(" -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3"));
options.push_back("-Wno-cuda-compat"); options.push_back("-Wno-cuda-compat");
options.push_back("--cuda-gpu-arch=" + arch); options.push_back("--offload-arch=" + arch);
prog.compile(options); prog.compile(options);
return {prog.get_code_obj()}; return {prog.get_code_obj()};
} }
...@@ -237,7 +237,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -237,7 +237,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
} }
else if(is_hip_clang_compiler()) else if(is_hip_clang_compiler())
{ {
params += " --cuda-gpu-arch=" + arch; params += " --offload-arch=" + arch;
params += " --cuda-device-only"; params += " --cuda-device-only";
params += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " "; params += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " ";
} }
......
...@@ -105,7 +105,7 @@ struct hip_copy_to_gpu ...@@ -105,7 +105,7 @@ struct hip_copy_to_gpu
std::string name() const { return "hip::copy_to_gpu"; } std::string name() const { return "hip::copy_to_gpu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1, 2); check_shapes{inputs, *this}.has(1, 2).same_type();
return inputs.at(0); return inputs.at(0);
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
...@@ -131,7 +131,7 @@ struct hip_copy_from_gpu ...@@ -131,7 +131,7 @@ struct hip_copy_from_gpu
std::string name() const { return "hip::copy_from_gpu"; } std::string name() const { return "hip::copy_from_gpu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1, 2); check_shapes{inputs, *this}.has(1, 2).same_type();
return inputs.at(0); return inputs.at(0);
} }
argument argument
...@@ -159,7 +159,7 @@ struct hip_copy ...@@ -159,7 +159,7 @@ struct hip_copy
std::string name() const { return "hip::copy"; } std::string name() const { return "hip::copy"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2).same_type();
return inputs.at(1); return inputs.at(1);
} }
argument compute(context& ctx, const shape&, std::vector<argument> args) const argument compute(context& ctx, const shape&, std::vector<argument> args) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment