"tasks/vscode:/vscode.git/clone" did not exist on "97913a1fb90613fbb623ae902a44c52f19542df7"
Commit 81b0ff5d authored by Paul Fultz II's avatar Paul Fultz II Committed by mvermeulen
Browse files

Add option to do offload copying automatically (#403)

* Add compiler options

* Add copy operators

* Formatting

* Use run_passes in tests

* Formatting

* Use run_pass in schedule test

* Formatting

* Add compile_options to get_passes in target

* Formatting

* Offload copy option

* Formatting

* Copy using pinned memory

* Formatting

* Improve performance of gpu copying

* Formatting

* Dont copy

* Formatting

* Always make an extra copy

* Formatting

* Remove unused write op

* Add missing include

* Remove copy_to_gpu function in python api

* Make offload copy disabled by default on C++

* Formatting

* Fix tidy issues

* Formatting

* Fix namespace

* Fix python tests

* Turn clang format off since its broken

* Fix compile error on gcc 5

* Remove commented code
parent e814cffb
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/compile_options.hpp>
#include <sstream> #include <sstream>
#include "test.hpp" #include "test.hpp"
#include <basic_ops.hpp> #include <basic_ops.hpp>
...@@ -15,7 +16,11 @@ struct id_target ...@@ -15,7 +16,11 @@ struct id_target
}; };
migraphx::context ctx = context{}; migraphx::context ctx = context{};
std::string name() const { return "id"; } std::string name() const { return "id"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const { return {}; } std::vector<migraphx::pass> get_passes(migraphx::context&,
const migraphx::compile_options&) const
{
return {};
}
migraphx::context get_context() const { return ctx; } migraphx::context get_context() const { return ctx; }
}; };
...@@ -72,7 +77,11 @@ struct reverse_pass ...@@ -72,7 +77,11 @@ struct reverse_pass
struct reverse_target struct reverse_target
{ {
std::string name() const { return "reverse"; } std::string name() const { return "reverse"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const { return {reverse_pass{}}; } std::vector<migraphx::pass> get_passes(migraphx::context&,
const migraphx::compile_options&) const
{
return {reverse_pass{}};
}
migraphx::context get_context() const { return {}; } migraphx::context get_context() const { return {}; }
}; };
...@@ -99,14 +108,19 @@ struct invert_pass ...@@ -99,14 +108,19 @@ struct invert_pass
struct invert_target struct invert_target
{ {
std::string name() const { return "invert"; } std::string name() const { return "invert"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const { return {invert_pass{}}; } std::vector<migraphx::pass> get_passes(migraphx::context&,
const migraphx::compile_options&) const
{
return {invert_pass{}};
}
migraphx::context get_context() const { return {}; } migraphx::context get_context() const { return {}; }
}; };
struct double_invert_target struct double_invert_target
{ {
std::string name() const { return "double_invert"; } std::string name() const { return "double_invert"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const std::vector<migraphx::pass> get_passes(migraphx::context&,
const migraphx::compile_options&) const
{ {
return {invert_pass{}, invert_pass{}}; return {invert_pass{}, invert_pass{}};
} }
......
...@@ -15,20 +15,16 @@ ...@@ -15,20 +15,16 @@
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
struct lowering_target void run_lowering(migraphx::program& p)
{ {
std::string name() const { return "gpu::lowering"; } auto ctx = migraphx::gpu::context{};
std::vector<migraphx::pass> get_passes(migraphx::context& gctx) const migraphx::run_passes(p,
{ {migraphx::auto_contiguous{},
auto& ctx = migraphx::any_cast<migraphx::gpu::context>(gctx); migraphx::gpu::lowering{&ctx, false},
return {migraphx::auto_contiguous{},
migraphx::gpu::lowering{ctx},
migraphx::dead_code_elimination{}, migraphx::dead_code_elimination{},
migraphx::eliminate_contiguous{}, migraphx::eliminate_contiguous{},
migraphx::dead_code_elimination{}}; migraphx::dead_code_elimination{}});
} }
migraphx::gpu::context get_context() const { return migraphx::gpu::context{}; }
};
TEST_CASE(tanh_shape) TEST_CASE(tanh_shape)
{ {
...@@ -48,8 +44,8 @@ TEST_CASE(tanh_shape) ...@@ -48,8 +44,8 @@ TEST_CASE(tanh_shape)
auto p2 = create_program(); auto p2 = create_program();
EXPECT(p1 == p2); EXPECT(p1 == p2);
p1.compile(lowering_target{}); run_lowering(p1);
p2.compile(lowering_target()); run_lowering(p2);
EXPECT(p1 == p2); EXPECT(p1 == p2);
......
...@@ -90,7 +90,9 @@ void compile_check(migraphx::program& p, const migraphx::target& t, bool show_tr ...@@ -90,7 +90,9 @@ void compile_check(migraphx::program& p, const migraphx::target& t, bool show_tr
auto name = t.name(); auto name = t.name();
auto s = p.get_shape(); auto s = p.get_shape();
std::stringstream ss; std::stringstream ss;
p.compile(t, migraphx::tracer{ss}); migraphx::compile_options options;
options.trace = migraphx::tracer{ss};
p.compile(t, options);
if(p.get_shape() != s) if(p.get_shape() != s)
{ {
std::cout << ss.str() << std::endl; std::cout << ss.str() << std::endl;
......
#include <migraphx/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
struct memory_coloring_target void run_pass(migraphx::program& p)
{ {
std::string name() const { return "memory_coloring"; } migraphx::run_passes(p, {migraphx::memory_coloring{"allocate", true}});
std::vector<migraphx::pass> get_passes(migraphx::context&) const }
{
return {migraphx::memory_coloring{"allocate", true}};
}
migraphx::context get_context() const { return {}; }
};
struct allocate struct allocate
{ {
...@@ -56,7 +52,7 @@ TEST_CASE(test1) ...@@ -56,7 +52,7 @@ TEST_CASE(test1)
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a2, p1); p.add_instruction(pass_op{}, a2, p1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 192); CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -70,7 +66,7 @@ TEST_CASE(test2) ...@@ -70,7 +66,7 @@ TEST_CASE(test2)
auto p1 = p.add_instruction(pass_op{}, a1, input); auto p1 = p.add_instruction(pass_op{}, a1, input);
auto p2 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto p2 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, p2, p1); p.add_instruction(pass_op{}, p2, p1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 672); CHECK(p.get_parameter_shape("scratch").bytes() == 672);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -83,7 +79,7 @@ TEST_CASE(test3) ...@@ -83,7 +79,7 @@ TEST_CASE(test3)
auto p1 = p.add_instruction(pass_op{}, p2, a1); auto p1 = p.add_instruction(pass_op{}, p2, a1);
auto p3 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto p3 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, p3, p1); p.add_instruction(pass_op{}, p3, p1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 672); CHECK(p.get_parameter_shape("scratch").bytes() == 672);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -96,7 +92,7 @@ TEST_CASE(test4) ...@@ -96,7 +92,7 @@ TEST_CASE(test4)
auto p1 = p.add_instruction(pass_op{}, p2, a1); auto p1 = p.add_instruction(pass_op{}, p2, a1);
auto p3 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto p3 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, p3, p1); p.add_instruction(pass_op{}, p3, p1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 672); CHECK(p.get_parameter_shape("scratch").bytes() == 672);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -108,7 +104,7 @@ TEST_CASE(test5) ...@@ -108,7 +104,7 @@ TEST_CASE(test5)
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto p2 = add_alloc(p, {migraphx::shape::float_type, {8}}); auto p2 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, p2, p1); p.add_instruction(pass_op{}, p2, p1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 192); CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -121,7 +117,7 @@ TEST_CASE(test6) ...@@ -121,7 +117,7 @@ TEST_CASE(test6)
auto p2 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto p2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p3 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto p3 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, p3, p2, p1); p.add_instruction(pass_op{}, p3, p2, p1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 352); CHECK(p.get_parameter_shape("scratch").bytes() == 352);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -134,7 +130,7 @@ TEST_CASE(test7) ...@@ -134,7 +130,7 @@ TEST_CASE(test7)
auto p2 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto p2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p3 = add_alloc(p, {migraphx::shape::float_type, {8}}); auto p3 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, p3, p2, p1); p.add_instruction(pass_op{}, p3, p2, p1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 224); CHECK(p.get_parameter_shape("scratch").bytes() == 224);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -147,7 +143,7 @@ TEST_CASE(test8) ...@@ -147,7 +143,7 @@ TEST_CASE(test8)
auto p2 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto p2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p3 = add_alloc(p, {migraphx::shape::float_type, {192}}); auto p3 = add_alloc(p, {migraphx::shape::float_type, {192}});
p.add_instruction(pass_op{}, p3, p2, p1); p.add_instruction(pass_op{}, p3, p2, p1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 960); CHECK(p.get_parameter_shape("scratch").bytes() == 960);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -160,7 +156,7 @@ TEST_CASE(test9) ...@@ -160,7 +156,7 @@ TEST_CASE(test9)
auto p2 = add_alloc(p, {migraphx::shape::float_type, {8}}); auto p2 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p3 = add_alloc(p, {migraphx::shape::float_type, {8}}); auto p3 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, p3, p2, p1); p.add_instruction(pass_op{}, p3, p2, p1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 96); CHECK(p.get_parameter_shape("scratch").bytes() == 96);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -170,7 +166,7 @@ TEST_CASE(test10) ...@@ -170,7 +166,7 @@ TEST_CASE(test10)
migraphx::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, a1); p.add_instruction(pass_op{}, a1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 32); CHECK(p.get_parameter_shape("scratch").bytes() == 32);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -184,7 +180,7 @@ TEST_CASE(test11) ...@@ -184,7 +180,7 @@ TEST_CASE(test11)
auto a3 = add_alloc(p, {migraphx::shape::float_type, {8}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.add_instruction(pass_op{}, a3, p2); p.add_instruction(pass_op{}, a3, p2);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 224); CHECK(p.get_parameter_shape("scratch").bytes() == 224);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -198,7 +194,7 @@ TEST_CASE(test12) ...@@ -198,7 +194,7 @@ TEST_CASE(test12)
auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.add_instruction(pass_op{}, a3, p2); p.add_instruction(pass_op{}, a3, p2);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 352); CHECK(p.get_parameter_shape("scratch").bytes() == 352);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -212,7 +208,7 @@ TEST_CASE(test13) ...@@ -212,7 +208,7 @@ TEST_CASE(test13)
auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.add_instruction(pass_op{}, a3, p2); p.add_instruction(pass_op{}, a3, p2);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 224); CHECK(p.get_parameter_shape("scratch").bytes() == 224);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -226,7 +222,7 @@ TEST_CASE(test14) ...@@ -226,7 +222,7 @@ TEST_CASE(test14)
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.add_instruction(pass_op{}, a3, p2); p.add_instruction(pass_op{}, a3, p2);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 224); CHECK(p.get_parameter_shape("scratch").bytes() == 224);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -240,7 +236,7 @@ TEST_CASE(test15) ...@@ -240,7 +236,7 @@ TEST_CASE(test15)
auto p2 = p.add_instruction(pass_op{}, a2); auto p2 = p.add_instruction(pass_op{}, a2);
auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a3, p1, p2); p.add_instruction(pass_op{}, a3, p1, p2);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 352); CHECK(p.get_parameter_shape("scratch").bytes() == 352);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -254,7 +250,7 @@ TEST_CASE(test16) ...@@ -254,7 +250,7 @@ TEST_CASE(test16)
auto p2 = p.add_instruction(pass_op{}, a2); auto p2 = p.add_instruction(pass_op{}, a2);
auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a3, p1, p2); p.add_instruction(pass_op{}, a3, p1, p2);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 160); CHECK(p.get_parameter_shape("scratch").bytes() == 160);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -268,7 +264,7 @@ TEST_CASE(test17) ...@@ -268,7 +264,7 @@ TEST_CASE(test17)
auto a2 = p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {40}})); auto a2 = p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {40}}));
auto p2 = p.add_instruction(pass_op{}, a2); auto p2 = p.add_instruction(pass_op{}, a2);
p.add_instruction(pass_op{}, a3, p1, p2); p.add_instruction(pass_op{}, a3, p1, p2);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 160); CHECK(p.get_parameter_shape("scratch").bytes() == 160);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -282,7 +278,7 @@ TEST_CASE(test18) ...@@ -282,7 +278,7 @@ TEST_CASE(test18)
auto p3 = p.add_instruction(pass_op{}, p2, p1); auto p3 = p.add_instruction(pass_op{}, p2, p1);
auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a2, p1, p2, p3); p.add_instruction(pass_op{}, a2, p1, p2, p3);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 192); CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -296,7 +292,7 @@ TEST_CASE(test19) ...@@ -296,7 +292,7 @@ TEST_CASE(test19)
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto p2 = p.add_instruction(pass_op{}, a2, p1);
auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a3, p2, p1); p.add_instruction(pass_op{}, a3, p2, p1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 352); CHECK(p.get_parameter_shape("scratch").bytes() == 352);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -310,7 +306,7 @@ TEST_CASE(test20) ...@@ -310,7 +306,7 @@ TEST_CASE(test20)
auto p1 = p.add_instruction(pass_op{}, a1, a2, a3); auto p1 = p.add_instruction(pass_op{}, a1, a2, a3);
auto a4 = add_alloc(p, {migraphx::shape::float_type, {32}}); auto a4 = add_alloc(p, {migraphx::shape::float_type, {32}});
p.add_instruction(pass_op{}, a4, p1); p.add_instruction(pass_op{}, a4, p1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 384); CHECK(p.get_parameter_shape("scratch").bytes() == 384);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -324,7 +320,7 @@ TEST_CASE(test21) ...@@ -324,7 +320,7 @@ TEST_CASE(test21)
auto p1 = p.add_instruction(pass_op{}, a1, a2, a3); auto p1 = p.add_instruction(pass_op{}, a1, a2, a3);
auto a4 = add_alloc(p, {migraphx::shape::float_type, {8}}); auto a4 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, a4, p1); p.add_instruction(pass_op{}, a4, p1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 288); CHECK(p.get_parameter_shape("scratch").bytes() == 288);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -338,7 +334,7 @@ TEST_CASE(test22) ...@@ -338,7 +334,7 @@ TEST_CASE(test22)
auto p1 = p.add_instruction(pass_op{}, a1, a2, a3); auto p1 = p.add_instruction(pass_op{}, a1, a2, a3);
auto a4 = add_alloc(p, {migraphx::shape::float_type, {8}}); auto a4 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, a4, p1); p.add_instruction(pass_op{}, a4, p1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 288); CHECK(p.get_parameter_shape("scratch").bytes() == 288);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -352,7 +348,7 @@ TEST_CASE(test23) ...@@ -352,7 +348,7 @@ TEST_CASE(test23)
auto p1 = p.add_instruction(pass_op{}, a1, a2, a3); auto p1 = p.add_instruction(pass_op{}, a1, a2, a3);
auto a4 = add_alloc(p, {migraphx::shape::float_type, {8}}); auto a4 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, a4, p1); p.add_instruction(pass_op{}, a4, p1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 288); CHECK(p.get_parameter_shape("scratch").bytes() == 288);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -366,7 +362,7 @@ TEST_CASE(test24) ...@@ -366,7 +362,7 @@ TEST_CASE(test24)
auto p1 = p.add_instruction(pass_op{}, a1, a2, a3); auto p1 = p.add_instruction(pass_op{}, a1, a2, a3);
auto a4 = add_alloc(p, {migraphx::shape::float_type, {8}}); auto a4 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, a4, p1); p.add_instruction(pass_op{}, a4, p1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 384); CHECK(p.get_parameter_shape("scratch").bytes() == 384);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -380,7 +376,7 @@ TEST_CASE(test25) ...@@ -380,7 +376,7 @@ TEST_CASE(test25)
p.add_instruction(nop{}); p.add_instruction(nop{});
auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a2, p1); p.add_instruction(pass_op{}, a2, p1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 192); CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -394,7 +390,7 @@ TEST_CASE(test26) ...@@ -394,7 +390,7 @@ TEST_CASE(test26)
p.add_instruction(nop{}, a1, p1); p.add_instruction(nop{}, a1, p1);
auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a2, p1); p.add_instruction(pass_op{}, a2, p1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 192); CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -406,7 +402,7 @@ TEST_CASE(test27) ...@@ -406,7 +402,7 @@ TEST_CASE(test27)
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(nop{}, a2, p1); p.add_instruction(nop{}, a2, p1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 192); CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -420,7 +416,7 @@ TEST_CASE(test28) ...@@ -420,7 +416,7 @@ TEST_CASE(test28)
auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.add_instruction(pass_op{}, p2, output); p.add_instruction(pass_op{}, p2, output);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 192); CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -435,7 +431,7 @@ TEST_CASE(test29) ...@@ -435,7 +431,7 @@ TEST_CASE(test29)
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.move_instruction(output, p2); p.move_instruction(output, p2);
p.add_instruction(pass_op{}, p2, output); p.add_instruction(pass_op{}, p2, output);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 192); CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -450,7 +446,7 @@ TEST_CASE(test30) ...@@ -450,7 +446,7 @@ TEST_CASE(test30)
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.move_instruction(output, p2); p.move_instruction(output, p2);
p.add_instruction(pass_op{}, p2, output); p.add_instruction(pass_op{}, p2, output);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 192); CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -464,7 +460,7 @@ TEST_CASE(test31) ...@@ -464,7 +460,7 @@ TEST_CASE(test31)
auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.move_instruction(output, a2); p.move_instruction(output, a2);
p.add_instruction(pass_op{}, a2, p1); p.add_instruction(pass_op{}, a2, p1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 192); CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -478,7 +474,7 @@ TEST_CASE(test32) ...@@ -478,7 +474,7 @@ TEST_CASE(test32)
auto p1 = p.add_instruction(pass_op{}, a2, a1, a3); auto p1 = p.add_instruction(pass_op{}, a2, a1, a3);
auto a5 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto a5 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a5, p1); p.add_instruction(pass_op{}, a5, p1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 352); CHECK(p.get_parameter_shape("scratch").bytes() == 352);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -492,7 +488,7 @@ TEST_CASE(test33) ...@@ -492,7 +488,7 @@ TEST_CASE(test33)
auto p1 = p.add_instruction(pass_op{}, a2, a1, a3); auto p1 = p.add_instruction(pass_op{}, a2, a1, a3);
auto a5 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto a5 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a5, p1); p.add_instruction(pass_op{}, a5, p1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 192); CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -506,7 +502,7 @@ TEST_CASE(test34) ...@@ -506,7 +502,7 @@ TEST_CASE(test34)
auto p1 = p.add_instruction(pass_op{}, a2, a1, a3); auto p1 = p.add_instruction(pass_op{}, a2, a1, a3);
auto a5 = add_alloc(p, {migraphx::shape::float_type, {8}}); auto a5 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, a5, p1); p.add_instruction(pass_op{}, a5, p1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 480); CHECK(p.get_parameter_shape("scratch").bytes() == 480);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -520,7 +516,7 @@ TEST_CASE(test35) ...@@ -520,7 +516,7 @@ TEST_CASE(test35)
auto p1 = p.add_instruction(pass_op{}, a2, a1, a3); auto p1 = p.add_instruction(pass_op{}, a2, a1, a3);
auto a5 = add_alloc(p, {migraphx::shape::float_type, {8}}); auto a5 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, a5, p1); p.add_instruction(pass_op{}, a5, p1);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 224); CHECK(p.get_parameter_shape("scratch").bytes() == 224);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -537,7 +533,7 @@ TEST_CASE(test36) ...@@ -537,7 +533,7 @@ TEST_CASE(test36)
auto a4 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto a4 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p3 = p.add_instruction(pass_op{}, a4, p2); auto p3 = p.add_instruction(pass_op{}, a4, p2);
p.add_instruction(pass_op{}, output, p3); p.add_instruction(pass_op{}, output, p3);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 320); CHECK(p.get_parameter_shape("scratch").bytes() == 320);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -554,7 +550,7 @@ TEST_CASE(test37) ...@@ -554,7 +550,7 @@ TEST_CASE(test37)
auto a4 = add_alloc(p, {migraphx::shape::float_type, {40}}); auto a4 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p3 = p.add_instruction(pass_op{}, a4, p2); auto p3 = p.add_instruction(pass_op{}, a4, p2);
p.add_instruction(pass_op{}, output, p3); p.add_instruction(pass_op{}, output, p3);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 320); CHECK(p.get_parameter_shape("scratch").bytes() == 320);
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -599,7 +595,7 @@ TEST_CASE(test38) ...@@ -599,7 +595,7 @@ TEST_CASE(test38)
auto p78 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}}); auto p78 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}});
auto p83 = p.add_instruction(pass_op{}, p78, p77); auto p83 = p.add_instruction(pass_op{}, p78, p77);
p.add_instruction(pass_op{}, output, p83, p63); p.add_instruction(pass_op{}, output, p83, p63);
p.compile(memory_coloring_target{}); run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 7225344); // Optimal solution is 6422528 CHECK(p.get_parameter_shape("scratch").bytes() == 7225344); // Optimal solution is 6422528
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
...@@ -609,7 +605,7 @@ TEST_CASE(literal_test) ...@@ -609,7 +605,7 @@ TEST_CASE(literal_test)
migraphx::program p; migraphx::program p;
auto lit = generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto lit = generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
p.add_literal(lit); p.add_literal(lit);
p.compile(memory_coloring_target{}); run_pass(p);
auto result = p.eval({}); auto result = p.eval({});
CHECK(lit == result); CHECK(lit == result);
} }
......
#include <migraphx/propagate_constant.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/add.hpp>
#include <migraphx/op/scalar.hpp> #include <migraphx/op/scalar.hpp>
#include <migraphx/op/mul.hpp> #include <migraphx/op/mul.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
struct const_prop_target void run_pass(migraphx::program& p)
{ {
std::string name() const { return "const_prop"; } migraphx::run_passes(p, {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}});
std::vector<migraphx::pass> get_passes(migraphx::context&) const }
{
return {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}};
}
migraphx::context get_context() const { return {}; }
};
TEST_CASE(const_add) TEST_CASE(const_add)
{ {
...@@ -23,7 +19,7 @@ TEST_CASE(const_add) ...@@ -23,7 +19,7 @@ TEST_CASE(const_add)
auto two = p1.add_literal(2); auto two = p1.add_literal(2);
auto sum = p1.add_instruction(migraphx::op::add{}, one, two); auto sum = p1.add_instruction(migraphx::op::add{}, one, two);
p1.add_instruction(pass_op{}, sum); p1.add_instruction(pass_op{}, sum);
p1.compile(const_prop_target{}); run_pass(p1);
migraphx::program p2; migraphx::program p2;
auto total = p2.add_literal(3); auto total = p2.add_literal(3);
...@@ -38,7 +34,7 @@ TEST_CASE(const_add_parameter) ...@@ -38,7 +34,7 @@ TEST_CASE(const_add_parameter)
auto two = p1.add_literal(2); auto two = p1.add_literal(2);
auto sum = p1.add_instruction(migraphx::op::add{}, one, two); auto sum = p1.add_instruction(migraphx::op::add{}, one, two);
p1.add_instruction(pass_op{}, sum); p1.add_instruction(pass_op{}, sum);
p1.compile(const_prop_target{}); run_pass(p1);
migraphx::program p2; migraphx::program p2;
auto total = p2.add_literal(3); auto total = p2.add_literal(3);
...@@ -54,7 +50,7 @@ TEST_CASE(const_multiadd) ...@@ -54,7 +50,7 @@ TEST_CASE(const_multiadd)
auto sum1 = p1.add_instruction(migraphx::op::add{}, one, two); auto sum1 = p1.add_instruction(migraphx::op::add{}, one, two);
auto sum2 = p1.add_instruction(migraphx::op::add{}, sum1, two); auto sum2 = p1.add_instruction(migraphx::op::add{}, sum1, two);
p1.add_instruction(pass_op{}, sum2); p1.add_instruction(pass_op{}, sum2);
p1.compile(const_prop_target{}); run_pass(p1);
migraphx::program p2; migraphx::program p2;
auto total = p2.add_literal(5); auto total = p2.add_literal(5);
...@@ -71,7 +67,7 @@ TEST_CASE(const_add_mul) ...@@ -71,7 +67,7 @@ TEST_CASE(const_add_mul)
auto sum1 = p1.add_instruction(migraphx::op::add{}, one, mul); auto sum1 = p1.add_instruction(migraphx::op::add{}, one, mul);
auto sum2 = p1.add_instruction(migraphx::op::add{}, sum1, two); auto sum2 = p1.add_instruction(migraphx::op::add{}, sum1, two);
p1.add_instruction(pass_op{}, sum2); p1.add_instruction(pass_op{}, sum2);
p1.compile(const_prop_target{}); run_pass(p1);
migraphx::program p2; migraphx::program p2;
auto total = p2.add_literal(7); auto total = p2.add_literal(7);
...@@ -86,7 +82,7 @@ TEST_CASE(const_add_scalar) ...@@ -86,7 +82,7 @@ TEST_CASE(const_add_scalar)
auto two = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(2)); auto two = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(2));
auto sum = p1.add_instruction(migraphx::op::add{}, one, two); auto sum = p1.add_instruction(migraphx::op::add{}, one, two);
p1.add_instruction(pass_op{}, sum); p1.add_instruction(pass_op{}, sum);
p1.compile(const_prop_target{}); run_pass(p1);
migraphx::program p2; migraphx::program p2;
auto total = auto total =
...@@ -102,7 +98,7 @@ TEST_CASE(const_scalar) ...@@ -102,7 +98,7 @@ TEST_CASE(const_scalar)
auto one = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(1)); auto one = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(1));
p1.add_instruction(pass_op{}, one); p1.add_instruction(pass_op{}, one);
} }
p1.compile(const_prop_target{}); run_pass(p1);
migraphx::program p2; migraphx::program p2;
{ {
......
...@@ -19,6 +19,7 @@ add_dependencies(check migraphx_py) ...@@ -19,6 +19,7 @@ add_dependencies(check migraphx_py)
add_py_test(cpu test_cpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(cpu test_cpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
add_py_test(gpu_offload test_gpu_offload.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(gpu test_gpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(gpu test_gpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(array test_array.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(array test_array.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
endif() endif()
...@@ -52,9 +52,9 @@ def check_shapes(r, m): ...@@ -52,9 +52,9 @@ def check_shapes(r, m):
def run(p): def run(p):
params = {} params = {}
for key, value in p.get_parameter_shapes().items(): for key, value in p.get_parameter_shapes().items():
params[key] = migraphx.to_gpu(migraphx.generate_argument(value)) params[key] = migraphx.generate_argument(value)
return migraphx.from_gpu(p.run(params)) return p.run(params)
def test_shape(shape): def test_shape(shape):
......
...@@ -9,7 +9,7 @@ params = {} ...@@ -9,7 +9,7 @@ params = {}
for key, value in p.get_parameter_shapes().items(): for key, value in p.get_parameter_shapes().items():
print("Parameter {} -> {}".format(key, value)) print("Parameter {} -> {}".format(key, value))
params[key] = migraphx.to_gpu(migraphx.generate_argument(value)) params[key] = migraphx.generate_argument(value)
r = migraphx.from_gpu(p.run(params)) r = p.run(params)
print(r) print(r)
import migraphx
p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx")
print(p)
print("Compiling ...")
p.compile(migraphx.get_target("gpu"), offload_copy=False)
print(p)
params = {}
for key, value in p.get_parameter_shapes().items():
print("Parameter {} -> {}".format(key, value))
params[key] = migraphx.to_gpu(migraphx.generate_argument(value))
r = migraphx.from_gpu(p.run(params))
print(r)
#include <migraphx/schedule.hpp> #include <migraphx/schedule.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -155,15 +156,9 @@ bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx ...@@ -155,15 +156,9 @@ bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx
return false; return false;
} }
struct schedule_target struct scheduler
{ {
schedule_model_test model{}; schedule_model_test model{};
std::string name() const { return "schedule"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
{
return {migraphx::schedule{model}};
}
migraphx::context get_context() const { return {}; }
std::size_t get_stream(migraphx::instruction_ref ins) { return model.ins2stream->at(ins); } std::size_t get_stream(migraphx::instruction_ref ins) { return model.ins2stream->at(ins); }
...@@ -176,6 +171,8 @@ struct schedule_target ...@@ -176,6 +171,8 @@ struct schedule_target
return result; return result;
} }
void run_pass(migraphx::program& p) { migraphx::run_passes(p, {migraphx::schedule{model}}); }
bool has_stream(migraphx::instruction_ref ins) { return model.ins2stream->count(ins) > 0; } bool has_stream(migraphx::instruction_ref ins) { return model.ins2stream->count(ins) > 0; }
void check_conflicts(migraphx::program& p, void check_conflicts(migraphx::program& p,
...@@ -253,13 +250,13 @@ chain(migraphx::program& p, std::size_t n, T x, migraphx::instruction_ref input) ...@@ -253,13 +250,13 @@ chain(migraphx::program& p, std::size_t n, T x, migraphx::instruction_ref input)
} }
TEST_CASE(single_entry) TEST_CASE(single_entry)
{ {
schedule_target t{}; scheduler t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one); auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one); auto onep2 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, onep1, onep2); auto binary = p.add_instruction(nary_op{}, onep1, onep2);
p.compile(t); t.run_pass(p);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2)); EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
EXPECT(t.get_stream(binary) == 0); EXPECT(t.get_stream(binary) == 0);
...@@ -270,13 +267,13 @@ TEST_CASE(single_entry) ...@@ -270,13 +267,13 @@ TEST_CASE(single_entry)
TEST_CASE(stream_free) TEST_CASE(stream_free)
{ {
schedule_target t{}; scheduler t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto onep1 = p.add_instruction(stream_free_op{}, one); auto onep1 = p.add_instruction(stream_free_op{}, one);
auto onep2 = p.add_instruction(stream_free_op{}, one); auto onep2 = p.add_instruction(stream_free_op{}, one);
auto binary = p.add_instruction(nary_op{}, onep1, onep2); auto binary = p.add_instruction(nary_op{}, onep1, onep2);
p.compile(t); t.run_pass(p);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(not t.has_stream(onep1)); EXPECT(not t.has_stream(onep1));
EXPECT(not t.has_stream(onep2)); EXPECT(not t.has_stream(onep2));
...@@ -285,7 +282,7 @@ TEST_CASE(stream_free) ...@@ -285,7 +282,7 @@ TEST_CASE(stream_free)
TEST_CASE(zero_record) TEST_CASE(zero_record)
{ {
schedule_target t{}; scheduler t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one); auto onep1 = p.add_instruction(unary_op{}, one);
...@@ -293,7 +290,7 @@ TEST_CASE(zero_record) ...@@ -293,7 +290,7 @@ TEST_CASE(zero_record)
auto onei1 = p.add_instruction(migraphx::op::identity{}, onep1); auto onei1 = p.add_instruction(migraphx::op::identity{}, onep1);
auto onei2 = p.add_instruction(migraphx::op::identity{}, onep2); auto onei2 = p.add_instruction(migraphx::op::identity{}, onep2);
auto binary = p.add_instruction(nary_op{}, onei1, onei2); auto binary = p.add_instruction(nary_op{}, onei1, onei2);
p.compile(t); t.run_pass(p);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2)); EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
EXPECT(t.has_stream(binary)); EXPECT(t.has_stream(binary));
...@@ -305,13 +302,13 @@ TEST_CASE(zero_record) ...@@ -305,13 +302,13 @@ TEST_CASE(zero_record)
TEST_CASE(zero_merge1) TEST_CASE(zero_merge1)
{ {
schedule_target t{}; scheduler t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one); auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one); auto onep2 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(migraphx::op::identity{}, onep1, onep2); auto binary = p.add_instruction(migraphx::op::identity{}, onep1, onep2);
p.compile(t); t.run_pass(p);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2)); EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
// No stream assignment // No stream assignment
...@@ -323,7 +320,7 @@ TEST_CASE(zero_merge1) ...@@ -323,7 +320,7 @@ TEST_CASE(zero_merge1)
TEST_CASE(zero_merge2) TEST_CASE(zero_merge2)
{ {
schedule_target t{}; scheduler t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one); auto onep1 = p.add_instruction(unary_op{}, one);
...@@ -331,7 +328,7 @@ TEST_CASE(zero_merge2) ...@@ -331,7 +328,7 @@ TEST_CASE(zero_merge2)
auto binary = p.add_instruction(migraphx::op::identity{}, auto binary = p.add_instruction(migraphx::op::identity{},
p.add_instruction(migraphx::op::identity{}, onep1), p.add_instruction(migraphx::op::identity{}, onep1),
p.add_instruction(migraphx::op::identity{}, onep2)); p.add_instruction(migraphx::op::identity{}, onep2));
p.compile(t); t.run_pass(p);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2)); EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
// No stream assignment // No stream assignment
...@@ -343,14 +340,14 @@ TEST_CASE(zero_merge2) ...@@ -343,14 +340,14 @@ TEST_CASE(zero_merge2)
TEST_CASE(zero_merge3) TEST_CASE(zero_merge3)
{ {
schedule_target t{}; scheduler t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one); auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one); auto onep2 = p.add_instruction(unary_op{}, one);
auto id = p.add_instruction(migraphx::op::identity{}, onep1, onep2); auto id = p.add_instruction(migraphx::op::identity{}, onep1, onep2);
auto final = p.add_instruction(unary_op{}, id); auto final = p.add_instruction(unary_op{}, id);
p.compile(t); t.run_pass(p);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2)); EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
// No stream assignment // No stream assignment
...@@ -366,7 +363,7 @@ TEST_CASE(zero_merge3) ...@@ -366,7 +363,7 @@ TEST_CASE(zero_merge3)
TEST_CASE(zero_merge4) TEST_CASE(zero_merge4)
{ {
schedule_target t{}; scheduler t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one); auto onep1 = p.add_instruction(unary_op{}, one);
...@@ -375,7 +372,7 @@ TEST_CASE(zero_merge4) ...@@ -375,7 +372,7 @@ TEST_CASE(zero_merge4)
p.add_instruction(migraphx::op::identity{}, onep1), p.add_instruction(migraphx::op::identity{}, onep1),
p.add_instruction(migraphx::op::identity{}, onep2)); p.add_instruction(migraphx::op::identity{}, onep2));
auto final = p.add_instruction(unary_op{}, id); auto final = p.add_instruction(unary_op{}, id);
p.compile(t); t.run_pass(p);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2)); EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
// No stream assignment // No stream assignment
...@@ -391,14 +388,14 @@ TEST_CASE(zero_merge4) ...@@ -391,14 +388,14 @@ TEST_CASE(zero_merge4)
TEST_CASE(double_entry) TEST_CASE(double_entry)
{ {
schedule_target t{}; scheduler t{};
migraphx::program p; migraphx::program p;
auto one = p.add_instruction(stream_free_op{}, p.add_literal(1)); auto one = p.add_instruction(stream_free_op{}, p.add_literal(1));
auto two = p.add_instruction(stream_free_op{}, p.add_literal(2)); auto two = p.add_instruction(stream_free_op{}, p.add_literal(2));
auto onep = p.add_instruction(unary_op{}, one); auto onep = p.add_instruction(unary_op{}, one);
auto twop = p.add_instruction(unary_op{}, two); auto twop = p.add_instruction(unary_op{}, two);
auto binary = p.add_instruction(nary_op{}, onep, twop); auto binary = p.add_instruction(nary_op{}, onep, twop);
p.compile(t); t.run_pass(p);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(not t.has_stream(two)); EXPECT(not t.has_stream(two));
EXPECT(t.get_stream(onep) != t.get_stream(twop)); EXPECT(t.get_stream(onep) != t.get_stream(twop));
...@@ -410,13 +407,13 @@ TEST_CASE(double_entry) ...@@ -410,13 +407,13 @@ TEST_CASE(double_entry)
TEST_CASE(two_branches) TEST_CASE(two_branches)
{ {
schedule_target t{}; scheduler t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto c1 = chain(p, 2, unary_op{}, one); auto c1 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one); auto i1 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, i1, c1.back()); auto binary = p.add_instruction(nary_op{}, i1, c1.back());
p.compile(t); t.run_pass(p);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 1); EXPECT(t.get_stream(i1) == 1);
for(auto ins : c1) for(auto ins : c1)
...@@ -429,7 +426,7 @@ TEST_CASE(two_branches) ...@@ -429,7 +426,7 @@ TEST_CASE(two_branches)
TEST_CASE(four_branches) TEST_CASE(four_branches)
{ {
schedule_target t{}; scheduler t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto c1 = chain(p, 4, unary_op{}, one); auto c1 = chain(p, 4, unary_op{}, one);
...@@ -437,7 +434,7 @@ TEST_CASE(four_branches) ...@@ -437,7 +434,7 @@ TEST_CASE(four_branches)
auto c3 = chain(p, 2, unary_op{}, one); auto c3 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one); auto i1 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back()); auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back());
p.compile(t); t.run_pass(p);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 3); EXPECT(t.get_stream(i1) == 3);
for(auto ins : c1) for(auto ins : c1)
...@@ -457,7 +454,7 @@ TEST_CASE(four_branches) ...@@ -457,7 +454,7 @@ TEST_CASE(four_branches)
TEST_CASE(five_branches) TEST_CASE(five_branches)
{ {
schedule_target t{}; scheduler t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto c1 = chain(p, 5, unary_op{}, one); auto c1 = chain(p, 5, unary_op{}, one);
...@@ -466,7 +463,7 @@ TEST_CASE(five_branches) ...@@ -466,7 +463,7 @@ TEST_CASE(five_branches)
auto c4 = chain(p, 2, unary_op{}, one); auto c4 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one); auto i1 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back(), c4.back()); auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back(), c4.back());
p.compile(t); t.run_pass(p);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 3); EXPECT(t.get_stream(i1) == 3);
for(auto ins : c1) for(auto ins : c1)
...@@ -489,7 +486,7 @@ TEST_CASE(five_branches) ...@@ -489,7 +486,7 @@ TEST_CASE(five_branches)
TEST_CASE(four_branches_eq) TEST_CASE(four_branches_eq)
{ {
schedule_target t{}; scheduler t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one); auto onep1 = p.add_instruction(unary_op{}, one);
...@@ -497,7 +494,7 @@ TEST_CASE(four_branches_eq) ...@@ -497,7 +494,7 @@ TEST_CASE(four_branches_eq)
auto onep3 = p.add_instruction(unary_op{}, one); auto onep3 = p.add_instruction(unary_op{}, one);
auto onep4 = p.add_instruction(unary_op{}, one); auto onep4 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, onep1, onep2, onep3, onep4); auto binary = p.add_instruction(nary_op{}, onep1, onep2, onep3, onep4);
p.compile(t); t.run_pass(p);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT( EXPECT(
sorted<std::size_t>( sorted<std::size_t>(
...@@ -515,7 +512,7 @@ TEST_CASE(four_branches_eq) ...@@ -515,7 +512,7 @@ TEST_CASE(four_branches_eq)
TEST_CASE(seq_merge) TEST_CASE(seq_merge)
{ {
schedule_target t{}; scheduler t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto c1 = chain(p, 2, unary_op{}, one); auto c1 = chain(p, 2, unary_op{}, one);
...@@ -526,7 +523,7 @@ TEST_CASE(seq_merge) ...@@ -526,7 +523,7 @@ TEST_CASE(seq_merge)
auto i2 = p.add_instruction(unary_op{}, binary1); auto i2 = p.add_instruction(unary_op{}, binary1);
auto binary2 = p.add_instruction(nary_op{}, i2, c2.back()); auto binary2 = p.add_instruction(nary_op{}, i2, c2.back());
p.compile(t); t.run_pass(p);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) != t.get_stream(c1.back())); EXPECT(t.get_stream(i1) != t.get_stream(c1.back()));
...@@ -548,7 +545,7 @@ TEST_CASE(seq_merge) ...@@ -548,7 +545,7 @@ TEST_CASE(seq_merge)
TEST_CASE(par_merge) TEST_CASE(par_merge)
{ {
schedule_target t{}; scheduler t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto start1 = p.add_instruction(unary_op{}, one); auto start1 = p.add_instruction(unary_op{}, one);
...@@ -563,7 +560,7 @@ TEST_CASE(par_merge) ...@@ -563,7 +560,7 @@ TEST_CASE(par_merge)
auto binary3 = p.add_instruction(nary_op{}, binary1, binary2); auto binary3 = p.add_instruction(nary_op{}, binary1, binary2);
p.compile(t); t.run_pass(p);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(binary3) == 0); EXPECT(t.get_stream(binary3) == 0);
...@@ -589,7 +586,7 @@ TEST_CASE(par_merge) ...@@ -589,7 +586,7 @@ TEST_CASE(par_merge)
TEST_CASE(inner_par_merge) TEST_CASE(inner_par_merge)
{ {
schedule_target t{}; scheduler t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto start1 = p.add_instruction(unary_op{}, one); auto start1 = p.add_instruction(unary_op{}, one);
...@@ -607,7 +604,7 @@ TEST_CASE(inner_par_merge) ...@@ -607,7 +604,7 @@ TEST_CASE(inner_par_merge)
auto output = p.add_instruction(nary_op{}, binary1, binary2, outer1, outer2); auto output = p.add_instruction(nary_op{}, binary1, binary2, outer1, outer2);
p.compile(t); t.run_pass(p);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(output) == 0); EXPECT(t.get_stream(output) == 0);
EXPECT(get_wait_for(output) == get_wait_for(t.get_stream(output), EXPECT(get_wait_for(output) == get_wait_for(t.get_stream(output),
...@@ -642,7 +639,7 @@ TEST_CASE(inner_par_merge) ...@@ -642,7 +639,7 @@ TEST_CASE(inner_par_merge)
TEST_CASE(par_merge_multi_entry) TEST_CASE(par_merge_multi_entry)
{ {
schedule_target t{}; scheduler t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto start1 = p.add_instruction(unary_op{}, one); auto start1 = p.add_instruction(unary_op{}, one);
...@@ -658,7 +655,7 @@ TEST_CASE(par_merge_multi_entry) ...@@ -658,7 +655,7 @@ TEST_CASE(par_merge_multi_entry)
auto binary3 = p.add_instruction(nary_op{}, binary1, binary2); auto binary3 = p.add_instruction(nary_op{}, binary1, binary2);
p.compile(t); t.run_pass(p);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(not t.has_stream(two)); EXPECT(not t.has_stream(two));
EXPECT(t.get_stream(binary3) == 0); EXPECT(t.get_stream(binary3) == 0);
...@@ -685,7 +682,7 @@ TEST_CASE(par_merge_multi_entry) ...@@ -685,7 +682,7 @@ TEST_CASE(par_merge_multi_entry)
TEST_CASE(inner_split1) TEST_CASE(inner_split1)
{ {
schedule_target t{}; scheduler t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto c1 = chain(p, 2, unary_op{}, one); auto c1 = chain(p, 2, unary_op{}, one);
...@@ -693,7 +690,7 @@ TEST_CASE(inner_split1) ...@@ -693,7 +690,7 @@ TEST_CASE(inner_split1)
auto s1 = p.add_instruction(unary_op{}, c1); auto s1 = p.add_instruction(unary_op{}, c1);
auto s2 = p.add_instruction(unary_op{}, c1); auto s2 = p.add_instruction(unary_op{}, c1);
auto output = p.add_instruction(nary_op{}, i1, s1, s2); auto output = p.add_instruction(nary_op{}, i1, s1, s2);
p.compile(t); t.run_pass(p);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) != t.get_stream(s1)); EXPECT(t.get_stream(i1) != t.get_stream(s1));
EXPECT(t.get_stream(i1) != t.get_stream(s2)); EXPECT(t.get_stream(i1) != t.get_stream(s2));
...@@ -712,7 +709,7 @@ TEST_CASE(inner_split1) ...@@ -712,7 +709,7 @@ TEST_CASE(inner_split1)
TEST_CASE(inner_split2) TEST_CASE(inner_split2)
{ {
schedule_target t{}; scheduler t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto c1 = chain(p, 2, unary_op{}, one); auto c1 = chain(p, 2, unary_op{}, one);
...@@ -720,7 +717,7 @@ TEST_CASE(inner_split2) ...@@ -720,7 +717,7 @@ TEST_CASE(inner_split2)
auto s1 = chain(p, 3, unary_op{}, c1.back()); auto s1 = chain(p, 3, unary_op{}, c1.back());
auto s2 = chain(p, 4, unary_op{}, c1.back()); auto s2 = chain(p, 4, unary_op{}, c1.back());
auto output = p.add_instruction(nary_op{}, i1, s1.back(), s2.back()); auto output = p.add_instruction(nary_op{}, i1, s1.back(), s2.back());
p.compile(t); t.run_pass(p);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) != t.get_stream(s1.back())); EXPECT(t.get_stream(i1) != t.get_stream(s1.back()));
EXPECT(t.get_stream(i1) != t.get_stream(s2.back())); EXPECT(t.get_stream(i1) != t.get_stream(s2.back()));
...@@ -738,7 +735,7 @@ TEST_CASE(inner_split2) ...@@ -738,7 +735,7 @@ TEST_CASE(inner_split2)
TEST_CASE(inception_resnet) TEST_CASE(inception_resnet)
{ {
schedule_target t{}; scheduler t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto input = p.add_instruction(unary_op{}, one); auto input = p.add_instruction(unary_op{}, one);
...@@ -746,7 +743,7 @@ TEST_CASE(inception_resnet) ...@@ -746,7 +743,7 @@ TEST_CASE(inception_resnet)
auto i1 = p.add_instruction(unary_op{}, input); auto i1 = p.add_instruction(unary_op{}, input);
auto binary = p.add_instruction(nary_op{}, i1, c1.back()); auto binary = p.add_instruction(nary_op{}, i1, c1.back());
auto output = p.add_instruction(nary_op{}, binary, input); auto output = p.add_instruction(nary_op{}, binary, input);
p.compile(t); t.run_pass(p);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) != 0); EXPECT(t.get_stream(i1) != 0);
for(auto ins : c1) for(auto ins : c1)
...@@ -761,7 +758,7 @@ TEST_CASE(inception_resnet) ...@@ -761,7 +758,7 @@ TEST_CASE(inception_resnet)
TEST_CASE(inception1) TEST_CASE(inception1)
{ {
schedule_target t{}; scheduler t{};
migraphx::program p; migraphx::program p;
auto i1 = p.add_literal(0); auto i1 = p.add_literal(0);
...@@ -854,7 +851,7 @@ TEST_CASE(inception1) ...@@ -854,7 +851,7 @@ TEST_CASE(inception1)
auto i101 = p.add_literal(2); auto i101 = p.add_literal(2);
auto output = p.add_instruction(nary_op{"output"}, i96, i101, i100, i98, i99); auto output = p.add_instruction(nary_op{"output"}, i96, i101, i100, i98, i99);
p.compile(t); t.run_pass(p);
EXPECT(t.get_streams({i7, i11, i17, i23, i25, i31, i37, i39}) == EXPECT(t.get_streams({i7, i11, i17, i23, i25, i31, i37, i39}) ==
t.get_streams({i7, i7, i7, i7, i7, i7, i7, i7})); t.get_streams({i7, i7, i7, i7, i7, i7, i7, i7}));
......
#include <migraphx/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
...@@ -7,15 +8,10 @@ ...@@ -7,15 +8,10 @@
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
struct simplify_algebra_target void run_pass(migraphx::program& p)
{ {
std::string name() const { return "simplify_algebra"; } migraphx::run_passes(p, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}});
std::vector<migraphx::pass> get_passes(migraphx::context&) const }
{
return {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}};
}
migraphx::context get_context() const { return {}; }
};
TEST_CASE(simplify_add1) TEST_CASE(simplify_add1)
{ {
...@@ -30,7 +26,7 @@ TEST_CASE(simplify_add1) ...@@ -30,7 +26,7 @@ TEST_CASE(simplify_add1)
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2); auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3); p1.add_instruction(pass_op{}, sum3);
} }
p1.compile(simplify_algebra_target{}); run_pass(p1);
migraphx::program p2; migraphx::program p2;
{ {
...@@ -59,7 +55,7 @@ TEST_CASE(simplify_add2) ...@@ -59,7 +55,7 @@ TEST_CASE(simplify_add2)
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2); auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3); p1.add_instruction(pass_op{}, sum3);
} }
p1.compile(simplify_algebra_target{}); run_pass(p1);
migraphx::program p2; migraphx::program p2;
{ {
...@@ -87,7 +83,7 @@ TEST_CASE(simplify_add3) ...@@ -87,7 +83,7 @@ TEST_CASE(simplify_add3)
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2); auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3); p1.add_instruction(pass_op{}, sum3);
} }
p1.compile(simplify_algebra_target{}); run_pass(p1);
migraphx::program p2; migraphx::program p2;
{ {
...@@ -120,7 +116,7 @@ TEST_CASE(simplify_add_broadcast1) ...@@ -120,7 +116,7 @@ TEST_CASE(simplify_add_broadcast1)
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2); auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3); p1.add_instruction(pass_op{}, sum3);
} }
p1.compile(simplify_algebra_target{}); run_pass(p1);
migraphx::program p2; migraphx::program p2;
{ {
...@@ -156,7 +152,7 @@ TEST_CASE(simplify_add_broadcast2) ...@@ -156,7 +152,7 @@ TEST_CASE(simplify_add_broadcast2)
return p; return p;
}; };
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
p1.compile(simplify_algebra_target{}); run_pass(p1);
migraphx::program p2 = create_program(); migraphx::program p2 = create_program();
EXPECT(p1 == p2); EXPECT(p1 == p2);
...@@ -177,7 +173,7 @@ void simplify_add4() ...@@ -177,7 +173,7 @@ void simplify_add4()
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum2, two); auto sum3 = p1.add_instruction(migraphx::op::add{}, sum2, two);
p1.add_instruction(pass_op{}, sum3); p1.add_instruction(pass_op{}, sum3);
} }
p1.compile(simplify_algebra_target{}); run_pass(p1);
migraphx::program p2; migraphx::program p2;
{ {
...@@ -205,7 +201,7 @@ TEST_CASE(simplify_mul_conv1) ...@@ -205,7 +201,7 @@ TEST_CASE(simplify_mul_conv1)
auto mul = p.add_instruction(migraphx::op::mul{}, conv, b); auto mul = p.add_instruction(migraphx::op::mul{}, conv, b);
p.add_instruction(pass_op{}, mul); p.add_instruction(pass_op{}, mul);
EXPECT(conv->outputs().front()->name() == "mul"); EXPECT(conv->outputs().front()->name() == "mul");
p.compile(simplify_algebra_target{}); run_pass(p);
auto new_conv = auto new_conv =
std::find_if(p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }); std::find_if(p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; });
EXPECT(new_conv->outputs().front()->name() != "mul"); EXPECT(new_conv->outputs().front()->name() != "mul");
...@@ -222,7 +218,7 @@ TEST_CASE(simplify_mul_add) ...@@ -222,7 +218,7 @@ TEST_CASE(simplify_mul_add)
auto mul = p1.add_instruction(migraphx::op::mul{}, sum, two); auto mul = p1.add_instruction(migraphx::op::mul{}, sum, two);
p1.add_instruction(pass_op{}, mul); p1.add_instruction(pass_op{}, mul);
} }
p1.compile(simplify_algebra_target{}); run_pass(p1);
migraphx::program p2; migraphx::program p2;
{ {
...@@ -249,7 +245,7 @@ TEST_CASE(simplify_inner_broadcast) ...@@ -249,7 +245,7 @@ TEST_CASE(simplify_inner_broadcast)
auto sum = p1.add_instruction(migraphx::op::add{}, xb, yb); auto sum = p1.add_instruction(migraphx::op::add{}, xb, yb);
p1.add_instruction(pass_op{}, sum); p1.add_instruction(pass_op{}, sum);
} }
p1.compile(simplify_algebra_target{}); run_pass(p1);
migraphx::program p2; migraphx::program p2;
{ {
...@@ -276,7 +272,7 @@ TEST_CASE(simplify_add_conv1) ...@@ -276,7 +272,7 @@ TEST_CASE(simplify_add_conv1)
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2); auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto s = p.get_shape(); auto s = p.get_shape();
p.compile(simplify_algebra_target{}); run_pass(p);
EXPECT(s == p.get_shape()); EXPECT(s == p.get_shape());
EXPECT(std::count_if( EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1); p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
...@@ -296,7 +292,7 @@ TEST_CASE(simplify_add_conv_no_fusion_7x7_diff_strides) ...@@ -296,7 +292,7 @@ TEST_CASE(simplify_add_conv_no_fusion_7x7_diff_strides)
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2); auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto s = p.get_shape(); auto s = p.get_shape();
p.compile(simplify_algebra_target{}); run_pass(p);
EXPECT(s == p.get_shape()); EXPECT(s == p.get_shape());
// No fusion // No fusion
EXPECT(std::count_if( EXPECT(std::count_if(
...@@ -317,7 +313,7 @@ TEST_CASE(simplify_add_conv_1x1_diff_strides1) ...@@ -317,7 +313,7 @@ TEST_CASE(simplify_add_conv_1x1_diff_strides1)
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2); auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto s = p.get_shape(); auto s = p.get_shape();
p.compile(simplify_algebra_target{}); run_pass(p);
EXPECT(s == p.get_shape()); EXPECT(s == p.get_shape());
EXPECT(std::count_if( EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1); p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
...@@ -337,7 +333,7 @@ TEST_CASE(simplify_add_conv_1x1_diff_strides2) ...@@ -337,7 +333,7 @@ TEST_CASE(simplify_add_conv_1x1_diff_strides2)
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2); auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto s = p.get_shape(); auto s = p.get_shape();
p.compile(simplify_algebra_target{}); run_pass(p);
EXPECT(s == p.get_shape()); EXPECT(s == p.get_shape());
EXPECT(std::count_if( EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1); p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
...@@ -357,7 +353,7 @@ TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides1) ...@@ -357,7 +353,7 @@ TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides1)
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2); auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto s = p.get_shape(); auto s = p.get_shape();
p.compile(simplify_algebra_target{}); run_pass(p);
EXPECT(s == p.get_shape()); EXPECT(s == p.get_shape());
// No fusion // No fusion
EXPECT(std::count_if( EXPECT(std::count_if(
...@@ -378,7 +374,7 @@ TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides2) ...@@ -378,7 +374,7 @@ TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides2)
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2); auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto s = p.get_shape(); auto s = p.get_shape();
p.compile(simplify_algebra_target{}); run_pass(p);
EXPECT(s == p.get_shape()); EXPECT(s == p.get_shape());
// No fusion // No fusion
EXPECT(std::count_if( EXPECT(std::count_if(
......
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
struct simplify_reshapes_target void run_pass(migraphx::program& p)
{ {
std::string name() const { return "simplify_reshapes"; } migraphx::run_passes(p, {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}});
std::vector<migraphx::pass> get_passes(migraphx::context&) const }
{
return {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}};
}
migraphx::context get_context() const { return {}; }
};
TEST_CASE(double_contig) TEST_CASE(double_contig)
{ {
...@@ -26,7 +22,7 @@ TEST_CASE(double_contig) ...@@ -26,7 +22,7 @@ TEST_CASE(double_contig)
p.add_instruction(pass_op{}, c2); p.add_instruction(pass_op{}, c2);
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
p.compile(simplify_reshapes_target{}); run_pass(p);
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 4); EXPECT(std::distance(p.begin(), p.end()) == 4);
...@@ -43,7 +39,7 @@ TEST_CASE(double_transpose) ...@@ -43,7 +39,7 @@ TEST_CASE(double_transpose)
p.add_instruction(pass_op{}, t2); p.add_instruction(pass_op{}, t2);
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
p.compile(simplify_reshapes_target{}); run_pass(p);
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 2); EXPECT(std::distance(p.begin(), p.end()) == 2);
...@@ -62,7 +58,7 @@ TEST_CASE(double_transpose_contig) ...@@ -62,7 +58,7 @@ TEST_CASE(double_transpose_contig)
p.add_instruction(pass_op{}, c2); p.add_instruction(pass_op{}, c2);
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
p.compile(simplify_reshapes_target{}); run_pass(p);
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 2); EXPECT(std::distance(p.begin(), p.end()) == 2);
...@@ -78,7 +74,7 @@ TEST_CASE(single_transpose) ...@@ -78,7 +74,7 @@ TEST_CASE(single_transpose)
p.add_instruction(pass_op{}, t1); p.add_instruction(pass_op{}, t1);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed()); EXPECT(p.get_shape().transposed());
p.compile(simplify_reshapes_target{}); run_pass(p);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed()); EXPECT(p.get_shape().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 3); EXPECT(std::distance(p.begin(), p.end()) == 3);
...@@ -94,7 +90,7 @@ TEST_CASE(double_transpose_sin_pass) ...@@ -94,7 +90,7 @@ TEST_CASE(double_transpose_sin_pass)
p.add_instruction(migraphx::op::transpose{{1, 0}}, t1); p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
p.compile(simplify_reshapes_target{}); run_pass(p);
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
// TODO: Fix this // TODO: Fix this
...@@ -110,7 +106,7 @@ TEST_CASE(single_transpose_sin_pass) ...@@ -110,7 +106,7 @@ TEST_CASE(single_transpose_sin_pass)
p.add_instruction(migraphx::op::transpose{{1, 0}}, l); p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed()); EXPECT(p.get_shape().transposed());
p.compile(simplify_reshapes_target{}); run_pass(p);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed()); EXPECT(p.get_shape().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 2); EXPECT(std::distance(p.begin(), p.end()) == 2);
...@@ -130,7 +126,7 @@ TEST_CASE(reshape_transpose) ...@@ -130,7 +126,7 @@ TEST_CASE(reshape_transpose)
p.add_instruction(pass_op{}, r2); p.add_instruction(pass_op{}, r2);
EXPECT(p.get_shape() == s); EXPECT(p.get_shape() == s);
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{}); run_pass(p);
EXPECT(p.get_shape() == s); EXPECT(p.get_shape() == s);
EXPECT(std::distance(p.begin(), p.end()) == n); EXPECT(std::distance(p.begin(), p.end()) == n);
} }
...@@ -145,7 +141,7 @@ TEST_CASE(transpose_contiguous) ...@@ -145,7 +141,7 @@ TEST_CASE(transpose_contiguous)
p.add_instruction(pass_op{}, c1); p.add_instruction(pass_op{}, c1);
auto out_shape = p.get_shape(); auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{}); run_pass(p);
EXPECT(p.get_shape() == out_shape); EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n); EXPECT(std::distance(p.begin(), p.end()) == n);
} }
...@@ -161,7 +157,7 @@ TEST_CASE(transpose_double_contiguous) ...@@ -161,7 +157,7 @@ TEST_CASE(transpose_double_contiguous)
p.add_instruction(pass_op{}, c2); p.add_instruction(pass_op{}, c2);
auto out_shape = p.get_shape(); auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{}); run_pass(p);
EXPECT(p.get_shape() == out_shape); EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1); EXPECT(std::distance(p.begin(), p.end()) == n - 1);
EXPECT(p.has_instruction(t)); EXPECT(p.has_instruction(t));
...@@ -177,7 +173,7 @@ TEST_CASE(transpose_partial1) ...@@ -177,7 +173,7 @@ TEST_CASE(transpose_partial1)
p.add_instruction(pass_op{}, t2); p.add_instruction(pass_op{}, t2);
auto out_shape = p.get_shape(); auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{}); run_pass(p);
EXPECT(p.get_shape() == out_shape); EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1); EXPECT(std::distance(p.begin(), p.end()) == n - 1);
} }
...@@ -193,7 +189,7 @@ TEST_CASE(transpose_partial2) ...@@ -193,7 +189,7 @@ TEST_CASE(transpose_partial2)
p.add_instruction(pass_op{}, t3); p.add_instruction(pass_op{}, t3);
auto out_shape = p.get_shape(); auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{}); run_pass(p);
EXPECT(p.get_shape() == out_shape); EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 2); EXPECT(std::distance(p.begin(), p.end()) == n - 2);
} }
...@@ -210,7 +206,7 @@ TEST_CASE(transpose_partial3) ...@@ -210,7 +206,7 @@ TEST_CASE(transpose_partial3)
p.add_instruction(pass_op{}, t4); p.add_instruction(pass_op{}, t4);
auto out_shape = p.get_shape(); auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{}); run_pass(p);
EXPECT(p.get_shape() == out_shape); EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 3); EXPECT(std::distance(p.begin(), p.end()) == n - 3);
} }
...@@ -224,7 +220,7 @@ TEST_CASE(nop_transpose1) ...@@ -224,7 +220,7 @@ TEST_CASE(nop_transpose1)
p.add_instruction(pass_op{}, t); p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape(); auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{}); run_pass(p);
EXPECT(p.get_shape() == out_shape); EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1); EXPECT(std::distance(p.begin(), p.end()) == n - 1);
} }
...@@ -241,7 +237,7 @@ TEST_CASE(nop_transpose2) ...@@ -241,7 +237,7 @@ TEST_CASE(nop_transpose2)
p.add_instruction(pass_op{}, t4); p.add_instruction(pass_op{}, t4);
auto out_shape = p.get_shape(); auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{}); run_pass(p);
EXPECT(p.get_shape() == out_shape); EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 4); EXPECT(std::distance(p.begin(), p.end()) == n - 4);
} }
...@@ -258,7 +254,7 @@ TEST_CASE(nop_transpose3) ...@@ -258,7 +254,7 @@ TEST_CASE(nop_transpose3)
p.add_instruction(pass_op{}, t2); p.add_instruction(pass_op{}, t2);
auto out_shape = p.get_shape(); auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{}); run_pass(p);
EXPECT(p.get_shape() == out_shape); EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1); EXPECT(std::distance(p.begin(), p.end()) == n - 1);
} }
...@@ -276,7 +272,7 @@ TEST_CASE(concat_transpose1) ...@@ -276,7 +272,7 @@ TEST_CASE(concat_transpose1)
p.add_instruction(pass_op{}, t); p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape(); auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{}); run_pass(p);
EXPECT(p.get_shape().lens() == out_shape.lens()); EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 3); EXPECT(std::distance(p.begin(), p.end()) == n - 3);
auto new_concat = auto new_concat =
...@@ -298,7 +294,7 @@ TEST_CASE(concat_transpose2) ...@@ -298,7 +294,7 @@ TEST_CASE(concat_transpose2)
p.add_instruction(pass_op{}, t); p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape(); auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{}); run_pass(p);
EXPECT(p.get_shape().lens() == out_shape.lens()); EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2); EXPECT(std::distance(p.begin(), p.end()) == n - 2);
auto new_concat = auto new_concat =
...@@ -320,7 +316,7 @@ TEST_CASE(concat_transpose3) ...@@ -320,7 +316,7 @@ TEST_CASE(concat_transpose3)
p.add_instruction(pass_op{}, t); p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape(); auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{}); run_pass(p);
EXPECT(p.get_shape().lens() == out_shape.lens()); EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2); EXPECT(std::distance(p.begin(), p.end()) == n - 2);
auto new_concat = auto new_concat =
...@@ -341,7 +337,7 @@ TEST_CASE(nested_concat) ...@@ -341,7 +337,7 @@ TEST_CASE(nested_concat)
p.add_instruction(pass_op{}, concat3); p.add_instruction(pass_op{}, concat3);
auto out_shape = p.get_shape(); auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{}); run_pass(p);
EXPECT(p.get_shape().lens() == out_shape.lens()); EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2); EXPECT(std::distance(p.begin(), p.end()) == n - 2);
EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1); EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
...@@ -361,7 +357,7 @@ TEST_CASE(nested_concat_partial) ...@@ -361,7 +357,7 @@ TEST_CASE(nested_concat_partial)
p.add_instruction(pass_op{}, concat3); p.add_instruction(pass_op{}, concat3);
auto out_shape = p.get_shape(); auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{}); run_pass(p);
EXPECT(p.get_shape().lens() == out_shape.lens()); EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2); EXPECT(std::distance(p.begin(), p.end()) == n - 2);
EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1); EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <migraphx/context.hpp> #include <migraphx/context.hpp>
#include <migraphx/pass.hpp> #include <migraphx/pass.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
...@@ -28,9 +29,10 @@ struct target ...@@ -28,9 +29,10 @@ struct target
* @brief The transformation pass to be run during compilation. * @brief The transformation pass to be run during compilation.
* *
* @param ctx This is the target-dependent context that is created by `get_context` * @param ctx This is the target-dependent context that is created by `get_context`
* @param options Compiling options passed in by the user
* @return The passes to be ran * @return The passes to be ran
*/ */
std::vector<pass> get_passes(context& ctx) const; std::vector<pass> get_passes(context& ctx, const compile_options& options) const;
/** /**
* @brief Construct a context for the target. * @brief Construct a context for the target.
* @return The context to be used during compilation and execution. * @return The context to be used during compilation and execution.
...@@ -119,7 +121,7 @@ argument copy_from_target(T& x, const argument& arg) ...@@ -119,7 +121,7 @@ argument copy_from_target(T& x, const argument& arg)
<% <%
interface('target', interface('target',
virtual('name', returns='std::string', const=True), virtual('name', returns='std::string', const=True),
virtual('get_passes', ctx='context&', returns='std::vector<pass>', const=True), virtual('get_passes', ctx='context&', options='const compile_options&', returns='std::vector<pass>', const=True),
virtual('get_context', returns='context', const=True), virtual('get_context', returns='context', const=True),
virtual('copy_to', virtual('copy_to',
returns = 'argument', returns = 'argument',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment