Commit 3405e5bc authored by Paul's avatar Paul
Browse files

Improve eliminate_contiguous to remove all args together

parent 2863aa3e
......@@ -12,6 +12,8 @@ void auto_contiguous::apply(module& p) const
{
for(auto ins : iterator_for(p))
{
if (ins->name() == "layout")
continue;
shape s = ins->get_shape();
if(not s.standard() and s.elements() != 0)
{
......
......@@ -73,29 +73,56 @@ void eliminate_contiguous::apply(module& p) const
if(ins->name() == "@return")
continue;
if (std::none_of(ins->inputs().begin(), ins->inputs().end(), [&](auto arg) {
return arg->name() == op_name;
}))
continue;
// Make a copy so we can modify it while we iterate
auto args = ins->inputs();
for(auto arg : ins->inputs())
{
auto new_args = args;
std::transform(new_args.begin(), new_args.end(), new_args.begin(), [&](auto arg) {
if(arg->name() == op_name)
return arg->inputs().front();
else
return arg;
});
assert(args.size() == new_args.size());
if(try_compute_shape(ins, new_args))
{
for(auto i:range(args.size()))
{
auto new_args = args;
auto prev = arg->inputs().front();
replace(new_args, arg, prev);
if(try_compute_shape(ins, new_args))
{
instruction::replace_argument(ins, arg, prev);
}
else if(prev->can_eval())
if (args[i] == new_args[i])
continue;
instruction::replace_argument(ins, args[i], new_args[i]);
}
}
else
{
for(auto arg : ins->inputs())
{
if(arg->name() == op_name)
{
auto c = op::contiguous{};
auto r = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
new_args = args;
auto prev = arg->inputs().front();
replace(new_args, arg, prev);
if(try_compute_shape(ins, new_args))
{
instruction::replace_argument(ins, arg, prev);
}
else if(prev->can_eval())
{
auto c = op::contiguous{};
auto r = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
auto l = p.add_literal(r.get_shape(), r.data());
p.replace_instruction(arg, l);
auto l = p.add_literal(r.get_shape(), r.data());
p.replace_instruction(arg, l);
}
}
}
}
}
}
......
File mode changed from 100644 to 100755
......@@ -26,9 +26,21 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary>
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
auto r = s0;
if(s0 != s1 or !s0.packed())
if(s0 == s1 and s0.packed())
{
r = shape{s0.type(), s0.lens()};
r = s0;
}
else if(s0.packed() != s1.packed())
{
r = s0.packed() ? s0 : s1;
}
else if(s0.broadcasted() != s1.broadcasted())
{
r = s0.broadcasted() ? s1.with_lens(s0.lens()) : s0.with_lens(s0.lens());
}
else
{
r = {s0.type(), s0.lens()};
}
// Call to get_primitive to make sure an algo is available
this->get_primitive(this->to_memory_desc(r, inputs));
......
......@@ -17,9 +17,9 @@ struct dnnl_convolution
{
std::vector<int> arg_map(int) const { return {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}; }
shape adjust_shape(const shape& x, int i) const
shape adjust_shape(const shape& x, int i, const shape& output) const
{
auto s = base_adjust_shape(x);
auto s = base_adjust_shape(x, output);
if(i == 1 and op.group > 1)
{
// TODO: Add support for transposed weights
......
......@@ -11,9 +11,9 @@ struct dnnl_deconvolution
{
std::vector<int> arg_map(int) const { return {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}; }
shape adjust_shape(const shape& x, int i) const
shape adjust_shape(const shape& x, int i, const shape& output) const
{
auto s = base_adjust_shape(x);
auto s = base_adjust_shape(x, output);
if(i == 1)
{
// The input and output channels are flipped for dnnl
......
......@@ -127,7 +127,7 @@ struct dnnl_op : auto_register_op<Derived>
std::iota(result.begin(), result.end(), DNNL_ARG_SRC_0);
return result;
}
shape base_adjust_shape(const shape& s) const
shape base_adjust_shape(const shape& s, const shape& output) const
{
if(s.broadcasted())
{
......@@ -143,7 +143,8 @@ struct dnnl_op : auto_register_op<Derived>
else
return len;
});
return shape{s.type(), lens};
// Use the permutation of the output
return output.with_lens(s.type(), lens);
}
return s;
}
......@@ -164,7 +165,7 @@ struct dnnl_op : auto_register_op<Derived>
i++;
}
}
shape adjust_shape(const shape& s, int) const { return base_adjust_shape(s); }
shape adjust_shape(const shape& s, int, const shape& output) const { return base_adjust_shape(s, output); }
std::vector<int> create_arg_map(std::size_t input_size) const
{
const auto& self = static_cast<const Derived&>(*this);
......@@ -183,12 +184,12 @@ struct dnnl_op : auto_register_op<Derived>
{
const auto& self = static_cast<const Derived&>(*this);
std::unordered_map<int, dnnl::memory::desc> result;
result[DNNL_ARG_DST] = to_dnnl_memory_desc(self.adjust_shape(output_shape, inputs.size()));
result[DNNL_ARG_DST] = to_dnnl_memory_desc(self.adjust_shape(output_shape, inputs.size(), output_shape));
auto m = create_arg_map(inputs.size());
assert(m.size() >= inputs.size());
for(int i = 0; i < inputs.size(); i++)
{
result[m[i]] = to_dnnl_memory_desc(self.adjust_shape(inputs[i], i));
result[m[i]] = to_dnnl_memory_desc(self.adjust_shape(inputs[i], i, output_shape));
}
return result;
}
......
......@@ -9,7 +9,7 @@ struct dnnl_reorder : dnnl_op<dnnl_reorder, dnnl::reorder>
{
std::string name() const { return "dnnl::reorder"; }
shape adjust_shape(const shape& x, int) const { return x; }
shape adjust_shape(const shape& x, int, const shape&) const { return x; }
shape compute_shape(const std::vector<shape>& inputs) const
{
......
......@@ -67,6 +67,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_algebra{},
simplify_reshapes{},
layout_nhwc{},
dead_code_elimination{},
simplify_reshapes{},
simplify_algebra{},
auto_contiguous{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment