Commit 81b0ff5d authored by Paul Fultz II's avatar Paul Fultz II Committed by mvermeulen
Browse files

Add option to do offload copying automatically (#403)

* Add compiler options

* Add copy operators

* Formatting

* Use run_passes in tests

* Formatting

* Use run_pass in schedule test

* Formatting

* Add compile_options to get_passes in target

* Formatting

* Offload copy option

* Formatting

* Copy using pinned memory

* Formatting

* Improve performance of gpu copying

* Formatting

* Dont copy

* Formatting

* Always make an extra copy

* Formatting

* Remove unused write op

* Add missing include

* Remove copy_to_gpu function in python api

* Make offload copy disabled by default on C++

* Formatting

* Fix tidy issues

* Formatting

* Fix namespace

* Fix python tests

* Turn clang format off since its broken

* Fix compile error on gcc 5

* Remove commented code
parent e814cffb
......@@ -3,6 +3,7 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/compile_options.hpp>
#include <sstream>
#include "test.hpp"
#include <basic_ops.hpp>
......@@ -15,7 +16,11 @@ struct id_target
};
migraphx::context ctx = context{};
std::string name() const { return "id"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const { return {}; }
std::vector<migraphx::pass> get_passes(migraphx::context&,
const migraphx::compile_options&) const
{
return {};
}
migraphx::context get_context() const { return ctx; }
};
......@@ -72,7 +77,11 @@ struct reverse_pass
struct reverse_target
{
std::string name() const { return "reverse"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const { return {reverse_pass{}}; }
std::vector<migraphx::pass> get_passes(migraphx::context&,
const migraphx::compile_options&) const
{
return {reverse_pass{}};
}
migraphx::context get_context() const { return {}; }
};
......@@ -99,14 +108,19 @@ struct invert_pass
struct invert_target
{
std::string name() const { return "invert"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const { return {invert_pass{}}; }
std::vector<migraphx::pass> get_passes(migraphx::context&,
const migraphx::compile_options&) const
{
return {invert_pass{}};
}
migraphx::context get_context() const { return {}; }
};
struct double_invert_target
{
std::string name() const { return "double_invert"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
std::vector<migraphx::pass> get_passes(migraphx::context&,
const migraphx::compile_options&) const
{
return {invert_pass{}, invert_pass{}};
}
......
......@@ -15,20 +15,16 @@
#include <basic_ops.hpp>
#include <test.hpp>
struct lowering_target
void run_lowering(migraphx::program& p)
{
std::string name() const { return "gpu::lowering"; }
std::vector<migraphx::pass> get_passes(migraphx::context& gctx) const
{
auto& ctx = migraphx::any_cast<migraphx::gpu::context>(gctx);
return {migraphx::auto_contiguous{},
migraphx::gpu::lowering{ctx},
migraphx::dead_code_elimination{},
migraphx::eliminate_contiguous{},
migraphx::dead_code_elimination{}};
}
migraphx::gpu::context get_context() const { return migraphx::gpu::context{}; }
};
auto ctx = migraphx::gpu::context{};
migraphx::run_passes(p,
{migraphx::auto_contiguous{},
migraphx::gpu::lowering{&ctx, false},
migraphx::dead_code_elimination{},
migraphx::eliminate_contiguous{},
migraphx::dead_code_elimination{}});
}
TEST_CASE(tanh_shape)
{
......@@ -48,8 +44,8 @@ TEST_CASE(tanh_shape)
auto p2 = create_program();
EXPECT(p1 == p2);
p1.compile(lowering_target{});
p2.compile(lowering_target());
run_lowering(p1);
run_lowering(p2);
EXPECT(p1 == p2);
......
......@@ -90,7 +90,9 @@ void compile_check(migraphx::program& p, const migraphx::target& t, bool show_tr
auto name = t.name();
auto s = p.get_shape();
std::stringstream ss;
p.compile(t, migraphx::tracer{ss});
migraphx::compile_options options;
options.trace = migraphx::tracer{ss};
p.compile(t, options);
if(p.get_shape() != s)
{
std::cout << ss.str() << std::endl;
......
#include <migraphx/memory_coloring.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct memory_coloring_target
void run_pass(migraphx::program& p)
{
std::string name() const { return "memory_coloring"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
{
return {migraphx::memory_coloring{"allocate", true}};
}
migraphx::context get_context() const { return {}; }
};
migraphx::run_passes(p, {migraphx::memory_coloring{"allocate", true}});
}
struct allocate
{
......@@ -56,7 +52,7 @@ TEST_CASE(test1)
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a2, p1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p));
}
......@@ -70,7 +66,7 @@ TEST_CASE(test2)
auto p1 = p.add_instruction(pass_op{}, a1, input);
auto p2 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, p2, p1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 672);
CHECK(no_allocate(p));
}
......@@ -83,7 +79,7 @@ TEST_CASE(test3)
auto p1 = p.add_instruction(pass_op{}, p2, a1);
auto p3 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, p3, p1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 672);
CHECK(no_allocate(p));
}
......@@ -96,7 +92,7 @@ TEST_CASE(test4)
auto p1 = p.add_instruction(pass_op{}, p2, a1);
auto p3 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, p3, p1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 672);
CHECK(no_allocate(p));
}
......@@ -108,7 +104,7 @@ TEST_CASE(test5)
auto p1 = p.add_instruction(pass_op{}, a1);
auto p2 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, p2, p1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p));
}
......@@ -121,7 +117,7 @@ TEST_CASE(test6)
auto p2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p3 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, p3, p2, p1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 352);
CHECK(no_allocate(p));
}
......@@ -134,7 +130,7 @@ TEST_CASE(test7)
auto p2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p3 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, p3, p2, p1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 224);
CHECK(no_allocate(p));
}
......@@ -147,7 +143,7 @@ TEST_CASE(test8)
auto p2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p3 = add_alloc(p, {migraphx::shape::float_type, {192}});
p.add_instruction(pass_op{}, p3, p2, p1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 960);
CHECK(no_allocate(p));
}
......@@ -160,7 +156,7 @@ TEST_CASE(test9)
auto p2 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p3 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, p3, p2, p1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 96);
CHECK(no_allocate(p));
}
......@@ -170,7 +166,7 @@ TEST_CASE(test10)
migraphx::program p;
auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, a1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 32);
CHECK(no_allocate(p));
}
......@@ -184,7 +180,7 @@ TEST_CASE(test11)
auto a3 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.add_instruction(pass_op{}, a3, p2);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 224);
CHECK(no_allocate(p));
}
......@@ -198,7 +194,7 @@ TEST_CASE(test12)
auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.add_instruction(pass_op{}, a3, p2);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 352);
CHECK(no_allocate(p));
}
......@@ -212,7 +208,7 @@ TEST_CASE(test13)
auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.add_instruction(pass_op{}, a3, p2);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 224);
CHECK(no_allocate(p));
}
......@@ -226,7 +222,7 @@ TEST_CASE(test14)
auto p1 = p.add_instruction(pass_op{}, a1);
auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.add_instruction(pass_op{}, a3, p2);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 224);
CHECK(no_allocate(p));
}
......@@ -240,7 +236,7 @@ TEST_CASE(test15)
auto p2 = p.add_instruction(pass_op{}, a2);
auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a3, p1, p2);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 352);
CHECK(no_allocate(p));
}
......@@ -254,7 +250,7 @@ TEST_CASE(test16)
auto p2 = p.add_instruction(pass_op{}, a2);
auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a3, p1, p2);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 160);
CHECK(no_allocate(p));
}
......@@ -268,7 +264,7 @@ TEST_CASE(test17)
auto a2 = p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {40}}));
auto p2 = p.add_instruction(pass_op{}, a2);
p.add_instruction(pass_op{}, a3, p1, p2);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 160);
CHECK(no_allocate(p));
}
......@@ -282,7 +278,7 @@ TEST_CASE(test18)
auto p3 = p.add_instruction(pass_op{}, p2, p1);
auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a2, p1, p2, p3);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p));
}
......@@ -296,7 +292,7 @@ TEST_CASE(test19)
auto p2 = p.add_instruction(pass_op{}, a2, p1);
auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a3, p2, p1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 352);
CHECK(no_allocate(p));
}
......@@ -310,7 +306,7 @@ TEST_CASE(test20)
auto p1 = p.add_instruction(pass_op{}, a1, a2, a3);
auto a4 = add_alloc(p, {migraphx::shape::float_type, {32}});
p.add_instruction(pass_op{}, a4, p1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 384);
CHECK(no_allocate(p));
}
......@@ -324,7 +320,7 @@ TEST_CASE(test21)
auto p1 = p.add_instruction(pass_op{}, a1, a2, a3);
auto a4 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, a4, p1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 288);
CHECK(no_allocate(p));
}
......@@ -338,7 +334,7 @@ TEST_CASE(test22)
auto p1 = p.add_instruction(pass_op{}, a1, a2, a3);
auto a4 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, a4, p1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 288);
CHECK(no_allocate(p));
}
......@@ -352,7 +348,7 @@ TEST_CASE(test23)
auto p1 = p.add_instruction(pass_op{}, a1, a2, a3);
auto a4 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, a4, p1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 288);
CHECK(no_allocate(p));
}
......@@ -366,7 +362,7 @@ TEST_CASE(test24)
auto p1 = p.add_instruction(pass_op{}, a1, a2, a3);
auto a4 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, a4, p1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 384);
CHECK(no_allocate(p));
}
......@@ -380,7 +376,7 @@ TEST_CASE(test25)
p.add_instruction(nop{});
auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a2, p1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p));
}
......@@ -394,7 +390,7 @@ TEST_CASE(test26)
p.add_instruction(nop{}, a1, p1);
auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a2, p1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p));
}
......@@ -406,7 +402,7 @@ TEST_CASE(test27)
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(nop{}, a2, p1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p));
}
......@@ -420,7 +416,7 @@ TEST_CASE(test28)
auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.add_instruction(pass_op{}, p2, output);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p));
}
......@@ -435,7 +431,7 @@ TEST_CASE(test29)
auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.move_instruction(output, p2);
p.add_instruction(pass_op{}, p2, output);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p));
}
......@@ -450,7 +446,7 @@ TEST_CASE(test30)
auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.move_instruction(output, p2);
p.add_instruction(pass_op{}, p2, output);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p));
}
......@@ -464,7 +460,7 @@ TEST_CASE(test31)
auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.move_instruction(output, a2);
p.add_instruction(pass_op{}, a2, p1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p));
}
......@@ -478,7 +474,7 @@ TEST_CASE(test32)
auto p1 = p.add_instruction(pass_op{}, a2, a1, a3);
auto a5 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a5, p1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 352);
CHECK(no_allocate(p));
}
......@@ -492,7 +488,7 @@ TEST_CASE(test33)
auto p1 = p.add_instruction(pass_op{}, a2, a1, a3);
auto a5 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a5, p1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p));
}
......@@ -506,7 +502,7 @@ TEST_CASE(test34)
auto p1 = p.add_instruction(pass_op{}, a2, a1, a3);
auto a5 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, a5, p1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 480);
CHECK(no_allocate(p));
}
......@@ -520,7 +516,7 @@ TEST_CASE(test35)
auto p1 = p.add_instruction(pass_op{}, a2, a1, a3);
auto a5 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, a5, p1);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 224);
CHECK(no_allocate(p));
}
......@@ -537,7 +533,7 @@ TEST_CASE(test36)
auto a4 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p3 = p.add_instruction(pass_op{}, a4, p2);
p.add_instruction(pass_op{}, output, p3);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 320);
CHECK(no_allocate(p));
}
......@@ -554,7 +550,7 @@ TEST_CASE(test37)
auto a4 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p3 = p.add_instruction(pass_op{}, a4, p2);
p.add_instruction(pass_op{}, output, p3);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 320);
CHECK(no_allocate(p));
}
......@@ -599,7 +595,7 @@ TEST_CASE(test38)
auto p78 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}});
auto p83 = p.add_instruction(pass_op{}, p78, p77);
p.add_instruction(pass_op{}, output, p83, p63);
p.compile(memory_coloring_target{});
run_pass(p);
CHECK(p.get_parameter_shape("scratch").bytes() == 7225344); // Optimal solution is 6422528
CHECK(no_allocate(p));
}
......@@ -609,7 +605,7 @@ TEST_CASE(literal_test)
migraphx::program p;
auto lit = generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
p.add_literal(lit);
p.compile(memory_coloring_target{});
run_pass(p);
auto result = p.eval({});
CHECK(lit == result);
}
......
#include <migraphx/propagate_constant.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/mul.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct const_prop_target
void run_pass(migraphx::program& p)
{
std::string name() const { return "const_prop"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
{
return {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}};
}
migraphx::context get_context() const { return {}; }
};
migraphx::run_passes(p, {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}});
}
TEST_CASE(const_add)
{
......@@ -23,7 +19,7 @@ TEST_CASE(const_add)
auto two = p1.add_literal(2);
auto sum = p1.add_instruction(migraphx::op::add{}, one, two);
p1.add_instruction(pass_op{}, sum);
p1.compile(const_prop_target{});
run_pass(p1);
migraphx::program p2;
auto total = p2.add_literal(3);
......@@ -38,7 +34,7 @@ TEST_CASE(const_add_parameter)
auto two = p1.add_literal(2);
auto sum = p1.add_instruction(migraphx::op::add{}, one, two);
p1.add_instruction(pass_op{}, sum);
p1.compile(const_prop_target{});
run_pass(p1);
migraphx::program p2;
auto total = p2.add_literal(3);
......@@ -54,7 +50,7 @@ TEST_CASE(const_multiadd)
auto sum1 = p1.add_instruction(migraphx::op::add{}, one, two);
auto sum2 = p1.add_instruction(migraphx::op::add{}, sum1, two);
p1.add_instruction(pass_op{}, sum2);
p1.compile(const_prop_target{});
run_pass(p1);
migraphx::program p2;
auto total = p2.add_literal(5);
......@@ -71,7 +67,7 @@ TEST_CASE(const_add_mul)
auto sum1 = p1.add_instruction(migraphx::op::add{}, one, mul);
auto sum2 = p1.add_instruction(migraphx::op::add{}, sum1, two);
p1.add_instruction(pass_op{}, sum2);
p1.compile(const_prop_target{});
run_pass(p1);
migraphx::program p2;
auto total = p2.add_literal(7);
......@@ -86,7 +82,7 @@ TEST_CASE(const_add_scalar)
auto two = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(2));
auto sum = p1.add_instruction(migraphx::op::add{}, one, two);
p1.add_instruction(pass_op{}, sum);
p1.compile(const_prop_target{});
run_pass(p1);
migraphx::program p2;
auto total =
......@@ -102,7 +98,7 @@ TEST_CASE(const_scalar)
auto one = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(1));
p1.add_instruction(pass_op{}, one);
}
p1.compile(const_prop_target{});
run_pass(p1);
migraphx::program p2;
{
......
......@@ -19,6 +19,7 @@ add_dependencies(check migraphx_py)
add_py_test(cpu test_cpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
if(MIGRAPHX_ENABLE_GPU)
add_py_test(gpu_offload test_gpu_offload.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(gpu test_gpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(array test_array.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
endif()
......@@ -52,9 +52,9 @@ def check_shapes(r, m):
def run(p):
params = {}
for key, value in p.get_parameter_shapes().items():
params[key] = migraphx.to_gpu(migraphx.generate_argument(value))
params[key] = migraphx.generate_argument(value)
return migraphx.from_gpu(p.run(params))
return p.run(params)
def test_shape(shape):
......
......@@ -9,7 +9,7 @@ params = {}
for key, value in p.get_parameter_shapes().items():
print("Parameter {} -> {}".format(key, value))
params[key] = migraphx.to_gpu(migraphx.generate_argument(value))
params[key] = migraphx.generate_argument(value)
r = migraphx.from_gpu(p.run(params))
r = p.run(params)
print(r)
import migraphx
p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx")
print(p)
print("Compiling ...")
p.compile(migraphx.get_target("gpu"), offload_copy=False)
print(p)
params = {}
for key, value in p.get_parameter_shapes().items():
print("Parameter {} -> {}".format(key, value))
params[key] = migraphx.to_gpu(migraphx.generate_argument(value))
r = migraphx.from_gpu(p.run(params))
print(r)
#include <migraphx/schedule.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
......@@ -155,15 +156,9 @@ bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx
return false;
}
struct schedule_target
struct scheduler
{
schedule_model_test model{};
std::string name() const { return "schedule"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
{
return {migraphx::schedule{model}};
}
migraphx::context get_context() const { return {}; }
std::size_t get_stream(migraphx::instruction_ref ins) { return model.ins2stream->at(ins); }
......@@ -176,6 +171,8 @@ struct schedule_target
return result;
}
void run_pass(migraphx::program& p) { migraphx::run_passes(p, {migraphx::schedule{model}}); }
bool has_stream(migraphx::instruction_ref ins) { return model.ins2stream->count(ins) > 0; }
void check_conflicts(migraphx::program& p,
......@@ -253,13 +250,13 @@ chain(migraphx::program& p, std::size_t n, T x, migraphx::instruction_ref input)
}
TEST_CASE(single_entry)
{
schedule_target t{};
scheduler t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, onep1, onep2);
p.compile(t);
t.run_pass(p);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
EXPECT(t.get_stream(binary) == 0);
......@@ -270,13 +267,13 @@ TEST_CASE(single_entry)
TEST_CASE(stream_free)
{
schedule_target t{};
scheduler t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(stream_free_op{}, one);
auto onep2 = p.add_instruction(stream_free_op{}, one);
auto binary = p.add_instruction(nary_op{}, onep1, onep2);
p.compile(t);
t.run_pass(p);
EXPECT(not t.has_stream(one));
EXPECT(not t.has_stream(onep1));
EXPECT(not t.has_stream(onep2));
......@@ -285,7 +282,7 @@ TEST_CASE(stream_free)
TEST_CASE(zero_record)
{
schedule_target t{};
scheduler t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
......@@ -293,7 +290,7 @@ TEST_CASE(zero_record)
auto onei1 = p.add_instruction(migraphx::op::identity{}, onep1);
auto onei2 = p.add_instruction(migraphx::op::identity{}, onep2);
auto binary = p.add_instruction(nary_op{}, onei1, onei2);
p.compile(t);
t.run_pass(p);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
EXPECT(t.has_stream(binary));
......@@ -305,13 +302,13 @@ TEST_CASE(zero_record)
TEST_CASE(zero_merge1)
{
schedule_target t{};
scheduler t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(migraphx::op::identity{}, onep1, onep2);
p.compile(t);
t.run_pass(p);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
// No stream assignment
......@@ -323,7 +320,7 @@ TEST_CASE(zero_merge1)
TEST_CASE(zero_merge2)
{
schedule_target t{};
scheduler t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
......@@ -331,7 +328,7 @@ TEST_CASE(zero_merge2)
auto binary = p.add_instruction(migraphx::op::identity{},
p.add_instruction(migraphx::op::identity{}, onep1),
p.add_instruction(migraphx::op::identity{}, onep2));
p.compile(t);
t.run_pass(p);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
// No stream assignment
......@@ -343,14 +340,14 @@ TEST_CASE(zero_merge2)
TEST_CASE(zero_merge3)
{
schedule_target t{};
scheduler t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one);
auto id = p.add_instruction(migraphx::op::identity{}, onep1, onep2);
auto final = p.add_instruction(unary_op{}, id);
p.compile(t);
t.run_pass(p);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
// No stream assignment
......@@ -366,7 +363,7 @@ TEST_CASE(zero_merge3)
TEST_CASE(zero_merge4)
{
schedule_target t{};
scheduler t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
......@@ -375,7 +372,7 @@ TEST_CASE(zero_merge4)
p.add_instruction(migraphx::op::identity{}, onep1),
p.add_instruction(migraphx::op::identity{}, onep2));
auto final = p.add_instruction(unary_op{}, id);
p.compile(t);
t.run_pass(p);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
// No stream assignment
......@@ -391,14 +388,14 @@ TEST_CASE(zero_merge4)
TEST_CASE(double_entry)
{
schedule_target t{};
scheduler t{};
migraphx::program p;
auto one = p.add_instruction(stream_free_op{}, p.add_literal(1));
auto two = p.add_instruction(stream_free_op{}, p.add_literal(2));
auto onep = p.add_instruction(unary_op{}, one);
auto twop = p.add_instruction(unary_op{}, two);
auto binary = p.add_instruction(nary_op{}, onep, twop);
p.compile(t);
t.run_pass(p);
EXPECT(not t.has_stream(one));
EXPECT(not t.has_stream(two));
EXPECT(t.get_stream(onep) != t.get_stream(twop));
......@@ -410,13 +407,13 @@ TEST_CASE(double_entry)
TEST_CASE(two_branches)
{
schedule_target t{};
scheduler t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, i1, c1.back());
p.compile(t);
t.run_pass(p);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 1);
for(auto ins : c1)
......@@ -429,7 +426,7 @@ TEST_CASE(two_branches)
TEST_CASE(four_branches)
{
schedule_target t{};
scheduler t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 4, unary_op{}, one);
......@@ -437,7 +434,7 @@ TEST_CASE(four_branches)
auto c3 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back());
p.compile(t);
t.run_pass(p);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 3);
for(auto ins : c1)
......@@ -457,7 +454,7 @@ TEST_CASE(four_branches)
TEST_CASE(five_branches)
{
schedule_target t{};
scheduler t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 5, unary_op{}, one);
......@@ -466,7 +463,7 @@ TEST_CASE(five_branches)
auto c4 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back(), c4.back());
p.compile(t);
t.run_pass(p);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 3);
for(auto ins : c1)
......@@ -489,7 +486,7 @@ TEST_CASE(five_branches)
TEST_CASE(four_branches_eq)
{
schedule_target t{};
scheduler t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
......@@ -497,7 +494,7 @@ TEST_CASE(four_branches_eq)
auto onep3 = p.add_instruction(unary_op{}, one);
auto onep4 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, onep1, onep2, onep3, onep4);
p.compile(t);
t.run_pass(p);
EXPECT(not t.has_stream(one));
EXPECT(
sorted<std::size_t>(
......@@ -515,7 +512,7 @@ TEST_CASE(four_branches_eq)
TEST_CASE(seq_merge)
{
schedule_target t{};
scheduler t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 2, unary_op{}, one);
......@@ -526,7 +523,7 @@ TEST_CASE(seq_merge)
auto i2 = p.add_instruction(unary_op{}, binary1);
auto binary2 = p.add_instruction(nary_op{}, i2, c2.back());
p.compile(t);
t.run_pass(p);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) != t.get_stream(c1.back()));
......@@ -548,7 +545,7 @@ TEST_CASE(seq_merge)
TEST_CASE(par_merge)
{
schedule_target t{};
scheduler t{};
migraphx::program p;
auto one = p.add_literal(1);
auto start1 = p.add_instruction(unary_op{}, one);
......@@ -563,7 +560,7 @@ TEST_CASE(par_merge)
auto binary3 = p.add_instruction(nary_op{}, binary1, binary2);
p.compile(t);
t.run_pass(p);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(binary3) == 0);
......@@ -589,7 +586,7 @@ TEST_CASE(par_merge)
TEST_CASE(inner_par_merge)
{
schedule_target t{};
scheduler t{};
migraphx::program p;
auto one = p.add_literal(1);
auto start1 = p.add_instruction(unary_op{}, one);
......@@ -607,7 +604,7 @@ TEST_CASE(inner_par_merge)
auto output = p.add_instruction(nary_op{}, binary1, binary2, outer1, outer2);
p.compile(t);
t.run_pass(p);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(output) == 0);
EXPECT(get_wait_for(output) == get_wait_for(t.get_stream(output),
......@@ -642,7 +639,7 @@ TEST_CASE(inner_par_merge)
TEST_CASE(par_merge_multi_entry)
{
schedule_target t{};
scheduler t{};
migraphx::program p;
auto one = p.add_literal(1);
auto start1 = p.add_instruction(unary_op{}, one);
......@@ -658,7 +655,7 @@ TEST_CASE(par_merge_multi_entry)
auto binary3 = p.add_instruction(nary_op{}, binary1, binary2);
p.compile(t);
t.run_pass(p);
EXPECT(not t.has_stream(one));
EXPECT(not t.has_stream(two));
EXPECT(t.get_stream(binary3) == 0);
......@@ -685,7 +682,7 @@ TEST_CASE(par_merge_multi_entry)
TEST_CASE(inner_split1)
{
schedule_target t{};
scheduler t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 2, unary_op{}, one);
......@@ -693,7 +690,7 @@ TEST_CASE(inner_split1)
auto s1 = p.add_instruction(unary_op{}, c1);
auto s2 = p.add_instruction(unary_op{}, c1);
auto output = p.add_instruction(nary_op{}, i1, s1, s2);
p.compile(t);
t.run_pass(p);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) != t.get_stream(s1));
EXPECT(t.get_stream(i1) != t.get_stream(s2));
......@@ -712,7 +709,7 @@ TEST_CASE(inner_split1)
TEST_CASE(inner_split2)
{
schedule_target t{};
scheduler t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 2, unary_op{}, one);
......@@ -720,7 +717,7 @@ TEST_CASE(inner_split2)
auto s1 = chain(p, 3, unary_op{}, c1.back());
auto s2 = chain(p, 4, unary_op{}, c1.back());
auto output = p.add_instruction(nary_op{}, i1, s1.back(), s2.back());
p.compile(t);
t.run_pass(p);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) != t.get_stream(s1.back()));
EXPECT(t.get_stream(i1) != t.get_stream(s2.back()));
......@@ -738,7 +735,7 @@ TEST_CASE(inner_split2)
TEST_CASE(inception_resnet)
{
schedule_target t{};
scheduler t{};
migraphx::program p;
auto one = p.add_literal(1);
auto input = p.add_instruction(unary_op{}, one);
......@@ -746,7 +743,7 @@ TEST_CASE(inception_resnet)
auto i1 = p.add_instruction(unary_op{}, input);
auto binary = p.add_instruction(nary_op{}, i1, c1.back());
auto output = p.add_instruction(nary_op{}, binary, input);
p.compile(t);
t.run_pass(p);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) != 0);
for(auto ins : c1)
......@@ -761,7 +758,7 @@ TEST_CASE(inception_resnet)
TEST_CASE(inception1)
{
schedule_target t{};
scheduler t{};
migraphx::program p;
auto i1 = p.add_literal(0);
......@@ -854,7 +851,7 @@ TEST_CASE(inception1)
auto i101 = p.add_literal(2);
auto output = p.add_instruction(nary_op{"output"}, i96, i101, i100, i98, i99);
p.compile(t);
t.run_pass(p);
EXPECT(t.get_streams({i7, i11, i17, i23, i25, i31, i37, i39}) ==
t.get_streams({i7, i7, i7, i7, i7, i7, i7, i7}));
......
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
......@@ -7,15 +8,10 @@
#include <basic_ops.hpp>
#include <test.hpp>
struct simplify_algebra_target
void run_pass(migraphx::program& p)
{
std::string name() const { return "simplify_algebra"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
{
return {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}};
}
migraphx::context get_context() const { return {}; }
};
migraphx::run_passes(p, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}});
}
TEST_CASE(simplify_add1)
{
......@@ -30,7 +26,7 @@ TEST_CASE(simplify_add1)
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3);
}
p1.compile(simplify_algebra_target{});
run_pass(p1);
migraphx::program p2;
{
......@@ -59,7 +55,7 @@ TEST_CASE(simplify_add2)
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3);
}
p1.compile(simplify_algebra_target{});
run_pass(p1);
migraphx::program p2;
{
......@@ -87,7 +83,7 @@ TEST_CASE(simplify_add3)
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3);
}
p1.compile(simplify_algebra_target{});
run_pass(p1);
migraphx::program p2;
{
......@@ -120,7 +116,7 @@ TEST_CASE(simplify_add_broadcast1)
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3);
}
p1.compile(simplify_algebra_target{});
run_pass(p1);
migraphx::program p2;
{
......@@ -156,7 +152,7 @@ TEST_CASE(simplify_add_broadcast2)
return p;
};
migraphx::program p1 = create_program();
p1.compile(simplify_algebra_target{});
run_pass(p1);
migraphx::program p2 = create_program();
EXPECT(p1 == p2);
......@@ -177,7 +173,7 @@ void simplify_add4()
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum2, two);
p1.add_instruction(pass_op{}, sum3);
}
p1.compile(simplify_algebra_target{});
run_pass(p1);
migraphx::program p2;
{
......@@ -205,7 +201,7 @@ TEST_CASE(simplify_mul_conv1)
auto mul = p.add_instruction(migraphx::op::mul{}, conv, b);
p.add_instruction(pass_op{}, mul);
EXPECT(conv->outputs().front()->name() == "mul");
p.compile(simplify_algebra_target{});
run_pass(p);
auto new_conv =
std::find_if(p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; });
EXPECT(new_conv->outputs().front()->name() != "mul");
......@@ -222,7 +218,7 @@ TEST_CASE(simplify_mul_add)
auto mul = p1.add_instruction(migraphx::op::mul{}, sum, two);
p1.add_instruction(pass_op{}, mul);
}
p1.compile(simplify_algebra_target{});
run_pass(p1);
migraphx::program p2;
{
......@@ -249,7 +245,7 @@ TEST_CASE(simplify_inner_broadcast)
auto sum = p1.add_instruction(migraphx::op::add{}, xb, yb);
p1.add_instruction(pass_op{}, sum);
}
p1.compile(simplify_algebra_target{});
run_pass(p1);
migraphx::program p2;
{
......@@ -276,7 +272,7 @@ TEST_CASE(simplify_add_conv1)
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto s = p.get_shape();
p.compile(simplify_algebra_target{});
run_pass(p);
EXPECT(s == p.get_shape());
EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
......@@ -296,7 +292,7 @@ TEST_CASE(simplify_add_conv_no_fusion_7x7_diff_strides)
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto s = p.get_shape();
p.compile(simplify_algebra_target{});
run_pass(p);
EXPECT(s == p.get_shape());
// No fusion
EXPECT(std::count_if(
......@@ -317,7 +313,7 @@ TEST_CASE(simplify_add_conv_1x1_diff_strides1)
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto s = p.get_shape();
p.compile(simplify_algebra_target{});
run_pass(p);
EXPECT(s == p.get_shape());
EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
......@@ -337,7 +333,7 @@ TEST_CASE(simplify_add_conv_1x1_diff_strides2)
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto s = p.get_shape();
p.compile(simplify_algebra_target{});
run_pass(p);
EXPECT(s == p.get_shape());
EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
......@@ -357,7 +353,7 @@ TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides1)
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto s = p.get_shape();
p.compile(simplify_algebra_target{});
run_pass(p);
EXPECT(s == p.get_shape());
// No fusion
EXPECT(std::count_if(
......@@ -378,7 +374,7 @@ TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides2)
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto s = p.get_shape();
p.compile(simplify_algebra_target{});
run_pass(p);
EXPECT(s == p.get_shape());
// No fusion
EXPECT(std::count_if(
......
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct simplify_reshapes_target
void run_pass(migraphx::program& p)
{
std::string name() const { return "simplify_reshapes"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
{
return {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}};
}
migraphx::context get_context() const { return {}; }
};
migraphx::run_passes(p, {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}});
}
TEST_CASE(double_contig)
{
......@@ -26,7 +22,7 @@ TEST_CASE(double_contig)
p.add_instruction(pass_op{}, c2);
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
p.compile(simplify_reshapes_target{});
run_pass(p);
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 4);
......@@ -43,7 +39,7 @@ TEST_CASE(double_transpose)
p.add_instruction(pass_op{}, t2);
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
p.compile(simplify_reshapes_target{});
run_pass(p);
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 2);
......@@ -62,7 +58,7 @@ TEST_CASE(double_transpose_contig)
p.add_instruction(pass_op{}, c2);
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
p.compile(simplify_reshapes_target{});
run_pass(p);
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 2);
......@@ -78,7 +74,7 @@ TEST_CASE(single_transpose)
p.add_instruction(pass_op{}, t1);
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed());
p.compile(simplify_reshapes_target{});
run_pass(p);
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 3);
......@@ -94,7 +90,7 @@ TEST_CASE(double_transpose_sin_pass)
p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
p.compile(simplify_reshapes_target{});
run_pass(p);
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
// TODO: Fix this
......@@ -110,7 +106,7 @@ TEST_CASE(single_transpose_sin_pass)
p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed());
p.compile(simplify_reshapes_target{});
run_pass(p);
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 2);
......@@ -130,7 +126,7 @@ TEST_CASE(reshape_transpose)
p.add_instruction(pass_op{}, r2);
EXPECT(p.get_shape() == s);
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
run_pass(p);
EXPECT(p.get_shape() == s);
EXPECT(std::distance(p.begin(), p.end()) == n);
}
......@@ -145,7 +141,7 @@ TEST_CASE(transpose_contiguous)
p.add_instruction(pass_op{}, c1);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
run_pass(p);
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n);
}
......@@ -161,7 +157,7 @@ TEST_CASE(transpose_double_contiguous)
p.add_instruction(pass_op{}, c2);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
run_pass(p);
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
EXPECT(p.has_instruction(t));
......@@ -177,7 +173,7 @@ TEST_CASE(transpose_partial1)
p.add_instruction(pass_op{}, t2);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
run_pass(p);
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
......@@ -193,7 +189,7 @@ TEST_CASE(transpose_partial2)
p.add_instruction(pass_op{}, t3);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
run_pass(p);
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
}
......@@ -210,7 +206,7 @@ TEST_CASE(transpose_partial3)
p.add_instruction(pass_op{}, t4);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
run_pass(p);
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 3);
}
......@@ -224,7 +220,7 @@ TEST_CASE(nop_transpose1)
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
run_pass(p);
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
......@@ -241,7 +237,7 @@ TEST_CASE(nop_transpose2)
p.add_instruction(pass_op{}, t4);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
run_pass(p);
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 4);
}
......@@ -258,7 +254,7 @@ TEST_CASE(nop_transpose3)
p.add_instruction(pass_op{}, t2);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
run_pass(p);
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
......@@ -276,7 +272,7 @@ TEST_CASE(concat_transpose1)
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
run_pass(p);
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 3);
auto new_concat =
......@@ -298,7 +294,7 @@ TEST_CASE(concat_transpose2)
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
run_pass(p);
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
auto new_concat =
......@@ -320,7 +316,7 @@ TEST_CASE(concat_transpose3)
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
run_pass(p);
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
auto new_concat =
......@@ -341,7 +337,7 @@ TEST_CASE(nested_concat)
p.add_instruction(pass_op{}, concat3);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
run_pass(p);
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
......@@ -361,7 +357,7 @@ TEST_CASE(nested_concat_partial)
p.add_instruction(pass_op{}, concat3);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
run_pass(p);
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
......
......@@ -11,6 +11,7 @@
#include <migraphx/context.hpp>
#include <migraphx/pass.hpp>
#include <migraphx/config.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/rank.hpp>
......@@ -28,9 +29,10 @@ struct target
* @brief The transformation pass to be run during compilation.
*
* @param ctx This is the target-dependent context that is created by `get_context`
* @param options Compiling options passed in by the user
* @return The passes to be ran
*/
std::vector<pass> get_passes(context& ctx) const;
std::vector<pass> get_passes(context& ctx, const compile_options& options) const;
/**
* @brief Construct a context for the target.
* @return The context to be used during compilation and execution.
......@@ -119,7 +121,7 @@ argument copy_from_target(T& x, const argument& arg)
<%
interface('target',
virtual('name', returns='std::string', const=True),
virtual('get_passes', ctx='context&', returns='std::vector<pass>', const=True),
virtual('get_passes', ctx='context&', options='const compile_options&', returns='std::vector<pass>', const=True),
virtual('get_context', returns='context', const=True),
virtual('copy_to',
returns = 'argument',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment