Commit 879fa6ed authored by Paul's avatar Paul
Browse files

Add changes

parent d8011adf
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
namespace migraphx { namespace migraphx {
...@@ -58,17 +59,30 @@ void auto_contiguous::apply(module& m) const ...@@ -58,17 +59,30 @@ void auto_contiguous::apply(module& m) const
auto last = std::prev(m.end()); auto last = std::prev(m.end());
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name() == "layout") if(contains({"layout", "contiguous", "@return", "@param", "@outline"}, ins->name()))
continue; continue;
auto outputs = ins->outputs();
// for last instruction that is NOT a return // for last instruction that is NOT a return
if(ins->outputs().empty() and ins != last) if(outputs.empty() and ins != last)
continue; continue;
if(not outputs.empty())
// if contiguous was already inserted, skip
if(std::all_of(outputs.begin(), outputs.end(), [](auto output) {
return output->name() == "contiguous";
}))
continue;
shape s = ins->get_shape(); shape s = ins->get_shape();
if(not s.dynamic() and not s.standard() and s.elements() != 0) if(s.dynamic())
{ continue;
auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins); if(s.type() == shape::tuple_type)
m.replace_instruction(ins, c); continue;
} if(s.standard() and ins->name() == "@literal")
continue;
if(s.scalar() and not contains(ins->name(), "broadcast"))
continue;
auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins);
m.replace_instruction(ins, c);
} }
} }
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/op/contiguous.hpp> #include <migraphx/op/contiguous.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <type_traits>
#include <utility> #include <utility>
namespace migraphx { namespace migraphx {
...@@ -180,6 +181,18 @@ static void remove_contiguous(const std::string& op_name, module& m, F f) ...@@ -180,6 +181,18 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
} }
} }
static void remove_contiguous_nops(const std::string& op_name, module& m)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != op_name)
continue;
if(ins->inputs().front()->get_shape() != ins->get_shape())
continue;
m.replace_instruction(ins, ins->inputs().front());
}
}
void eliminate_contiguous::apply(module& m) const void eliminate_contiguous::apply(module& m) const
{ {
// Skip contiguous from splits first // Skip contiguous from splits first
...@@ -189,6 +202,7 @@ void eliminate_contiguous::apply(module& m) const ...@@ -189,6 +202,7 @@ void eliminate_contiguous::apply(module& m) const
return (ins->inputs().front()->outputs().size() == 1); return (ins->inputs().front()->outputs().size() == 1);
}); });
remove_contiguous(op_name, m, [](auto) { return true; }); remove_contiguous(op_name, m, [](auto) { return true; });
remove_contiguous_nops(op_name, m);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -134,7 +134,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -134,7 +134,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
prefuse_ops{}, prefuse_ops{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{},
optimize_module{}, optimize_module{},
fuse_pointwise{}, fuse_pointwise{},
dead_code_elimination{}, dead_code_elimination{},
...@@ -146,6 +145,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -146,6 +145,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
enable_pass(mlir_enabled(), fuse_mlir{&ctx}), enable_pass(mlir_enabled(), fuse_mlir{&ctx}),
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{},
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"}, eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{}, dead_code_elimination{},
......
...@@ -148,11 +148,13 @@ TEST_CASE(two_transpose_gather) ...@@ -148,11 +148,13 @@ TEST_CASE(two_transpose_gather)
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), data); migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), data);
auto ctd = m2.add_instruction(migraphx::make_op("contiguous"), td); auto ctd = m2.add_instruction(migraphx::make_op("contiguous"), td);
auto sd = m2.add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), ctd); auto sd = m2.add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), ctd);
auto bd = auto csd = m2.add_instruction(migraphx::make_op("contiguous"), sd);
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), sd); auto bd = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), csd);
auto cbd = m2.add_instruction(migraphx::make_op("contiguous"), bd); auto cbd = m2.add_instruction(migraphx::make_op("contiguous"), bd);
auto r = m2.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), cbd, ind); auto r = m2.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), cbd, ind);
m2.add_return({r}); auto cr = m2.add_instruction(migraphx::make_op("contiguous"), r);
m2.add_return({cr});
} }
EXPECT(m1 == m2); EXPECT(m1 == m2);
...@@ -177,7 +179,8 @@ TEST_CASE(standard_reshape_lazy) ...@@ -177,7 +179,8 @@ TEST_CASE(standard_reshape_lazy)
auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add); auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add);
auto r = auto r =
m2.add_instruction(migraphx::make_op("reshape_lazy", {{"dims", {2, 1, 12, 5}}}), ca); m2.add_instruction(migraphx::make_op("reshape_lazy", {{"dims", {2, 1, 12, 5}}}), ca);
m2.add_return({r}); auto cr = m2.add_instruction(migraphx::make_op("contiguous"), r);
m2.add_return({cr});
} }
EXPECT(m1 == m2); EXPECT(m1 == m2);
...@@ -198,8 +201,10 @@ TEST_CASE(standard_reshape) ...@@ -198,8 +201,10 @@ TEST_CASE(standard_reshape)
{ {
auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}}); auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto add = m2.add_instruction(migraphx::make_op("add"), data, data); auto add = m2.add_instruction(migraphx::make_op("add"), data, data);
auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), add); auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add);
m2.add_return({r}); auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), ca);
auto cr = m2.add_instruction(migraphx::make_op("contiguous"), r);
m2.add_return({cr});
} }
EXPECT(m1 == m2); EXPECT(m1 == m2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment