Commit 3405e5bc authored by Paul's avatar Paul
Browse files

Improve eliminate_contiguous to remove all args together

parent 2863aa3e
...@@ -12,6 +12,8 @@ void auto_contiguous::apply(module& p) const ...@@ -12,6 +12,8 @@ void auto_contiguous::apply(module& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if (ins->name() == "layout")
continue;
shape s = ins->get_shape(); shape s = ins->get_shape();
if(not s.standard() and s.elements() != 0) if(not s.standard() and s.elements() != 0)
{ {
......
...@@ -73,29 +73,56 @@ void eliminate_contiguous::apply(module& p) const ...@@ -73,29 +73,56 @@ void eliminate_contiguous::apply(module& p) const
if(ins->name() == "@return") if(ins->name() == "@return")
continue; continue;
if (std::none_of(ins->inputs().begin(), ins->inputs().end(), [&](auto arg) {
return arg->name() == op_name;
}))
continue;
// Make a copy so we can modify it while we iterate // Make a copy so we can modify it while we iterate
auto args = ins->inputs(); auto args = ins->inputs();
for(auto arg : ins->inputs()) auto new_args = args;
{ std::transform(new_args.begin(), new_args.end(), new_args.begin(), [&](auto arg) {
if(arg->name() == op_name) if(arg->name() == op_name)
return arg->inputs().front();
else
return arg;
});
assert(args.size() == new_args.size());
if(try_compute_shape(ins, new_args))
{
for(auto i:range(args.size()))
{ {
auto new_args = args; if (args[i] == new_args[i])
auto prev = arg->inputs().front(); continue;
replace(new_args, arg, prev); instruction::replace_argument(ins, args[i], new_args[i]);
if(try_compute_shape(ins, new_args)) }
{ }
instruction::replace_argument(ins, arg, prev); else
} {
else if(prev->can_eval()) for(auto arg : ins->inputs())
{
if(arg->name() == op_name)
{ {
auto c = op::contiguous{}; new_args = args;
auto r = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()}); auto prev = arg->inputs().front();
replace(new_args, arg, prev);
if(try_compute_shape(ins, new_args))
{
instruction::replace_argument(ins, arg, prev);
}
else if(prev->can_eval())
{
auto c = op::contiguous{};
auto r = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
auto l = p.add_literal(r.get_shape(), r.data()); auto l = p.add_literal(r.get_shape(), r.data());
p.replace_instruction(arg, l); p.replace_instruction(arg, l);
}
} }
} }
} }
} }
} }
......
File mode changed from 100644 to 100755
...@@ -26,9 +26,21 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary> ...@@ -26,9 +26,21 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary>
auto s0 = inputs.at(0); auto s0 = inputs.at(0);
auto s1 = inputs.at(1); auto s1 = inputs.at(1);
auto r = s0; auto r = s0;
if(s0 != s1 or !s0.packed()) if(s0 == s1 and s0.packed())
{ {
r = shape{s0.type(), s0.lens()}; r = s0;
}
else if(s0.packed() != s1.packed())
{
r = s0.packed() ? s0 : s1;
}
else if(s0.broadcasted() != s1.broadcasted())
{
r = s0.broadcasted() ? s1.with_lens(s0.lens()) : s0.with_lens(s0.lens());
}
else
{
r = {s0.type(), s0.lens()};
} }
// Call to get_primitive to make sure an algo is available // Call to get_primitive to make sure an algo is available
this->get_primitive(this->to_memory_desc(r, inputs)); this->get_primitive(this->to_memory_desc(r, inputs));
......
...@@ -17,9 +17,9 @@ struct dnnl_convolution ...@@ -17,9 +17,9 @@ struct dnnl_convolution
{ {
std::vector<int> arg_map(int) const { return {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}; } std::vector<int> arg_map(int) const { return {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}; }
shape adjust_shape(const shape& x, int i) const shape adjust_shape(const shape& x, int i, const shape& output) const
{ {
auto s = base_adjust_shape(x); auto s = base_adjust_shape(x, output);
if(i == 1 and op.group > 1) if(i == 1 and op.group > 1)
{ {
// TODO: Add support for transposed weights // TODO: Add support for transposed weights
......
...@@ -11,9 +11,9 @@ struct dnnl_deconvolution ...@@ -11,9 +11,9 @@ struct dnnl_deconvolution
{ {
std::vector<int> arg_map(int) const { return {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}; } std::vector<int> arg_map(int) const { return {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}; }
shape adjust_shape(const shape& x, int i) const shape adjust_shape(const shape& x, int i, const shape& output) const
{ {
auto s = base_adjust_shape(x); auto s = base_adjust_shape(x, output);
if(i == 1) if(i == 1)
{ {
// The input and output channels are flipped for dnnl // The input and output channels are flipped for dnnl
......
...@@ -127,7 +127,7 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -127,7 +127,7 @@ struct dnnl_op : auto_register_op<Derived>
std::iota(result.begin(), result.end(), DNNL_ARG_SRC_0); std::iota(result.begin(), result.end(), DNNL_ARG_SRC_0);
return result; return result;
} }
shape base_adjust_shape(const shape& s) const shape base_adjust_shape(const shape& s, const shape& output) const
{ {
if(s.broadcasted()) if(s.broadcasted())
{ {
...@@ -143,7 +143,8 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -143,7 +143,8 @@ struct dnnl_op : auto_register_op<Derived>
else else
return len; return len;
}); });
return shape{s.type(), lens}; // Use the permutation of the output
return output.with_lens(s.type(), lens);
} }
return s; return s;
} }
...@@ -164,7 +165,7 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -164,7 +165,7 @@ struct dnnl_op : auto_register_op<Derived>
i++; i++;
} }
} }
shape adjust_shape(const shape& s, int) const { return base_adjust_shape(s); } shape adjust_shape(const shape& s, int, const shape& output) const { return base_adjust_shape(s, output); }
std::vector<int> create_arg_map(std::size_t input_size) const std::vector<int> create_arg_map(std::size_t input_size) const
{ {
const auto& self = static_cast<const Derived&>(*this); const auto& self = static_cast<const Derived&>(*this);
...@@ -183,12 +184,12 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -183,12 +184,12 @@ struct dnnl_op : auto_register_op<Derived>
{ {
const auto& self = static_cast<const Derived&>(*this); const auto& self = static_cast<const Derived&>(*this);
std::unordered_map<int, dnnl::memory::desc> result; std::unordered_map<int, dnnl::memory::desc> result;
result[DNNL_ARG_DST] = to_dnnl_memory_desc(self.adjust_shape(output_shape, inputs.size())); result[DNNL_ARG_DST] = to_dnnl_memory_desc(self.adjust_shape(output_shape, inputs.size(), output_shape));
auto m = create_arg_map(inputs.size()); auto m = create_arg_map(inputs.size());
assert(m.size() >= inputs.size()); assert(m.size() >= inputs.size());
for(int i = 0; i < inputs.size(); i++) for(int i = 0; i < inputs.size(); i++)
{ {
result[m[i]] = to_dnnl_memory_desc(self.adjust_shape(inputs[i], i)); result[m[i]] = to_dnnl_memory_desc(self.adjust_shape(inputs[i], i, output_shape));
} }
return result; return result;
} }
......
...@@ -9,7 +9,7 @@ struct dnnl_reorder : dnnl_op<dnnl_reorder, dnnl::reorder> ...@@ -9,7 +9,7 @@ struct dnnl_reorder : dnnl_op<dnnl_reorder, dnnl::reorder>
{ {
std::string name() const { return "dnnl::reorder"; } std::string name() const { return "dnnl::reorder"; }
shape adjust_shape(const shape& x, int) const { return x; } shape adjust_shape(const shape& x, int, const shape&) const { return x; }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
......
...@@ -67,6 +67,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -67,6 +67,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_algebra{}, simplify_algebra{},
simplify_reshapes{}, simplify_reshapes{},
layout_nhwc{}, layout_nhwc{},
dead_code_elimination{},
simplify_reshapes{}, simplify_reshapes{},
simplify_algebra{}, simplify_algebra{},
auto_contiguous{}, auto_contiguous{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment