Unverified Commit 65c5581f authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'master' into identity

parents 453fa37a f04a3ba6
#include <migraph/auto_contiguous.hpp> #include <migraphx/auto_contiguous.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
struct contiguous_target struct contiguous_target
{ {
std::string name() const { return "contiguous"; } std::string name() const { return "contiguous"; }
std::vector<migraph::pass> get_passes(migraph::context&) const std::vector<migraphx::pass> get_passes(migraphx::context&) const
{ {
return {migraph::auto_contiguous{}}; return {migraphx::auto_contiguous{}};
} }
migraph::context get_context() const { return {}; } migraphx::context get_context() const { return {}; }
}; };
// TODO: Add this test case // TODO: Add this test case
void literal_broadcast() void literal_broadcast()
{ {
migraph::program p; migraphx::program p;
p.add_literal(get_2_broadcasted()); p.add_literal(get_2_broadcasted());
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().broadcasted()); EXPECT(p.get_shape().broadcasted());
...@@ -28,7 +28,7 @@ void literal_broadcast() ...@@ -28,7 +28,7 @@ void literal_broadcast()
TEST_CASE(literal_transpose) TEST_CASE(literal_transpose)
{ {
migraph::program p; migraphx::program p;
p.add_literal(get_2x2_transposed()); p.add_literal(get_2x2_transposed());
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed()); EXPECT(p.get_shape().transposed());
...@@ -39,11 +39,11 @@ TEST_CASE(literal_transpose) ...@@ -39,11 +39,11 @@ TEST_CASE(literal_transpose)
TEST_CASE(after_literal_transpose) TEST_CASE(after_literal_transpose)
{ {
migraph::program p; migraphx::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
auto t = p.add_instruction(migraph::op::transpose{{1, 0}}, l); auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
p.add_instruction(pass_op{}, t); p.add_instruction(pass_op{}, t);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed()); EXPECT(p.get_shape().transposed());
...@@ -54,12 +54,12 @@ TEST_CASE(after_literal_transpose) ...@@ -54,12 +54,12 @@ TEST_CASE(after_literal_transpose)
TEST_CASE(after_literal_broadcast) TEST_CASE(after_literal_broadcast)
{ {
migraph::program p; migraphx::program p;
auto l1 = p.add_literal(get_2x2()); auto l1 = p.add_literal(get_2x2());
auto l2 = p.add_literal(get_2()); auto l2 = p.add_literal(get_2());
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted()); EXPECT(not p.get_shape().broadcasted());
auto b = p.add_instruction(migraph::op::broadcast{0, l1->get_shape()}, l2); auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape()}, l2);
p.add_instruction(pass_op{}, b); p.add_instruction(pass_op{}, b);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().broadcasted()); EXPECT(p.get_shape().broadcasted());
...@@ -70,11 +70,11 @@ TEST_CASE(after_literal_broadcast) ...@@ -70,11 +70,11 @@ TEST_CASE(after_literal_broadcast)
TEST_CASE(after_param_transpose) TEST_CASE(after_param_transpose)
{ {
migraph::program p; migraphx::program p;
auto l = p.add_parameter("2x2", {migraph::shape::float_type, {2, 2}}); auto l = p.add_parameter("2x2", {migraphx::shape::float_type, {2, 2}});
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
auto t = p.add_instruction(migraph::op::transpose{{1, 0}}, l); auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
p.add_instruction(pass_op{}, t); p.add_instruction(pass_op{}, t);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed()); EXPECT(p.get_shape().transposed());
...@@ -85,12 +85,12 @@ TEST_CASE(after_param_transpose) ...@@ -85,12 +85,12 @@ TEST_CASE(after_param_transpose)
TEST_CASE(after_param_broadcast) TEST_CASE(after_param_broadcast)
{ {
migraph::program p; migraphx::program p;
auto l1 = p.add_parameter("2x2", {migraph::shape::float_type, {2, 2}}); auto l1 = p.add_parameter("2x2", {migraphx::shape::float_type, {2, 2}});
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {2}}); auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {2}});
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted()); EXPECT(not p.get_shape().broadcasted());
auto b = p.add_instruction(migraph::op::broadcast{0, l1->get_shape()}, l2); auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape()}, l2);
p.add_instruction(pass_op{}, b); p.add_instruction(pass_op{}, b);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().broadcasted()); EXPECT(p.get_shape().broadcasted());
......
#include <migraph/common_subexpression_elimination.hpp> #include <migraphx/common_subexpression_elimination.hpp>
#include <migraph/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
struct cse_target struct cse_target
{ {
std::string name() const { return "dce"; } std::string name() const { return "dce"; }
std::vector<migraph::pass> get_passes(migraph::context&) const std::vector<migraphx::pass> get_passes(migraphx::context&) const
{ {
return {migraph::common_subexpression_elimination{}, migraph::dead_code_elimination{}}; return {migraphx::common_subexpression_elimination{}, migraphx::dead_code_elimination{}};
} }
migraph::context get_context() const { return {}; } migraphx::context get_context() const { return {}; }
}; };
TEST_CASE(cse_test1) TEST_CASE(cse_test1)
{ {
migraph::program p1; migraphx::program p1;
{ {
auto one = p1.add_literal(1); auto one = p1.add_literal(1);
auto two = p1.add_literal(2); auto two = p1.add_literal(2);
auto sum1 = p1.add_instruction(migraph::op::add{}, one, two); auto sum1 = p1.add_instruction(migraphx::op::add{}, one, two);
auto sum2 = p1.add_instruction(migraph::op::add{}, one, two); auto sum2 = p1.add_instruction(migraphx::op::add{}, one, two);
auto sum3 = p1.add_instruction(migraph::op::add{}, sum1, sum2); auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3); p1.add_instruction(pass_op{}, sum3);
} }
p1.compile(cse_target{}); p1.compile(cse_target{});
migraph::program p2; migraphx::program p2;
{ {
auto one = p2.add_literal(1); auto one = p2.add_literal(1);
auto two = p2.add_literal(2); auto two = p2.add_literal(2);
auto sum1 = p2.add_instruction(migraph::op::add{}, one, two); auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two);
auto sum3 = p2.add_instruction(migraph::op::add{}, sum1, sum1); auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, sum1);
p2.add_instruction(pass_op{}, sum3); p2.add_instruction(pass_op{}, sum3);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
...@@ -40,24 +40,24 @@ TEST_CASE(cse_test1) ...@@ -40,24 +40,24 @@ TEST_CASE(cse_test1)
TEST_CASE(cse_test2) TEST_CASE(cse_test2)
{ {
migraph::program p1; migraphx::program p1;
{ {
auto one = p1.add_literal(1); auto one = p1.add_literal(1);
auto two = p1.add_literal(2); auto two = p1.add_literal(2);
auto sum1 = p1.add_instruction(migraph::op::add{}, one, two); auto sum1 = p1.add_instruction(migraphx::op::add{}, one, two);
auto sum2 = p1.add_instruction(migraph::op::add{}, two, one); auto sum2 = p1.add_instruction(migraphx::op::add{}, two, one);
auto sum3 = p1.add_instruction(migraph::op::add{}, sum1, sum2); auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3); p1.add_instruction(pass_op{}, sum3);
} }
p1.compile(cse_target{}); p1.compile(cse_target{});
migraph::program p2; migraphx::program p2;
{ {
auto one = p2.add_literal(1); auto one = p2.add_literal(1);
auto two = p2.add_literal(2); auto two = p2.add_literal(2);
auto sum1 = p2.add_instruction(migraph::op::add{}, one, two); auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two);
auto sum2 = p2.add_instruction(migraph::op::add{}, two, one); auto sum2 = p2.add_instruction(migraphx::op::add{}, two, one);
auto sum3 = p2.add_instruction(migraph::op::add{}, sum1, sum2); auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, sum2);
p2.add_instruction(pass_op{}, sum3); p2.add_instruction(pass_op{}, sum3);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
...@@ -65,22 +65,22 @@ TEST_CASE(cse_test2) ...@@ -65,22 +65,22 @@ TEST_CASE(cse_test2)
TEST_CASE(cse_test3) TEST_CASE(cse_test3)
{ {
migraph::program p1; migraphx::program p1;
{ {
auto one = p1.add_literal(1); auto one = p1.add_literal(1);
auto two = p1.add_literal(1); auto two = p1.add_literal(1);
auto sum1 = p1.add_instruction(migraph::op::add{}, one, two); auto sum1 = p1.add_instruction(migraphx::op::add{}, one, two);
auto sum2 = p1.add_instruction(migraph::op::add{}, two, one); auto sum2 = p1.add_instruction(migraphx::op::add{}, two, one);
auto sum3 = p1.add_instruction(migraph::op::add{}, sum1, sum2); auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3); p1.add_instruction(pass_op{}, sum3);
} }
p1.compile(cse_target{}); p1.compile(cse_target{});
migraph::program p2; migraphx::program p2;
{ {
auto one = p2.add_literal(1); auto one = p2.add_literal(1);
auto sum1 = p2.add_instruction(migraph::op::add{}, one, one); auto sum1 = p2.add_instruction(migraphx::op::add{}, one, one);
auto sum3 = p2.add_instruction(migraph::op::add{}, sum1, sum1); auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, sum1);
p2.add_instruction(pass_op{}, sum3); p2.add_instruction(pass_op{}, sum3);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
...@@ -88,25 +88,25 @@ TEST_CASE(cse_test3) ...@@ -88,25 +88,25 @@ TEST_CASE(cse_test3)
TEST_CASE(cse_test4) TEST_CASE(cse_test4)
{ {
migraph::program p1; migraphx::program p1;
{ {
auto one = p1.add_literal(1); auto one = p1.add_literal(1);
auto two = p1.add_literal(1); auto two = p1.add_literal(1);
auto sum1 = p1.add_instruction(migraph::op::add{}, one, two); auto sum1 = p1.add_instruction(migraphx::op::add{}, one, two);
auto sum2 = p1.add_instruction(migraph::op::add{}, two, one); auto sum2 = p1.add_instruction(migraphx::op::add{}, two, one);
auto sum3 = p1.add_instruction(migraph::op::add{}, sum1, one); auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, one);
auto sum4 = p1.add_instruction(migraph::op::add{}, sum2, two); auto sum4 = p1.add_instruction(migraphx::op::add{}, sum2, two);
auto sum5 = p1.add_instruction(migraph::op::add{}, sum4, sum3); auto sum5 = p1.add_instruction(migraphx::op::add{}, sum4, sum3);
p1.add_instruction(pass_op{}, sum5); p1.add_instruction(pass_op{}, sum5);
} }
p1.compile(cse_target{}); p1.compile(cse_target{});
migraph::program p2; migraphx::program p2;
{ {
auto one = p2.add_literal(1); auto one = p2.add_literal(1);
auto sum1 = p2.add_instruction(migraph::op::add{}, one, one); auto sum1 = p2.add_instruction(migraphx::op::add{}, one, one);
auto sum3 = p2.add_instruction(migraph::op::add{}, sum1, one); auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, one);
auto sum5 = p2.add_instruction(migraph::op::add{}, sum3, sum3); auto sum5 = p2.add_instruction(migraphx::op::add{}, sum3, sum3);
p2.add_instruction(pass_op{}, sum5); p2.add_instruction(pass_op{}, sum5);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
......
#include <migraph/constant_propagate.hpp> #include <migraphx/constant_propagate.hpp>
#include <migraph/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
struct const_prop_target struct const_prop_target
{ {
std::string name() const { return "const_prop"; } std::string name() const { return "const_prop"; }
std::vector<migraph::pass> get_passes(migraph::context&) const std::vector<migraphx::pass> get_passes(migraphx::context&) const
{ {
return {migraph::constant_propagate{}, migraph::dead_code_elimination{}}; return {migraphx::constant_propagate{}, migraphx::dead_code_elimination{}};
} }
migraph::context get_context() const { return {}; } migraphx::context get_context() const { return {}; }
}; };
TEST_CASE(const_add1) TEST_CASE(const_add1)
{ {
migraph::program p1; migraphx::program p1;
auto one = p1.add_literal(1); auto one = p1.add_literal(1);
auto two = p1.add_literal(2); auto two = p1.add_literal(2);
auto sum = p1.add_instruction(migraph::op::add{}, one, two); auto sum = p1.add_instruction(migraphx::op::add{}, one, two);
p1.add_instruction(pass_op{}, sum); p1.add_instruction(pass_op{}, sum);
p1.compile(const_prop_target{}); p1.compile(const_prop_target{});
migraph::program p2; migraphx::program p2;
auto total = p2.add_literal(3); auto total = p2.add_literal(3);
p2.add_instruction(pass_op{}, total); p2.add_instruction(pass_op{}, total);
EXPECT(p1 == p2); EXPECT(p1 == p2);
...@@ -31,14 +31,14 @@ TEST_CASE(const_add1) ...@@ -31,14 +31,14 @@ TEST_CASE(const_add1)
TEST_CASE(const_add2) TEST_CASE(const_add2)
{ {
migraph::program p1; migraphx::program p1;
auto one = p1.add_parameter("one", {migraph::shape::int32_type, {1}}); auto one = p1.add_parameter("one", {migraphx::shape::int32_type, {1}});
auto two = p1.add_literal(2); auto two = p1.add_literal(2);
auto sum = p1.add_instruction(migraph::op::add{}, one, two); auto sum = p1.add_instruction(migraphx::op::add{}, one, two);
p1.add_instruction(pass_op{}, sum); p1.add_instruction(pass_op{}, sum);
p1.compile(const_prop_target{}); p1.compile(const_prop_target{});
migraph::program p2; migraphx::program p2;
auto total = p2.add_literal(3); auto total = p2.add_literal(3);
p2.add_instruction(pass_op{}, total); p2.add_instruction(pass_op{}, total);
EXPECT(p1 != p2); EXPECT(p1 != p2);
...@@ -46,15 +46,15 @@ TEST_CASE(const_add2) ...@@ -46,15 +46,15 @@ TEST_CASE(const_add2)
TEST_CASE(const_add3) TEST_CASE(const_add3)
{ {
migraph::program p1; migraphx::program p1;
auto one = p1.add_literal(1); auto one = p1.add_literal(1);
auto two = p1.add_literal(2); auto two = p1.add_literal(2);
auto sum1 = p1.add_instruction(migraph::op::add{}, one, two); auto sum1 = p1.add_instruction(migraphx::op::add{}, one, two);
auto sum2 = p1.add_instruction(migraph::op::add{}, sum1, two); auto sum2 = p1.add_instruction(migraphx::op::add{}, sum1, two);
p1.add_instruction(pass_op{}, sum2); p1.add_instruction(pass_op{}, sum2);
p1.compile(const_prop_target{}); p1.compile(const_prop_target{});
migraph::program p2; migraphx::program p2;
auto total = p2.add_literal(5); auto total = p2.add_literal(5);
p2.add_instruction(pass_op{}, total); p2.add_instruction(pass_op{}, total);
EXPECT(p1 == p2); EXPECT(p1 == p2);
......
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <migraph/literal.hpp> #include <migraphx/literal.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraph/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraph/verify.hpp> #include <migraphx/verify.hpp>
#include "test.hpp" #include "test.hpp"
TEST_CASE(slice_test) TEST_CASE(slice_test)
{ {
{ {
migraph::program p; migraphx::program p;
std::vector<int> data(2 * 2 * 3); std::vector<int> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0); std::iota(data.begin(), data.end(), 0);
migraph::shape s{migraph::shape::int32_type, {2, 2, 3}}; migraphx::shape s{migraphx::shape::int32_type, {2, 2, 3}};
auto l0 = p.add_literal(migraph::literal{s, data}); auto l0 = p.add_literal(migraphx::literal{s, data});
p.add_instruction(migraph::op::slice{{2}, {1}, {3}}, l0); p.add_instruction(migraphx::op::slice{{2}, {1}, {3}}, l0);
migraph::shape s2{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}}; migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}};
EXPECT(p.get_shape() == s2); EXPECT(p.get_shape() == s2);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
migraph::shape sresult{migraph::shape::int32_type, {2, 2, 2}, {4, 2, 1}}; migraphx::shape sresult{migraphx::shape::int32_type, {2, 2, 2}, {4, 2, 1}};
auto result = p.eval({}); auto result = p.eval({});
std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11}; std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int> results_vector(2 * 2 * 2); std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(result.get_shape() == sresult); EXPECT(result.get_shape() == sresult);
} }
{ {
migraph::program p; migraphx::program p;
std::vector<int> data(2 * 2 * 3); std::vector<int> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0); std::iota(data.begin(), data.end(), 0);
migraph::shape s{migraph::shape::int32_type, {2, 2, 3}}; migraphx::shape s{migraphx::shape::int32_type, {2, 2, 3}};
auto l0 = p.add_literal(migraph::literal{s, data}); auto l0 = p.add_literal(migraphx::literal{s, data});
p.add_instruction(migraph::op::slice{{0, 1, 2}, {0, 0, 0}, {2, 2, 2}}, l0); p.add_instruction(migraphx::op::slice{{0, 1, 2}, {0, 0, 0}, {2, 2, 2}}, l0);
migraph::shape s2{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}}; migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}};
EXPECT(p.get_shape() == s2); EXPECT(p.get_shape() == s2);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
migraph::shape sresult{migraph::shape::int32_type, {2, 2, 2}, {4, 2, 1}}; migraphx::shape sresult{migraphx::shape::int32_type, {2, 2, 2}, {4, 2, 1}};
auto result = p.eval({}); auto result = p.eval({});
std::vector<int> gold = {0, 1, 3, 4, 6, 7, 9, 10}; std::vector<int> gold = {0, 1, 3, 4, 6, 7, 9, 10};
std::vector<int> results_vector(2 * 2 * 2); std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(result.get_shape() == sresult); EXPECT(result.get_shape() == sresult);
} }
} }
...@@ -50,85 +50,85 @@ TEST_CASE(slice_test) ...@@ -50,85 +50,85 @@ TEST_CASE(slice_test)
TEST_CASE(concat_test) TEST_CASE(concat_test)
{ {
{ {
migraph::program p; migraphx::program p;
std::size_t axis = 1; std::size_t axis = 1;
std::vector<int> data0 = {0, 1, 5, 6}; std::vector<int> data0 = {0, 1, 5, 6};
std::vector<int> data1 = {2, 3, 4, 7, 8, 9}; std::vector<int> data1 = {2, 3, 4, 7, 8, 9};
std::vector<int> data2 = {10, 20}; std::vector<int> data2 = {10, 20};
migraph::shape s0{migraph::shape::int32_type, {2, 2}}; migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraph::shape s1{migraph::shape::int32_type, {2, 3}}; migraphx::shape s1{migraphx::shape::int32_type, {2, 3}};
migraph::shape s2{migraph::shape::int32_type, {2, 1}}; migraphx::shape s2{migraphx::shape::int32_type, {2, 1}};
auto l0 = p.add_literal(migraph::literal{s0, data0}); auto l0 = p.add_literal(migraphx::literal{s0, data0});
auto l1 = p.add_literal(migraph::literal{s1, data1}); auto l1 = p.add_literal(migraphx::literal{s1, data1});
auto l2 = p.add_literal(migraph::literal{s2, data2}); auto l2 = p.add_literal(migraphx::literal{s2, data2});
p.add_instruction(migraph::op::concat{axis}, l0, l1, l2); p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<int> gold = {0, 1, 2, 3, 4, 10, 5, 6, 7, 8, 9, 20}; std::vector<int> gold = {0, 1, 2, 3, 4, 10, 5, 6, 7, 8, 9, 20};
std::vector<int> results_vector(2 * 6); std::vector<int> results_vector(2 * 6);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(migraph::verify_range(result.get_shape().lens(), std::vector<std::size_t>({2, 6}))); EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector<std::size_t>({2, 6})));
EXPECT( EXPECT(
migraph::verify_range(result.get_shape().strides(), std::vector<std::size_t>({6, 1}))); migraphx::verify_range(result.get_shape().strides(), std::vector<std::size_t>({6, 1})));
} }
{ {
migraph::program p; migraphx::program p;
std::size_t axis = 0; std::size_t axis = 0;
std::vector<int> data0 = {0, 1, 2, 3}; std::vector<int> data0 = {0, 1, 2, 3};
std::vector<int> data1 = {4, 5, 6, 7, 8, 9}; std::vector<int> data1 = {4, 5, 6, 7, 8, 9};
std::vector<int> data2 = {10, 11}; std::vector<int> data2 = {10, 11};
migraph::shape s0{migraph::shape::int32_type, {2, 2}}; migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraph::shape s1{migraph::shape::int32_type, {3, 2}}; migraphx::shape s1{migraphx::shape::int32_type, {3, 2}};
migraph::shape s2{migraph::shape::int32_type, {1, 2}}; migraphx::shape s2{migraphx::shape::int32_type, {1, 2}};
auto l0 = p.add_literal(migraph::literal{s0, data0}); auto l0 = p.add_literal(migraphx::literal{s0, data0});
auto l1 = p.add_literal(migraph::literal{s1, data1}); auto l1 = p.add_literal(migraphx::literal{s1, data1});
auto l2 = p.add_literal(migraph::literal{s2, data2}); auto l2 = p.add_literal(migraphx::literal{s2, data2});
p.add_instruction(migraph::op::concat{axis}, l0, l1, l2); p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<int> gold = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; std::vector<int> gold = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
std::vector<int> results_vector(6 * 2); std::vector<int> results_vector(6 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(migraph::verify_range(result.get_shape().lens(), std::vector<std::size_t>({6, 2}))); EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector<std::size_t>({6, 2})));
EXPECT( EXPECT(
migraph::verify_range(result.get_shape().strides(), std::vector<std::size_t>({2, 1}))); migraphx::verify_range(result.get_shape().strides(), std::vector<std::size_t>({2, 1})));
} }
} }
TEST_CASE(squeeze_test) TEST_CASE(squeeze_test)
{ {
{ {
migraph::program p; migraphx::program p;
std::vector<float> data(4 * 3 * 3); std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4, 1, 3, 1, 3}}; migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
migraph::shape s2{migraph::shape::float_type, {4, 3, 1, 3}}; migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data}); auto l0 = p.add_literal(migraphx::literal{s1, data});
p.add_instruction(migraph::op::squeeze{{1}}, l0); p.add_instruction(migraphx::op::squeeze{{1}}, l0);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result.get_shape() == s2); EXPECT(result.get_shape() == s2);
} }
{ {
migraph::program p; migraphx::program p;
std::vector<float> data(4 * 3 * 3); std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4, 1, 3, 1, 3}}; migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
migraph::shape s2{migraph::shape::float_type, {4, 1, 3, 3}}; migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data}); auto l0 = p.add_literal(migraphx::literal{s1, data});
p.add_instruction(migraph::op::squeeze{{3}}, l0); p.add_instruction(migraphx::op::squeeze{{3}}, l0);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result.get_shape() == s2); EXPECT(result.get_shape() == s2);
} }
{ {
migraph::program p; migraphx::program p;
std::vector<float> data(4 * 3 * 3); std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4, 1, 3, 1, 3}}; migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
migraph::shape s2{migraph::shape::float_type, {4, 3, 3}}; migraphx::shape s2{migraphx::shape::float_type, {4, 3, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data}); auto l0 = p.add_literal(migraphx::literal{s1, data});
p.add_instruction(migraph::op::squeeze{}, l0); p.add_instruction(migraphx::op::squeeze{}, l0);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result.get_shape() == s2); EXPECT(result.get_shape() == s2);
} }
...@@ -137,24 +137,24 @@ TEST_CASE(squeeze_test) ...@@ -137,24 +137,24 @@ TEST_CASE(squeeze_test)
TEST_CASE(unsqueeze_test) TEST_CASE(unsqueeze_test)
{ {
{ {
migraph::program p; migraphx::program p;
std::vector<float> data(4 * 3 * 3); std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4, 3, 3}}; migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
migraph::shape s2{migraph::shape::float_type, {4, 1, 3, 3}}; migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data}); auto l0 = p.add_literal(migraphx::literal{s1, data});
p.add_instruction(migraph::op::unsqueeze{{1}}, l0); p.add_instruction(migraphx::op::unsqueeze{{1}}, l0);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result.get_shape() == s2); EXPECT(result.get_shape() == s2);
} }
{ {
migraph::program p; migraphx::program p;
std::vector<float> data(4 * 3 * 3); std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4, 3, 3}}; migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
migraph::shape s2{migraph::shape::float_type, {4, 3, 1, 3}}; migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data}); auto l0 = p.add_literal(migraphx::literal{s1, data});
p.add_instruction(migraph::op::unsqueeze{{2}}, l0); p.add_instruction(migraphx::op::unsqueeze{{2}}, l0);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result.get_shape() == s2); EXPECT(result.get_shape() == s2);
} }
...@@ -162,42 +162,42 @@ TEST_CASE(unsqueeze_test) ...@@ -162,42 +162,42 @@ TEST_CASE(unsqueeze_test)
TEST_CASE(globalavgpool_test) TEST_CASE(globalavgpool_test)
{ {
migraph::program p; migraphx::program p;
auto s = migraph::shape{migraph::shape::float_type, {1, 3, 2, 2}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}};
auto op = migraph::op::pooling{"average"}; auto op = migraphx::op::pooling{"average"};
auto lens = s.lens(); auto lens = s.lens();
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6}; std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = p.add_literal(migraph::literal{s, data}); auto l0 = p.add_literal(migraphx::literal{s, data});
p.add_instruction(op, l0); p.add_instruction(op, l0);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.25, 0.575, 0.375}; std::vector<float> gold{0.25, 0.575, 0.375};
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(globalmaxpool_test) TEST_CASE(globalmaxpool_test)
{ {
migraph::program p; migraphx::program p;
auto s = migraph::shape{migraph::shape::float_type, {1, 3, 2, 2}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}};
auto op = migraph::op::pooling{"max"}; auto op = migraphx::op::pooling{"max"};
auto lens = s.lens(); auto lens = s.lens();
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6}; std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = p.add_literal(migraph::literal{s, data}); auto l0 = p.add_literal(migraphx::literal{s, data});
p.add_instruction(op, l0); p.add_instruction(op, l0);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.4, 0.9, 0.7}; std::vector<float> gold{0.4, 0.9, 0.7};
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(im2col_3x3_no_pad_identity_test) TEST_CASE(im2col_3x3_no_pad_identity_test)
...@@ -213,20 +213,20 @@ TEST_CASE(im2col_3x3_no_pad_identity_test) ...@@ -213,20 +213,20 @@ TEST_CASE(im2col_3x3_no_pad_identity_test)
std::vector<int32_t> input(channels * size[0] * size[1]); std::vector<int32_t> input(channels * size[0] * size[1]);
std::iota(input.begin(), input.end(), 0); std::iota(input.begin(), input.end(), 0);
migraph::program p; migraphx::program p;
migraph::shape s_image{migraph::shape::int32_type, {1, channels, size[0], size[1]}}; migraphx::shape s_image{migraphx::shape::int32_type, {1, channels, size[0], size[1]}};
migraph::shape s_weights{migraph::shape::int32_type, {1, channels, f[0], f[1]}}; migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}};
auto l_image = p.add_literal(migraph::literal{s_image, input}); auto l_image = p.add_literal(migraphx::literal{s_image, input});
auto l_weights = p.add_literal(migraph::literal{s_weights, weights}); auto l_weights = p.add_literal(migraphx::literal{s_weights, weights});
p.add_instruction(migraph::op::im2col{padding, stride, dilation}, l_image, l_weights); p.add_instruction(migraphx::op::im2col{padding, stride, dilation}, l_image, l_weights);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::size_t col_height = (size[0] - f[0] + 2 * padding[0]) / stride[0] + 1; std::size_t col_height = (size[0] - f[0] + 2 * padding[0]) / stride[0] + 1;
std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1; std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1;
std::vector<float> results_vector(channels * f[0] * f[1] * col_height * col_width); std::vector<float> results_vector(channels * f[0] * f[1] * col_height * col_width);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, input)); EXPECT(migraphx::verify_range(results_vector, input));
} }
TEST_CASE(im2col_3x3_no_pad_test) TEST_CASE(im2col_3x3_no_pad_test)
...@@ -242,13 +242,13 @@ TEST_CASE(im2col_3x3_no_pad_test) ...@@ -242,13 +242,13 @@ TEST_CASE(im2col_3x3_no_pad_test)
std::vector<int32_t> input(channels * size[0] * size[1]); std::vector<int32_t> input(channels * size[0] * size[1]);
std::iota(input.begin(), input.end(), 0); std::iota(input.begin(), input.end(), 0);
migraph::program p; migraphx::program p;
migraph::shape s_image{migraph::shape::int32_type, {1, channels, size[0], size[1]}}; migraphx::shape s_image{migraphx::shape::int32_type, {1, channels, size[0], size[1]}};
migraph::shape s_weights{migraph::shape::int32_type, {1, channels, f[0], f[1]}}; migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}};
auto l_image = p.add_literal(migraph::literal{s_image, input}); auto l_image = p.add_literal(migraphx::literal{s_image, input});
auto l_weights = p.add_literal(migraph::literal{s_weights, weights}); auto l_weights = p.add_literal(migraphx::literal{s_weights, weights});
p.add_instruction(migraph::op::im2col{padding, stride, dilation}, l_image, l_weights); p.add_instruction(migraphx::op::im2col{padding, stride, dilation}, l_image, l_weights);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<int> correct = {0, 1, 2, 4, 5, 6, 8, 9, 10, 1, 2, 3, 5, 6, 7, 9, 10, 11, std::vector<int> correct = {0, 1, 2, 4, 5, 6, 8, 9, 10, 1, 2, 3, 5, 6, 7, 9, 10, 11,
...@@ -258,7 +258,7 @@ TEST_CASE(im2col_3x3_no_pad_test) ...@@ -258,7 +258,7 @@ TEST_CASE(im2col_3x3_no_pad_test)
std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1; std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1;
std::vector<float> results_vector(channels * f[0] * f[1] * col_height * col_width); std::vector<float> results_vector(channels * f[0] * f[1] * col_height * col_width);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, correct)); EXPECT(migraphx::verify_range(results_vector, correct));
} }
TEST_CASE(im2col_3x3_stride_2_no_pad_test) TEST_CASE(im2col_3x3_stride_2_no_pad_test)
...@@ -274,13 +274,13 @@ TEST_CASE(im2col_3x3_stride_2_no_pad_test) ...@@ -274,13 +274,13 @@ TEST_CASE(im2col_3x3_stride_2_no_pad_test)
std::vector<int32_t> input(channels * size[0] * size[1]); std::vector<int32_t> input(channels * size[0] * size[1]);
std::iota(input.begin(), input.end(), 0); std::iota(input.begin(), input.end(), 0);
migraph::program p; migraphx::program p;
migraph::shape s_image{migraph::shape::int32_type, {1, channels, size[0], size[1]}}; migraphx::shape s_image{migraphx::shape::int32_type, {1, channels, size[0], size[1]}};
migraph::shape s_weights{migraph::shape::int32_type, {1, channels, f[0], f[1]}}; migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}};
auto l_image = p.add_literal(migraph::literal{s_image, input}); auto l_image = p.add_literal(migraphx::literal{s_image, input});
auto l_weights = p.add_literal(migraph::literal{s_weights, weights}); auto l_weights = p.add_literal(migraphx::literal{s_weights, weights});
p.add_instruction(migraph::op::im2col{padding, stride, dilation}, l_image, l_weights); p.add_instruction(migraphx::op::im2col{padding, stride, dilation}, l_image, l_weights);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<int> correct = {0, 1, 2, 6, 7, 8, 12, 13, 14, 2, 3, 4, std::vector<int> correct = {0, 1, 2, 6, 7, 8, 12, 13, 14, 2, 3, 4,
...@@ -291,7 +291,7 @@ TEST_CASE(im2col_3x3_stride_2_no_pad_test) ...@@ -291,7 +291,7 @@ TEST_CASE(im2col_3x3_stride_2_no_pad_test)
std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1; std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1;
std::vector<float> results_vector(channels * f[0] * f[1] * col_height * col_width); std::vector<float> results_vector(channels * f[0] * f[1] * col_height * col_width);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, correct)); EXPECT(migraphx::verify_range(results_vector, correct));
} }
TEST_CASE(im2col_3x3_with_padding_test) TEST_CASE(im2col_3x3_with_padding_test)
...@@ -307,13 +307,13 @@ TEST_CASE(im2col_3x3_with_padding_test) ...@@ -307,13 +307,13 @@ TEST_CASE(im2col_3x3_with_padding_test)
std::vector<int32_t> input(channels * size[0] * size[1]); std::vector<int32_t> input(channels * size[0] * size[1]);
std::iota(input.begin(), input.end(), 0); std::iota(input.begin(), input.end(), 0);
migraph::program p; migraphx::program p;
migraph::shape s_image{migraph::shape::int32_type, {1, channels, size[0], size[1]}}; migraphx::shape s_image{migraphx::shape::int32_type, {1, channels, size[0], size[1]}};
migraph::shape s_weights{migraph::shape::int32_type, {1, channels, f[0], f[1]}}; migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}};
auto l_image = p.add_literal(migraph::literal{s_image, input}); auto l_image = p.add_literal(migraphx::literal{s_image, input});
auto l_weights = p.add_literal(migraph::literal{s_weights, weights}); auto l_weights = p.add_literal(migraphx::literal{s_weights, weights});
p.add_instruction(migraph::op::im2col{padding, stride, dilation}, l_image, l_weights); p.add_instruction(migraphx::op::im2col{padding, stride, dilation}, l_image, l_weights);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<int> correct = {0, 0, 0, 0, 0, 1, 0, 2, 3, 0, 0, 0, 0, 1, 0, 2, 3, 0, std::vector<int> correct = {0, 0, 0, 0, 0, 1, 0, 2, 3, 0, 0, 0, 0, 1, 0, 2, 3, 0,
...@@ -323,19 +323,19 @@ TEST_CASE(im2col_3x3_with_padding_test) ...@@ -323,19 +323,19 @@ TEST_CASE(im2col_3x3_with_padding_test)
std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1; std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1;
std::vector<float> results_vector(channels * f[0] * f[1] * col_height * col_width); std::vector<float> results_vector(channels * f[0] * f[1] * col_height * col_width);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, correct)); EXPECT(migraphx::verify_range(results_vector, correct));
} }
TEST_CASE(batch_norm_inference_test) TEST_CASE(batch_norm_inference_test)
{ {
migraph::program p; migraphx::program p;
const size_t width = 2, height = 2, channels = 4, batches = 2; const size_t width = 2, height = 2, channels = 4, batches = 2;
const float x_val = 8.0f, mean_val = 2.0f, variance_val = 4.0f, scale_val = 2.0f, const float x_val = 8.0f, mean_val = 2.0f, variance_val = 4.0f, scale_val = 2.0f,
bias_val = 1.0f; bias_val = 1.0f;
const float output_val = scale_val * (x_val - mean_val) / (std::sqrt(variance_val)) + bias_val; const float output_val = scale_val * (x_val - mean_val) / (std::sqrt(variance_val)) + bias_val;
migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}}; migraphx::shape s{migraphx::shape::float_type, {batches, channels, height, width}};
migraph::shape vars{migraph::shape::float_type, {channels}}; migraphx::shape vars{migraphx::shape::float_type, {channels}};
std::vector<float> x_data(width * height * channels * batches); std::vector<float> x_data(width * height * channels * batches);
std::vector<float> scale_data(channels); std::vector<float> scale_data(channels);
std::vector<float> bias_data(channels); std::vector<float> bias_data(channels);
...@@ -348,14 +348,14 @@ TEST_CASE(batch_norm_inference_test) ...@@ -348,14 +348,14 @@ TEST_CASE(batch_norm_inference_test)
std::fill(scale_data.begin(), scale_data.end(), scale_val); std::fill(scale_data.begin(), scale_data.end(), scale_val);
std::fill(bias_data.begin(), bias_data.end(), bias_val); std::fill(bias_data.begin(), bias_data.end(), bias_val);
auto x = p.add_literal(migraph::literal{s, x_data}); auto x = p.add_literal(migraphx::literal{s, x_data});
auto scale = p.add_literal(migraph::literal{vars, scale_data}); auto scale = p.add_literal(migraphx::literal{vars, scale_data});
auto bias = p.add_literal(migraph::literal{vars, bias_data}); auto bias = p.add_literal(migraphx::literal{vars, bias_data});
auto mean = p.add_literal(migraph::literal{vars, mean_data}); auto mean = p.add_literal(migraphx::literal{vars, mean_data});
auto variance = p.add_literal(migraph::literal{vars, variance_data}); auto variance = p.add_literal(migraphx::literal{vars, variance_data});
p.add_instruction(migraph::op::batch_norm_inference{}, x, scale, bias, mean, variance); p.add_instruction(migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> result_vector(width * height * channels * batches); std::vector<float> result_vector(width * height * channels * batches);
...@@ -363,7 +363,7 @@ TEST_CASE(batch_norm_inference_test) ...@@ -363,7 +363,7 @@ TEST_CASE(batch_norm_inference_test)
std::fill(gold.begin(), gold.end(), output_val); std::fill(gold.begin(), gold.end(), output_val);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(im2col_3x3_with_channels_identity_test) TEST_CASE(im2col_3x3_with_channels_identity_test)
...@@ -379,105 +379,105 @@ TEST_CASE(im2col_3x3_with_channels_identity_test) ...@@ -379,105 +379,105 @@ TEST_CASE(im2col_3x3_with_channels_identity_test)
std::vector<int32_t> input(channels * size[0] * size[1]); std::vector<int32_t> input(channels * size[0] * size[1]);
std::iota(input.begin(), input.end(), 0); std::iota(input.begin(), input.end(), 0);
migraph::program p; migraphx::program p;
migraph::shape s_image{migraph::shape::int32_type, {1, channels, size[0], size[1]}}; migraphx::shape s_image{migraphx::shape::int32_type, {1, channels, size[0], size[1]}};
migraph::shape s_weights{migraph::shape::int32_type, {1, channels, f[0], f[1]}}; migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}};
auto l_image = p.add_literal(migraph::literal{s_image, input}); auto l_image = p.add_literal(migraphx::literal{s_image, input});
auto l_weights = p.add_literal(migraph::literal{s_weights, weights}); auto l_weights = p.add_literal(migraphx::literal{s_weights, weights});
p.add_instruction(migraph::op::im2col{padding, stride, dilation}, l_image, l_weights); p.add_instruction(migraphx::op::im2col{padding, stride, dilation}, l_image, l_weights);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::size_t col_height = (size[0] - f[0] + 2 * padding[0]) / stride[0] + 1; std::size_t col_height = (size[0] - f[0] + 2 * padding[0]) / stride[0] + 1;
std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1; std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1;
std::vector<float> results_vector(channels * f[0] * f[1] * col_height * col_width); std::vector<float> results_vector(channels * f[0] * f[1] * col_height * col_width);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, input)); EXPECT(migraphx::verify_range(results_vector, input));
} }
TEST_CASE(exp_test) TEST_CASE(exp_test)
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = p.add_literal(migraph::literal{s, {-1, 0, 1}}); auto l = p.add_literal(migraphx::literal{s, {-1, 0, 1}});
p.add_instruction(migraph::op::exp{}, l); p.add_instruction(migraphx::op::exp{}, l);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.36787944f, 1.f, 2.71828183f}; std::vector<float> gold = {0.36787944f, 1.f, 2.71828183f};
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(sin_test) TEST_CASE(sin_test)
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = p.add_literal(migraph::literal{s, {-1, 0, 1}}); auto l = p.add_literal(migraphx::literal{s, {-1, 0, 1}});
p.add_instruction(migraph::op::sin{}, l); p.add_instruction(migraphx::op::sin{}, l);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-0.84147098f, 0.f, 0.84147098f}; std::vector<float> gold = {-0.84147098f, 0.f, 0.84147098f};
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(cos_test) TEST_CASE(cos_test)
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = p.add_literal(migraph::literal{s, {-1, 0, 1}}); auto l = p.add_literal(migraphx::literal{s, {-1, 0, 1}});
p.add_instruction(migraph::op::cos{}, l); p.add_instruction(migraphx::op::cos{}, l);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.54030231f, 1.f, 0.54030231f}; std::vector<float> gold = {0.54030231f, 1.f, 0.54030231f};
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(tan_test) TEST_CASE(tan_test)
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = p.add_literal(migraph::literal{s, {-1, 0, 1}}); auto l = p.add_literal(migraphx::literal{s, {-1, 0, 1}});
p.add_instruction(migraph::op::tan{}, l); p.add_instruction(migraphx::op::tan{}, l);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1.55740772f, 0.0f, 1.55740772f}; std::vector<float> gold = {-1.55740772f, 0.0f, 1.55740772f};
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(add_test) TEST_CASE(add_test)
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto l1 = p.add_literal(migraph::literal{s, {-1, 0, 1}}); auto l1 = p.add_literal(migraphx::literal{s, {-1, 0, 1}});
auto l2 = p.add_literal(migraph::literal{s, {1, 2, 3}}); auto l2 = p.add_literal(migraphx::literal{s, {1, 2, 3}});
p.add_instruction(migraph::op::add{}, l1, l2); p.add_instruction(migraphx::op::add{}, l1, l2);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 2, 4}; std::vector<float> gold = {0, 2, 4};
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(broadcast_test) TEST_CASE(broadcast_test)
{ {
migraph::program p; migraphx::program p;
migraph::shape a_shape{migraph::shape::int32_type, {2, 2}}; migraphx::shape a_shape{migraphx::shape::int32_type, {2, 2}};
std::vector<int32_t> a_data{0, 0, 0, 0}; std::vector<int32_t> a_data{0, 0, 0, 0};
migraph::shape b_shape{migraph::shape::int32_type, {2}}; migraphx::shape b_shape{migraphx::shape::int32_type, {2}};
std::vector<int32_t> b_data{-2, -3}; std::vector<int32_t> b_data{-2, -3};
uint64_t axis = 0; uint64_t axis = 0;
auto l1 = p.add_literal(migraph::literal{a_shape, a_data}); auto l1 = p.add_literal(migraphx::literal{a_shape, a_data});
auto l2 = p.add_literal(migraph::literal{b_shape, b_data}); auto l2 = p.add_literal(migraphx::literal{b_shape, b_data});
p.add_instruction(migraph::op::broadcast{axis, l1->get_shape()}, l2); p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape()}, l2);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
auto output = result.get<int32_t>(); auto output = result.get<int32_t>();
EXPECT(output(0, 0) == -2); EXPECT(output(0, 0) == -2);
...@@ -488,145 +488,145 @@ TEST_CASE(broadcast_test) ...@@ -488,145 +488,145 @@ TEST_CASE(broadcast_test)
TEST_CASE(add_broadcast_test) TEST_CASE(add_broadcast_test)
{ {
{ {
migraph::program p; migraphx::program p;
migraph::shape a_shape{migraph::shape::float_type, {2, 2, 3}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3}};
std::vector<float> a_data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; std::vector<float> a_data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
migraph::shape b_shape{migraph::shape::float_type, {2, 2}}; migraphx::shape b_shape{migraphx::shape::float_type, {2, 2}};
std::vector<float> b_data{0, -1, -2, -3}; std::vector<float> b_data{0, -1, -2, -3};
uint64_t axis = 0; uint64_t axis = 0;
auto l1 = p.add_literal(migraph::literal{a_shape, a_data}); auto l1 = p.add_literal(migraphx::literal{a_shape, a_data});
auto l2 = p.add_literal(migraph::literal{b_shape, b_data}); auto l2 = p.add_literal(migraphx::literal{b_shape, b_data});
auto l3 = p.add_instruction(migraph::op::broadcast{axis, l1->get_shape()}, l2); auto l3 = p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape()}, l2);
p.add_instruction(migraph::op::add{}, l1, l3); p.add_instruction(migraphx::op::add{}, l1, l3);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result.get_shape().packed()); EXPECT(result.get_shape().packed());
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8}; std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8};
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
{ {
migraph::program p; migraphx::program p;
migraph::shape a_shape{migraph::shape::float_type, {2, 2, 3}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3}};
std::vector<float> a_data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; std::vector<float> a_data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
migraph::shape b_shape{migraph::shape::float_type, {2, 2, 1}}; migraphx::shape b_shape{migraphx::shape::float_type, {2, 2, 1}};
std::vector<float> b_data{0, -1, -2, -3}; std::vector<float> b_data{0, -1, -2, -3};
auto l1 = p.add_literal(migraph::literal{a_shape, a_data}); auto l1 = p.add_literal(migraphx::literal{a_shape, a_data});
auto l2 = p.add_literal(migraph::literal{b_shape, b_data}); auto l2 = p.add_literal(migraphx::literal{b_shape, b_data});
auto l3 = p.add_instruction(migraph::op::multibroadcast{{2, 2, 3}}, l1); auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 3}}, l1);
auto l4 = p.add_instruction(migraph::op::multibroadcast{{2, 2, 3}}, l2); auto l4 = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 3}}, l2);
p.add_instruction(migraph::op::add{}, l3, l4); p.add_instruction(migraphx::op::add{}, l3, l4);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result.get_shape().packed()); EXPECT(result.get_shape().packed());
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8}; std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8};
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
} }
TEST_CASE(sub_test) TEST_CASE(sub_test)
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto l1 = p.add_literal(migraph::literal{s, {-1, 0, 1}}); auto l1 = p.add_literal(migraphx::literal{s, {-1, 0, 1}});
auto l2 = p.add_literal(migraph::literal{s, {1, 2, 3}}); auto l2 = p.add_literal(migraphx::literal{s, {1, 2, 3}});
p.add_instruction(migraph::op::sub{}, l1, l2); p.add_instruction(migraphx::op::sub{}, l1, l2);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2, -2, -2}; std::vector<float> gold = {-2, -2, -2};
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(mul_test) TEST_CASE(mul_test)
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto l1 = p.add_literal(migraph::literal{s, {-1, 0, 1}}); auto l1 = p.add_literal(migraphx::literal{s, {-1, 0, 1}});
auto l2 = p.add_literal(migraph::literal{s, {1, 2, 3}}); auto l2 = p.add_literal(migraphx::literal{s, {1, 2, 3}});
p.add_instruction(migraph::op::mul{}, l1, l2); p.add_instruction(migraphx::op::mul{}, l1, l2);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1, 0, 3}; std::vector<float> gold = {-1, 0, 3};
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(div_test) TEST_CASE(div_test)
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto l1 = p.add_literal(migraph::literal{s, {-1.0f, 0.5f, 1.0f}}); auto l1 = p.add_literal(migraphx::literal{s, {-1.0f, 0.5f, 1.0f}});
auto l2 = p.add_literal(migraph::literal{s, {1.0f, 2.0f, 4.0f}}); auto l2 = p.add_literal(migraphx::literal{s, {1.0f, 2.0f, 4.0f}});
p.add_instruction(migraph::op::div{}, l1, l2); p.add_instruction(migraphx::op::div{}, l1, l2);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1.f, 0.25f, 0.25f}; std::vector<float> gold = {-1.f, 0.25f, 0.25f};
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(relu_test) TEST_CASE(relu_test)
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = p.add_literal(migraph::literal{s, {-1.f, 0.f, 1.f}}); auto l = p.add_literal(migraphx::literal{s, {-1.f, 0.f, 1.f}});
p.add_instruction(migraph::op::relu{}, l); p.add_instruction(migraphx::op::relu{}, l);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.f, 0.f, 1.f}; std::vector<float> gold = {0.f, 0.f, 1.f};
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(leaky_relu_test) TEST_CASE(leaky_relu_test)
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = p.add_literal(migraph::literal{s, {-1.f, 0.f, 1.f}}); auto l = p.add_literal(migraphx::literal{s, {-1.f, 0.f, 1.f}});
p.add_instruction(migraph::op::leaky_relu{0.01}, l); p.add_instruction(migraphx::op::leaky_relu{0.01}, l);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-0.01f, 0.f, 1.f}; std::vector<float> gold = {-0.01f, 0.f, 1.f};
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(imagescaler_test) TEST_CASE(imagescaler_test)
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {1, 3, 2, 2}}; migraphx::shape s{migraphx::shape::float_type, {1, 3, 2, 2}};
auto img = p.add_literal(migraph::literal{s, auto img = p.add_literal(migraphx::literal{s,
{0.2, {0.2,
0.3, 0.3,
0.5, 0.5,
0.4, 0.4,
0.7, 0.7,
0.8, 0.8,
0.1, 0.1,
0.9, 0.9,
0.15, 0.15,
0.25, 0.25,
0.35, 0.35,
0.45}}); 0.45}});
auto scale_val = p.add_literal(2.f); auto scale_val = p.add_literal(2.f);
auto scaled_tensor = p.add_instruction(migraph::op::scalar{s}, scale_val); auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s}, scale_val);
auto img_scaled = p.add_instruction(migraph::op::mul{}, img, scaled_tensor); auto img_scaled = p.add_instruction(migraphx::op::mul{}, img, scaled_tensor);
auto bias_vals = p.add_literal( auto bias_vals = p.add_literal(
migraph::literal{migraph::shape{migraph::shape::float_type, {3}}, {0.01, 0.02, 0.03}}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
auto bias_bcast = p.add_instruction(migraph::op::broadcast{1, s}, bias_vals); auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s}, bias_vals);
p.add_instruction(migraph::op::add{}, img_scaled, bias_bcast); p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
...@@ -644,53 +644,53 @@ TEST_CASE(imagescaler_test) ...@@ -644,53 +644,53 @@ TEST_CASE(imagescaler_test)
0.53, 0.53,
0.73, 0.73,
0.93}; 0.93};
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(reshape_test) TEST_CASE(reshape_test)
{ {
migraph::shape a_shape{migraph::shape::float_type, {24, 1, 1, 1}}; migraphx::shape a_shape{migraphx::shape::float_type, {24, 1, 1, 1}};
std::vector<float> data(24); std::vector<float> data(24);
std::iota(data.begin(), data.end(), -3); std::iota(data.begin(), data.end(), -3);
{ {
migraph::program p; migraphx::program p;
auto l = p.add_literal(migraph::literal{a_shape, data}); auto l = p.add_literal(migraphx::literal{a_shape, data});
std::vector<int64_t> new_shape = {8, 3, 1, 1}; std::vector<int64_t> new_shape = {8, 3, 1, 1};
p.add_instruction(migraph::op::reshape{new_shape}, l); p.add_instruction(migraphx::op::reshape{new_shape}, l);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, data)); EXPECT(migraphx::verify_range(results_vector, data));
} }
{ {
migraph::program p; migraphx::program p;
auto l = p.add_literal(migraph::literal{a_shape, data}); auto l = p.add_literal(migraphx::literal{a_shape, data});
std::vector<int64_t> new_shape = {1, 3, 4, 2}; std::vector<int64_t> new_shape = {1, 3, 4, 2};
p.add_instruction(migraph::op::reshape{new_shape}, l); p.add_instruction(migraphx::op::reshape{new_shape}, l);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, data)); EXPECT(migraphx::verify_range(results_vector, data));
} }
{ {
migraph::program p; migraphx::program p;
auto l = p.add_literal(migraph::literal{a_shape, data}); auto l = p.add_literal(migraphx::literal{a_shape, data});
std::vector<int64_t> new_shape = {1, 3, 4, 2}; std::vector<int64_t> new_shape = {1, 3, 4, 2};
p.add_instruction(migraph::op::reshape{new_shape}, l); p.add_instruction(migraphx::op::reshape{new_shape}, l);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, data)); EXPECT(migraphx::verify_range(results_vector, data));
} }
} }
template <class T> template <class T>
void gemm_test() void gemm_test()
{ {
migraph::program p; migraphx::program p;
std::vector<T> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885, std::vector<T> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027, 1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632, -0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
...@@ -722,12 +722,12 @@ void gemm_test() ...@@ -722,12 +722,12 @@ void gemm_test()
-1.29885596e+00, -1.29885596e+00,
2.16294914e+00, 2.16294914e+00,
-1.48101497e-01}; -1.48101497e-01};
migraph::shape a_shape{migraph::shape::get_type<T>{}, {4, 5}}; migraphx::shape a_shape{migraphx::shape::get_type<T>{}, {4, 5}};
auto al = p.add_literal(migraph::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
migraph::shape b_shape{migraph::shape::get_type<T>{}, {5, 3}}; migraphx::shape b_shape{migraphx::shape::get_type<T>{}, {5, 3}};
auto bl = p.add_literal(migraph::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraph::op::dot{}, al, bl); p.add_instruction(migraphx::op::dot{}, al, bl);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<T> results_vector(12); std::vector<T> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
...@@ -742,7 +742,7 @@ TEST_CASE_REGISTER(gemm_test<double>) ...@@ -742,7 +742,7 @@ TEST_CASE_REGISTER(gemm_test<double>)
TEST_CASE(maxpool_test) TEST_CASE(maxpool_test)
{ {
migraph::program p; migraphx::program p;
std::vector<float> a = { std::vector<float> a = {
-2.1314404, -1.63041711, 1.54562736, 1.04625261, -1.42931843, -0.48703974, 0.4065806, -2.1314404, -1.63041711, 1.54562736, 1.04625261, -1.42931843, -0.48703974, 0.4065806,
-0.1524526, 1.30775225, 0.45538983, -0.06631992, -1.75332725, 1.33493888, 0.47327688, -0.1524526, 1.30775225, 0.45538983, -0.06631992, -1.75332725, 1.33493888, 0.47327688,
...@@ -781,10 +781,10 @@ TEST_CASE(maxpool_test) ...@@ -781,10 +781,10 @@ TEST_CASE(maxpool_test)
1.95433736, 2.46601582, 1.53285873, 1.95433736, 1.06763375, 1.4545635, 1.95433736, 2.46601582, 1.53285873, 1.95433736, 1.06763375, 1.4545635,
1.33624589, 1.16736257, 0.6126079, 1.36892557, 2.40126371, 1.53441942, 1.33624589, 1.16736257, 0.6126079, 1.36892557, 2.40126371, 1.53441942,
0.52119428, 2.07681108, 0.88494766, 1.51522756, 0.54275119, 0.6629802}; 0.52119428, 2.07681108, 0.88494766, 1.51522756, 0.54275119, 0.6629802};
migraph::shape a_shape{migraph::shape::float_type, {2, 3, 6, 6}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 6, 6}};
auto al = p.add_literal(migraph::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{3, 2}}}, al); p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{3, 2}}}, al);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
// std::cout << result.get_shape() << std::endl; // std::cout << result.get_shape() << std::endl;
std::vector<float> results_vector(36); std::vector<float> results_vector(36);
...@@ -799,7 +799,7 @@ TEST_CASE(maxpool_test) ...@@ -799,7 +799,7 @@ TEST_CASE(maxpool_test)
TEST_CASE(softmax_test) TEST_CASE(softmax_test)
{ {
migraph::program p; migraphx::program p;
std::vector<float> a = { std::vector<float> a = {
-5.61869681e-01, 9.07827199e-01, 1.29255986e+00, 3.18533443e-02, -1.22183852e-03, -5.61869681e-01, 9.07827199e-01, 1.29255986e+00, 3.18533443e-02, -1.22183852e-03,
-2.83830553e-01, -1.03245842e+00, -9.28322077e-01, -8.82696748e-01, 1.11327164e-01, -2.83830553e-01, -1.03245842e+00, -9.28322077e-01, -8.82696748e-01, 1.11327164e-01,
...@@ -846,19 +846,19 @@ TEST_CASE(softmax_test) ...@@ -846,19 +846,19 @@ TEST_CASE(softmax_test)
0.17377149, 0.76075399, 0.20071237, 0.32632929, 0.36892858, 0.09416146, 0.26656723, 0.17377149, 0.76075399, 0.20071237, 0.32632929, 0.36892858, 0.09416146, 0.26656723,
0.42914796}; 0.42914796};
migraph::shape a_shape{migraph::shape::float_type, {5, 3, 4, 2}}; migraphx::shape a_shape{migraphx::shape::float_type, {5, 3, 4, 2}};
auto al = p.add_literal(migraph::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
p.add_instruction(migraph::op::softmax{}, al); p.add_instruction(migraphx::op::softmax{}, al);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(120); std::vector<float> results_vector(120);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, s)); EXPECT(migraphx::verify_range(results_vector, s));
} }
TEST_CASE(conv2d_test) TEST_CASE(conv2d_test)
{ {
migraph::program p; migraphx::program p;
std::vector<float> a = { std::vector<float> a = {
2.71567607, -0.9960829, 0.91671127, 0.28140706, 0.63235772, 0.08077253, 0.80927712, 2.71567607, -0.9960829, 0.91671127, 0.28140706, 0.63235772, 0.08077253, 0.80927712,
-0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439, -0.65290606, -0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439, -0.65290606,
...@@ -904,24 +904,24 @@ TEST_CASE(conv2d_test) ...@@ -904,24 +904,24 @@ TEST_CASE(conv2d_test)
0.71606487, 0.71606487,
-0.55201721, -0.55201721,
-0.46427044}; -0.46427044};
migraph::shape a_shape{migraph::shape::float_type, {2, 3, 4, 4}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 4, 4}};
auto al = p.add_literal(migraph::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
migraph::shape c_shape{migraph::shape::float_type, {2, 3, 3, 3}}; migraphx::shape c_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto cl = p.add_literal(migraph::literal{c_shape, c}); auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(migraph::op::convolution{}, al, cl); p.add_instruction(migraphx::op::convolution{}, al, cl);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(16); std::vector<float> results_vector(16);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, s)); EXPECT(migraphx::verify_range(results_vector, s));
} }
TEST_CASE(conv2d_padding_test) TEST_CASE(conv2d_padding_test)
{ {
migraph::program p; migraphx::program p;
std::vector<float> a = { std::vector<float> a = {
2.71567607, -0.9960829, 0.91671127, 0.28140706, 0.63235772, 0.08077253, 0.80927712, 2.71567607, -0.9960829, 0.91671127, 0.28140706, 0.63235772, 0.08077253, 0.80927712,
-0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439, -0.65290606, -0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439, -0.65290606,
...@@ -960,24 +960,24 @@ TEST_CASE(conv2d_padding_test) ...@@ -960,24 +960,24 @@ TEST_CASE(conv2d_padding_test)
-0.20369984, -0.83037728, -1.40423918, -0.46160448, -0.22944322, 0.36074194, 0.49579027, -0.20369984, -0.83037728, -1.40423918, -0.46160448, -0.22944322, 0.36074194, 0.49579027,
0.46527559}; 0.46527559};
migraph::shape a_shape{migraph::shape::float_type, {2, 3, 4, 4}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 4, 4}};
auto al = p.add_literal(migraph::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
migraph::shape c_shape{migraph::shape::float_type, {2, 3, 3, 3}}; migraphx::shape c_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto cl = p.add_literal(migraph::literal{c_shape, c}); auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(migraph::op::convolution{{{1, 1}}, {{1, 1}}}, al, cl); p.add_instruction(migraphx::op::convolution{{{1, 1}}, {{1, 1}}}, al, cl);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(64); std::vector<float> results_vector(64);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, s)); EXPECT(migraphx::verify_range(results_vector, s));
} }
TEST_CASE(conv2d_padding_stride_test) TEST_CASE(conv2d_padding_stride_test)
{ {
migraph::program p; migraphx::program p;
std::vector<float> a = { std::vector<float> a = {
2.71567607, -0.9960829, 0.91671127, 0.28140706, 0.63235772, 0.08077253, 0.80927712, 2.71567607, -0.9960829, 0.91671127, 0.28140706, 0.63235772, 0.08077253, 0.80927712,
-0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439, -0.65290606, -0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439, -0.65290606,
...@@ -1021,33 +1021,33 @@ TEST_CASE(conv2d_padding_stride_test) ...@@ -1021,33 +1021,33 @@ TEST_CASE(conv2d_padding_stride_test)
-0.16138598, -0.16138598,
0.79344082}; 0.79344082};
migraph::shape a_shape{migraph::shape::float_type, {2, 3, 4, 4}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 4, 4}};
auto al = p.add_literal(migraph::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
migraph::shape c_shape{migraph::shape::float_type, {2, 3, 3, 3}}; migraphx::shape c_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto cl = p.add_literal(migraph::literal{c_shape, c}); auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(migraph::op::convolution{{{1, 1}}, {{2, 2}}}, al, cl); p.add_instruction(migraphx::op::convolution{{{1, 1}}, {{2, 2}}}, al, cl);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(16); std::vector<float> results_vector(16);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, s)); EXPECT(migraphx::verify_range(results_vector, s));
} }
TEST_CASE(transpose_test) TEST_CASE(transpose_test)
{ {
migraph::shape a_shape{migraph::shape::float_type, {1, 2, 2, 3}}; migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 2, 3}};
std::vector<float> data(12); std::vector<float> data(12);
std::iota(data.begin(), data.end(), 0); std::iota(data.begin(), data.end(), 0);
{ {
migraph::program p; migraphx::program p;
auto l = p.add_literal(migraph::literal{a_shape, data}); auto l = p.add_literal(migraphx::literal{a_shape, data});
std::vector<int64_t> perm = {0, 3, 1, 2}; std::vector<int64_t> perm = {0, 3, 1, 2};
p.add_instruction(migraph::op::transpose{perm}, l); p.add_instruction(migraphx::op::transpose{perm}, l);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
result.visit([&](auto output) { result.visit([&](auto output) {
...@@ -1056,31 +1056,31 @@ TEST_CASE(transpose_test) ...@@ -1056,31 +1056,31 @@ TEST_CASE(transpose_test)
}); });
} }
{ {
migraph::program p; migraphx::program p;
auto l = p.add_literal(migraph::literal{a_shape, data}); auto l = p.add_literal(migraphx::literal{a_shape, data});
std::vector<int64_t> perm = {0, 3, 1, 2}; std::vector<int64_t> perm = {0, 3, 1, 2};
auto result = p.add_instruction(migraph::op::transpose{perm}, l); auto result = p.add_instruction(migraphx::op::transpose{perm}, l);
p.add_instruction(migraph::op::contiguous{}, result); p.add_instruction(migraphx::op::contiguous{}, result);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result2 = p.eval({}); auto result2 = p.eval({});
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
result2.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result2.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11}; std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
} }
TEST_CASE(contiguous_test) TEST_CASE(contiguous_test)
{ {
migraph::shape a_shape{migraph::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}}; migraphx::shape a_shape{migraphx::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}};
std::vector<float> data(12); std::vector<float> data(12);
std::iota(data.begin(), data.end(), 0); std::iota(data.begin(), data.end(), 0);
migraph::program p; migraphx::program p;
auto l = p.add_literal(migraph::literal{a_shape, data}); auto l = p.add_literal(migraphx::literal{a_shape, data});
p.add_instruction(migraph::op::contiguous{}, l); p.add_instruction(migraphx::op::contiguous{}, l);
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
...@@ -1088,7 +1088,7 @@ TEST_CASE(contiguous_test) ...@@ -1088,7 +1088,7 @@ TEST_CASE(contiguous_test)
std::vector<size_t> new_lens = {1, 3, 2, 2}; std::vector<size_t> new_lens = {1, 3, 2, 2};
std::vector<size_t> new_strides = {12, 1, 6, 3}; std::vector<size_t> new_strides = {12, 1, 6, 3};
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11}; std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(identity_test) TEST_CASE(identity_test)
......
#include <migraph/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
struct dce_target struct dce_target
{ {
std::string name() const { return "dce"; } std::string name() const { return "dce"; }
std::vector<migraph::pass> get_passes(migraph::context&) const std::vector<migraphx::pass> get_passes(migraphx::context&) const
{ {
return {migraph::dead_code_elimination{}}; return {migraphx::dead_code_elimination{}};
} }
migraph::context get_context() const { return {}; } migraphx::context get_context() const { return {}; }
}; };
TEST_CASE(simple_test) TEST_CASE(simple_test)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
...@@ -23,13 +23,13 @@ TEST_CASE(simple_test) ...@@ -23,13 +23,13 @@ TEST_CASE(simple_test)
p.compile(dce_target{}); p.compile(dce_target{});
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(p.begin(), p.end()) == count);
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result == migraph::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraph::literal{4}); EXPECT(result != migraphx::literal{4});
} }
TEST_CASE(simple_test_nop) TEST_CASE(simple_test_nop)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
...@@ -39,13 +39,13 @@ TEST_CASE(simple_test_nop) ...@@ -39,13 +39,13 @@ TEST_CASE(simple_test_nop)
p.compile(dce_target{}); p.compile(dce_target{});
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(p.begin(), p.end()) == count);
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result == migraph::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraph::literal{4}); EXPECT(result != migraphx::literal{4});
} }
TEST_CASE(simple_test_nop2) TEST_CASE(simple_test_nop2)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
...@@ -55,13 +55,13 @@ TEST_CASE(simple_test_nop2) ...@@ -55,13 +55,13 @@ TEST_CASE(simple_test_nop2)
p.compile(dce_target{}); p.compile(dce_target{});
EXPECT(std::distance(p.begin(), p.end()) == 2); EXPECT(std::distance(p.begin(), p.end()) == 2);
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result == migraph::literal{}); EXPECT(result == migraphx::literal{});
EXPECT(result != migraph::literal{4}); EXPECT(result != migraphx::literal{4});
} }
TEST_CASE(duplicate_test1) TEST_CASE(duplicate_test1)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
...@@ -71,13 +71,13 @@ TEST_CASE(duplicate_test1) ...@@ -71,13 +71,13 @@ TEST_CASE(duplicate_test1)
p.compile(dce_target{}); p.compile(dce_target{});
EXPECT(std::distance(p.begin(), p.end()) == (count - 1)); EXPECT(std::distance(p.begin(), p.end()) == (count - 1));
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result == migraph::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraph::literal{4}); EXPECT(result != migraphx::literal{4});
} }
TEST_CASE(duplicate_test2) TEST_CASE(duplicate_test2)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
...@@ -88,13 +88,13 @@ TEST_CASE(duplicate_test2) ...@@ -88,13 +88,13 @@ TEST_CASE(duplicate_test2)
p.compile(dce_target{}); p.compile(dce_target{});
EXPECT(std::distance(p.begin(), p.end()) == (count - 2)); EXPECT(std::distance(p.begin(), p.end()) == (count - 2));
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result == migraph::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraph::literal{4}); EXPECT(result != migraphx::literal{4});
} }
TEST_CASE(depth_test) TEST_CASE(depth_test)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
...@@ -107,8 +107,8 @@ TEST_CASE(depth_test) ...@@ -107,8 +107,8 @@ TEST_CASE(depth_test)
p.compile(dce_target{}); p.compile(dce_target{});
EXPECT(std::distance(p.begin(), p.end()) == (count - 4)); EXPECT(std::distance(p.begin(), p.end()) == (count - 4));
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result == migraph::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraph::literal{4}); EXPECT(result != migraphx::literal{4});
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraph/eliminate_allocation.hpp> #include <migraphx/eliminate_allocation.hpp>
#include <migraph/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
...@@ -8,25 +8,26 @@ struct eliminate_allocation_target ...@@ -8,25 +8,26 @@ struct eliminate_allocation_target
{ {
std::size_t align = 32; std::size_t align = 32;
std::string name() const { return "eliminate_allocation"; } std::string name() const { return "eliminate_allocation"; }
std::vector<migraph::pass> get_passes(migraph::context&) const std::vector<migraphx::pass> get_passes(migraphx::context&) const
{ {
return {migraph::eliminate_allocation{"allocate", align}, migraph::dead_code_elimination{}}; return {migraphx::eliminate_allocation{"allocate", align},
migraphx::dead_code_elimination{}};
} }
migraph::context get_context() const { return {}; } migraphx::context get_context() const { return {}; }
}; };
struct allocate struct allocate
{ {
migraph::shape s{}; migraphx::shape s{};
std::string name() const { return "allocate"; } std::string name() const { return "allocate"; }
migraph::shape compute_shape(const std::vector<migraph::shape>& inputs) const migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{ {
migraph::check_shapes{inputs}.has(0); migraphx::check_shapes{inputs}.has(0);
return s; return s;
} }
migraph::argument compute(migraph::context&, migraphx::argument compute(migraphx::context&,
const migraph::shape& output_shape, const migraphx::shape& output_shape,
const std::vector<migraph::argument>&) const const std::vector<migraphx::argument>&) const
{ {
return {output_shape}; return {output_shape};
} }
...@@ -34,69 +35,69 @@ struct allocate ...@@ -34,69 +35,69 @@ struct allocate
TEST_CASE(basic) TEST_CASE(basic)
{ {
migraph::program p; migraphx::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {8}}}); auto a1 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {8}}});
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {40}}}); auto a2 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {40}}});
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto p2 = p.add_instruction(pass_op{}, a2, p1);
auto a3 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {200}}}); auto a3 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
p.add_instruction(pass_op{}, a3, p2); p.add_instruction(pass_op{}, a3, p2);
p.compile(eliminate_allocation_target{}); p.compile(eliminate_allocation_target{});
EXPECT(p.get_shape() == migraph::shape{migraph::shape::float_type, {200}}); EXPECT(p.get_shape() == migraphx::shape{migraphx::shape::float_type, {200}});
EXPECT(p.get_parameter_shape("memory").bytes() == (8 * 4 + 40 * 4 + 200 * 4)); EXPECT(p.get_parameter_shape("memory").bytes() == (8 * 4 + 40 * 4 + 200 * 4));
} }
TEST_CASE(aligned) TEST_CASE(aligned)
{ {
migraph::program p; migraphx::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1}}}); auto a1 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}});
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2}}}); auto a2 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}});
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto p2 = p.add_instruction(pass_op{}, a2, p1);
auto a3 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {200}}}); auto a3 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
p.add_instruction(pass_op{}, a3, p2); p.add_instruction(pass_op{}, a3, p2);
p.compile(eliminate_allocation_target{}); p.compile(eliminate_allocation_target{});
EXPECT(p.get_shape() == migraph::shape{migraph::shape::float_type, {200}}); EXPECT(p.get_shape() == migraphx::shape{migraphx::shape::float_type, {200}});
EXPECT(p.get_parameter_shape("memory").bytes() == (32 + 32 + 200 * 4)); EXPECT(p.get_parameter_shape("memory").bytes() == (32 + 32 + 200 * 4));
} }
TEST_CASE(unaligned) TEST_CASE(unaligned)
{ {
migraph::program p; migraphx::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1}}}); auto a1 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}});
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2}}}); auto a2 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}});
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto p2 = p.add_instruction(pass_op{}, a2, p1);
auto a3 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {200}}}); auto a3 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
p.add_instruction(pass_op{}, a3, p2); p.add_instruction(pass_op{}, a3, p2);
p.compile(eliminate_allocation_target{1}); p.compile(eliminate_allocation_target{1});
EXPECT(p.get_shape() == migraph::shape{migraph::shape::float_type, {200}}); EXPECT(p.get_shape() == migraphx::shape{migraphx::shape::float_type, {200}});
EXPECT(p.get_parameter_shape("memory").bytes() == (1 * 4 + 2 * 4 + 200 * 4)); EXPECT(p.get_parameter_shape("memory").bytes() == (1 * 4 + 2 * 4 + 200 * 4));
} }
TEST_CASE(float_aligned) TEST_CASE(float_aligned)
{ {
migraph::program p; migraphx::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1}}}); auto a1 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}});
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2}}}); auto a2 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}});
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto p2 = p.add_instruction(pass_op{}, a2, p1);
auto a3 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {200}}}); auto a3 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}});
p.add_instruction(pass_op{}, a3, p2); p.add_instruction(pass_op{}, a3, p2);
p.compile(eliminate_allocation_target{4}); p.compile(eliminate_allocation_target{4});
EXPECT(p.get_shape() == migraph::shape{migraph::shape::float_type, {200}}); EXPECT(p.get_shape() == migraphx::shape{migraphx::shape::float_type, {200}});
EXPECT(p.get_parameter_shape("memory").bytes() == (1 * 4 + 2 * 4 + 200 * 4)); EXPECT(p.get_parameter_shape("memory").bytes() == (1 * 4 + 2 * 4 + 200 * 4));
} }
......
#include <migraph/eliminate_concat.hpp> #include <migraphx/eliminate_concat.hpp>
#include <migraph/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
struct concat struct concat
{ {
concat(std::size_t axis) { op.axis = axis; } concat(std::size_t axis) { op.axis = axis; }
migraph::op::concat op; migraphx::op::concat op;
std::string name() const { return "eliminate_concat::concat"; } std::string name() const { return "eliminate_concat::concat"; }
migraph::shape compute_shape(std::vector<migraph::shape> inputs) const migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{ {
return op.compute_shape(std::move(inputs)); return op.compute_shape(std::move(inputs));
} }
migraph::argument compute(migraph::context&, migraphx::argument compute(migraphx::context&,
const migraph::shape& output_shape, const migraphx::shape& output_shape,
const std::vector<migraph::argument>&) const const std::vector<migraphx::argument>&) const
{ {
return {output_shape}; return {output_shape};
} }
...@@ -28,9 +28,9 @@ struct concat_test_optimization ...@@ -28,9 +28,9 @@ struct concat_test_optimization
/// A unique name used to identify the allocate operator /// A unique name used to identify the allocate operator
std::string allocate() const { return "allocate"; } std::string allocate() const { return "allocate"; }
/// Return the lowered concat operator /// Return the lowered concat operator
migraph::op::concat get_concat(const migraph::operation& op) const migraphx::op::concat get_concat(const migraphx::operation& op) const
{ {
return migraph::any_cast<concat>(op).op; return migraphx::any_cast<concat>(op).op;
} }
}; };
...@@ -38,26 +38,26 @@ struct eliminate_concat_target ...@@ -38,26 +38,26 @@ struct eliminate_concat_target
{ {
std::size_t align = 32; std::size_t align = 32;
std::string name() const { return "eliminate_target"; } std::string name() const { return "eliminate_target"; }
std::vector<migraph::pass> get_passes(migraph::context&) const std::vector<migraphx::pass> get_passes(migraphx::context&) const
{ {
return {migraph::eliminate_concat{concat_test_optimization{}}, return {migraphx::eliminate_concat{concat_test_optimization{}},
migraph::dead_code_elimination{}}; migraphx::dead_code_elimination{}};
} }
migraph::context get_context() const { return {}; } migraphx::context get_context() const { return {}; }
}; };
struct allocate struct allocate
{ {
migraph::shape s{}; migraphx::shape s{};
std::string name() const { return "allocate"; } std::string name() const { return "allocate"; }
migraph::shape compute_shape(const std::vector<migraph::shape>& inputs) const migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{ {
migraph::check_shapes{inputs}.has(0); migraphx::check_shapes{inputs}.has(0);
return s; return s;
} }
migraph::argument compute(migraph::context&, migraphx::argument compute(migraphx::context&,
const migraph::shape& output_shape, const migraphx::shape& output_shape,
const std::vector<migraph::argument>&) const const std::vector<migraphx::argument>&) const
{ {
return {output_shape}; return {output_shape};
} }
...@@ -66,14 +66,14 @@ struct allocate ...@@ -66,14 +66,14 @@ struct allocate
struct fred_op struct fred_op
{ {
std::string name() const { return "fred_op"; } std::string name() const { return "fred_op"; }
migraph::shape compute_shape(const std::vector<migraph::shape>& inputs) const migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{ {
migraph::check_shapes{inputs}.has(1); migraphx::check_shapes{inputs}.has(1);
return inputs.at(0); return inputs.at(0);
} }
migraph::argument compute(migraph::context&, migraphx::argument compute(migraphx::context&,
const migraph::shape&, const migraphx::shape&,
const std::vector<migraph::argument>& args) const const std::vector<migraphx::argument>& args) const
{ {
return args.at(0); return args.at(0);
} }
...@@ -82,37 +82,39 @@ struct fred_op ...@@ -82,37 +82,39 @@ struct fred_op
TEST_CASE(basic) TEST_CASE(basic)
{ {
auto create_test_program = []() { auto create_test_program = []() {
migraph::program p; migraphx::program p;
auto a1 = auto a1 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 2, 8, 8}}}); p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}});
auto p1 = p.add_instruction(fred_op{}, a1); auto p1 = p.add_instruction(fred_op{}, a1);
auto a2 = auto a2 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 3, 8, 8}}}); p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}});
auto p2 = p.add_instruction(fred_op{}, a2); auto p2 = p.add_instruction(fred_op{}, a2);
auto a3 = auto a3 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 5, 8, 8}}}); p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}});
auto p3 = p.add_instruction(fred_op{}, a3); auto p3 = p.add_instruction(fred_op{}, a3);
std::size_t axis = 1; std::size_t axis = 1;
auto a4 = auto a4 = p.add_instruction(
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 10, 8, 8}}}); allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
p.add_instruction(concat(axis), p1, p2, p3, a4); p.add_instruction(concat(axis), p1, p2, p3, a4);
return p; return p;
}; };
auto create_control_program = []() { auto create_control_program = []() {
migraph::program p; migraphx::program p;
auto a1 = auto a1 = p.add_instruction(
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 10, 8, 8}}}); allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}});
auto l1 = p.add_instruction( auto l1 = p.add_instruction(
migraph::op::load{migraph::shape{migraph::shape::float_type, {1, 2, 8, 8}}, 0}, {a1}); migraphx::op::load{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}, 0},
{a1});
auto p1 = p.add_instruction(fred_op{}, l1); auto p1 = p.add_instruction(fred_op{}, l1);
auto l2 = p.add_instruction( auto l2 = p.add_instruction(
migraph::op::load{migraph::shape{migraph::shape::float_type, {1, 3, 8, 8}}, 512}, {a1}); migraphx::op::load{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}, 512},
{a1});
auto p2 = p.add_instruction(fred_op{}, l2); auto p2 = p.add_instruction(fred_op{}, l2);
auto l3 = p.add_instruction( auto l3 = p.add_instruction(
migraph::op::load{migraph::shape{migraph::shape::float_type, {1, 5, 8, 8}}, 1280}, migraphx::op::load{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}, 1280},
{a1}); {a1});
auto p3 = p.add_instruction(fred_op{}, l3); auto p3 = p.add_instruction(fred_op{}, l3);
p.add_instruction(migraph::op::identity{}, {a1, p1, p2, p3}); p.add_instruction(migraphx::op::identity{}, {a1, p1, p2, p3});
return p; return p;
}; };
...@@ -126,36 +128,36 @@ TEST_CASE(basic) ...@@ -126,36 +128,36 @@ TEST_CASE(basic)
TEST_CASE(wont_work) TEST_CASE(wont_work)
{ {
auto create_test_program = []() { auto create_test_program = []() {
migraph::program p; migraphx::program p;
auto a1 = auto a1 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 2, 8, 8}}}); p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
auto p1 = p.add_instruction(fred_op{}, a1); auto p1 = p.add_instruction(fred_op{}, a1);
auto a2 = auto a2 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 3, 8, 8}}}); p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
auto p2 = p.add_instruction(fred_op{}, a2); auto p2 = p.add_instruction(fred_op{}, a2);
auto a3 = auto a3 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 5, 8, 8}}}); p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
auto p3 = p.add_instruction(fred_op{}, a3); auto p3 = p.add_instruction(fred_op{}, a3);
std::size_t axis = 1; std::size_t axis = 1;
auto a4 = auto a4 = p.add_instruction(
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 10, 8, 8}}}); allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
p.add_instruction(concat(axis), p1, p2, p3, a4); p.add_instruction(concat(axis), p1, p2, p3, a4);
return p; return p;
}; };
auto create_control_program = []() { auto create_control_program = []() {
migraph::program p; migraphx::program p;
auto a1 = auto a1 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 2, 8, 8}}}); p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}});
auto p1 = p.add_instruction(fred_op{}, a1); auto p1 = p.add_instruction(fred_op{}, a1);
auto a2 = auto a2 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 3, 8, 8}}}); p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}});
auto p2 = p.add_instruction(fred_op{}, a2); auto p2 = p.add_instruction(fred_op{}, a2);
auto a3 = auto a3 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 5, 8, 8}}}); p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}});
auto p3 = p.add_instruction(fred_op{}, a3); auto p3 = p.add_instruction(fred_op{}, a3);
std::size_t axis = 1; std::size_t axis = 1;
auto a4 = auto a4 = p.add_instruction(
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 10, 8, 8}}}); allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}});
p.add_instruction(concat(axis), p1, p2, p3, a4); p.add_instruction(concat(axis), p1, p2, p3, a4);
return p; return p;
}; };
......
#include <migraph/eliminate_contiguous.hpp> #include <migraphx/eliminate_contiguous.hpp>
#include <migraph/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
struct eliminate_contiguous_target struct eliminate_contiguous_target
{ {
std::string name() const { return "eliminate_contiguous"; } std::string name() const { return "eliminate_contiguous"; }
std::vector<migraph::pass> get_passes(migraph::context&) const std::vector<migraphx::pass> get_passes(migraphx::context&) const
{ {
return {migraph::eliminate_contiguous{}, migraph::dead_code_elimination{}}; return {migraphx::eliminate_contiguous{}, migraphx::dead_code_elimination{}};
} }
migraph::context get_context() const { return {}; } migraphx::context get_context() const { return {}; }
}; };
TEST_CASE(standard_op) TEST_CASE(standard_op)
{ {
migraph::program p; migraphx::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraph::op::transpose{{1, 0}}, l); auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraph::op::contiguous{}, t); auto c = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_standard_op{}, c); p.add_instruction(pass_standard_op{}, c);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{}); p.compile(eliminate_contiguous_target{});
...@@ -28,10 +28,10 @@ TEST_CASE(standard_op) ...@@ -28,10 +28,10 @@ TEST_CASE(standard_op)
TEST_CASE(non_standard_op) TEST_CASE(non_standard_op)
{ {
migraph::program p; migraphx::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraph::op::transpose{{1, 0}}, l); auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraph::op::contiguous{}, t); auto c = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_op{}, c); p.add_instruction(pass_op{}, c);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{}); p.compile(eliminate_contiguous_target{});
......
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <sstream> #include <sstream>
#include "test.hpp" #include "test.hpp"
#include <basic_ops.hpp> #include <basic_ops.hpp>
...@@ -9,17 +9,17 @@ ...@@ -9,17 +9,17 @@
struct id_target struct id_target
{ {
std::string name() const { return "id"; } std::string name() const { return "id"; }
std::vector<migraph::pass> get_passes(migraph::context&) const { return {}; } std::vector<migraphx::pass> get_passes(migraphx::context&) const { return {}; }
migraph::context get_context() const { return {}; } migraphx::context get_context() const { return {}; }
}; };
struct reverse_pass struct reverse_pass
{ {
std::string name() const { return "reverse_pass"; } std::string name() const { return "reverse_pass"; }
void apply(migraph::program& p) const void apply(migraphx::program& p) const
{ {
for(auto ins : migraph::iterator_for(p)) for(auto ins : migraphx::iterator_for(p))
{ {
if(ins->name() == "sum") if(ins->name() == "sum")
{ {
...@@ -36,35 +36,35 @@ struct reverse_pass ...@@ -36,35 +36,35 @@ struct reverse_pass
struct reverse_target struct reverse_target
{ {
std::string name() const { return "reverse"; } std::string name() const { return "reverse"; }
std::vector<migraph::pass> get_passes(migraph::context&) const { return {reverse_pass{}}; } std::vector<migraphx::pass> get_passes(migraphx::context&) const { return {reverse_pass{}}; }
migraph::context get_context() const { return {}; } migraphx::context get_context() const { return {}; }
}; };
struct double_reverse_target struct double_reverse_target
{ {
std::string name() const { return "double_reverse"; } std::string name() const { return "double_reverse"; }
std::vector<migraph::pass> get_passes(migraph::context&) const std::vector<migraphx::pass> get_passes(migraphx::context&) const
{ {
return {reverse_pass{}, reverse_pass{}}; return {reverse_pass{}, reverse_pass{}};
} }
migraph::context get_context() const { return {}; } migraphx::context get_context() const { return {}; }
}; };
TEST_CASE(literal_test1) TEST_CASE(literal_test1)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
p.add_instruction(sum_op{}, one, two); p.add_instruction(sum_op{}, one, two);
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result == migraph::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraph::literal{4}); EXPECT(result != migraphx::literal{4});
} }
TEST_CASE(literal_test2) TEST_CASE(literal_test2)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
...@@ -72,15 +72,15 @@ TEST_CASE(literal_test2) ...@@ -72,15 +72,15 @@ TEST_CASE(literal_test2)
p.add_instruction(sum_op{}, sum1, two); p.add_instruction(sum_op{}, sum1, two);
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result == migraph::literal{5}); EXPECT(result == migraphx::literal{5});
EXPECT(result != migraph::literal{3}); EXPECT(result != migraphx::literal{3});
} }
TEST_CASE(print_test) TEST_CASE(print_test)
{ {
migraph::program p; migraphx::program p;
auto x = p.add_parameter("x", {migraph::shape::int64_type}); auto x = p.add_parameter("x", {migraphx::shape::int64_type});
auto two = p.add_literal(2); auto two = p.add_literal(2);
p.add_instruction(sum_op{}, x, two); p.add_instruction(sum_op{}, x, two);
...@@ -92,36 +92,36 @@ TEST_CASE(print_test) ...@@ -92,36 +92,36 @@ TEST_CASE(print_test)
TEST_CASE(param_test) TEST_CASE(param_test)
{ {
migraph::program p; migraphx::program p;
auto x = p.add_parameter("x", {migraph::shape::int64_type}); auto x = p.add_parameter("x", {migraphx::shape::int64_type});
auto y = p.add_parameter("y", {migraph::shape::int64_type}); auto y = p.add_parameter("y", {migraphx::shape::int64_type});
p.add_instruction(sum_op{}, x, y); p.add_instruction(sum_op{}, x, y);
auto result = p.eval( auto result = p.eval(
{{"x", migraph::literal{1}.get_argument()}, {"y", migraph::literal{2}.get_argument()}}); {{"x", migraphx::literal{1}.get_argument()}, {"y", migraphx::literal{2}.get_argument()}});
EXPECT(result == migraph::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraph::literal{4}); EXPECT(result != migraphx::literal{4});
} }
TEST_CASE(param_error_test) TEST_CASE(param_error_test)
{ {
migraph::program p; migraphx::program p;
auto x = p.add_parameter("x", {migraph::shape::int64_type}); auto x = p.add_parameter("x", {migraphx::shape::int64_type});
auto y = p.add_parameter("y", {migraph::shape::int64_type}); auto y = p.add_parameter("y", {migraphx::shape::int64_type});
p.add_instruction(sum_op{}, x, y); p.add_instruction(sum_op{}, x, y);
EXPECT(test::throws<migraph::exception>( EXPECT(test::throws<migraphx::exception>(
[&] { [&] {
p.eval({{"x", migraph::literal{1}.get_argument()}}); p.eval({{"x", migraphx::literal{1}.get_argument()}});
}, },
"Parameter not found: y")); "Parameter not found: y"));
} }
TEST_CASE(replace_test) TEST_CASE(replace_test)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
...@@ -130,13 +130,13 @@ TEST_CASE(replace_test) ...@@ -130,13 +130,13 @@ TEST_CASE(replace_test)
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result == migraph::literal{1}); EXPECT(result == migraphx::literal{1});
EXPECT(result != migraph::literal{3}); EXPECT(result != migraphx::literal{3});
} }
TEST_CASE(replace_ins_test) TEST_CASE(replace_ins_test)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
...@@ -146,13 +146,13 @@ TEST_CASE(replace_ins_test) ...@@ -146,13 +146,13 @@ TEST_CASE(replace_ins_test)
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result == migraph::literal{1}); EXPECT(result == migraphx::literal{1});
EXPECT(result != migraph::literal{3}); EXPECT(result != migraphx::literal{3});
} }
TEST_CASE(replace_ins_test2) TEST_CASE(replace_ins_test2)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
...@@ -163,13 +163,13 @@ TEST_CASE(replace_ins_test2) ...@@ -163,13 +163,13 @@ TEST_CASE(replace_ins_test2)
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result == migraph::literal{2}); EXPECT(result == migraphx::literal{2});
EXPECT(result != migraph::literal{3}); EXPECT(result != migraphx::literal{3});
} }
TEST_CASE(insert_replace_test) TEST_CASE(insert_replace_test)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
...@@ -181,47 +181,47 @@ TEST_CASE(insert_replace_test) ...@@ -181,47 +181,47 @@ TEST_CASE(insert_replace_test)
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result == migraph::literal{4}); EXPECT(result == migraphx::literal{4});
EXPECT(result != migraph::literal{5}); EXPECT(result != migraphx::literal{5});
} }
TEST_CASE(target_test) TEST_CASE(target_test)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
p.add_instruction(sum_op{}, one, two); p.add_instruction(sum_op{}, one, two);
p.compile(id_target{}); p.compile(id_target{});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result == migraph::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraph::literal{4}); EXPECT(result != migraphx::literal{4});
} }
TEST_CASE(reverse_target_test) TEST_CASE(reverse_target_test)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
p.add_instruction(sum_op{}, two, one); p.add_instruction(sum_op{}, two, one);
p.compile(reverse_target{}); p.compile(reverse_target{});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result == migraph::literal{1}); EXPECT(result == migraphx::literal{1});
EXPECT(result != migraph::literal{4}); EXPECT(result != migraphx::literal{4});
} }
TEST_CASE(double_reverse_target_test) TEST_CASE(double_reverse_target_test)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
p.add_instruction(sum_op{}, two, one); p.add_instruction(sum_op{}, two, one);
p.compile(double_reverse_target{}); p.compile(double_reverse_target{});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result == migraph::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraph::literal{4}); EXPECT(result != migraphx::literal{4});
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraph/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <test.hpp> #include <test.hpp>
#include <migraph/verify.hpp> #include <migraphx/verify.hpp>
TEST_CASE(fwd_conv_batchnorm_rewrite_test) TEST_CASE(fwd_conv_batchnorm_rewrite_test)
{ {
...@@ -30,29 +30,30 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test) ...@@ -30,29 +30,30 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test)
-0.62146691, -2.40572931, -1.47175612, 1.49654601, -1.07070376, -0.65908074, -0.28457694, -0.62146691, -2.40572931, -1.47175612, 1.49654601, -1.07070376, -0.65908074, -0.28457694,
1.60046717, 0.20677642, -1.51844486, 0.41203847, -0.01285751, 0.07948031, -0.91507006, 1.60046717, 0.20677642, -1.51844486, 0.41203847, -0.01285751, 0.07948031, -0.91507006,
-1.59481079, -0.12856238, 0.39970482, -1.89015158, 0.66969754, 0.10312618}; -1.59481079, -0.12856238, 0.39970482, -1.89015158, 0.66969754, 0.10312618};
migraph::shape xs{migraph::shape::float_type, {1, 3, 6, 6}}; migraphx::shape xs{migraphx::shape::float_type, {1, 3, 6, 6}};
migraph::shape ws{migraph::shape::float_type, {1, 3, 3, 3}}; migraphx::shape ws{migraphx::shape::float_type, {1, 3, 3, 3}};
migraph::shape vars{migraph::shape::float_type, {1}}; migraphx::shape vars{migraphx::shape::float_type, {1}};
auto create_program = [&]() { auto create_program = [&]() {
migraph::program p; migraphx::program p;
auto x = p.add_literal(xs, xdata); auto x = p.add_literal(xs, xdata);
auto w = p.add_literal(ws, wdata); auto w = p.add_literal(ws, wdata);
auto conv = p.add_instruction(migraph::op::convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, x, w); auto conv =
auto scale = p.add_literal(migraph::literal{vars, {3.0f}}); p.add_instruction(migraphx::op::convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, x, w);
auto bias = p.add_literal(migraph::literal{vars, {8.1f}}); auto scale = p.add_literal(migraphx::literal{vars, {3.0f}});
auto mean = p.add_literal(migraph::literal{vars, {4.0f}}); auto bias = p.add_literal(migraphx::literal{vars, {8.1f}});
auto variance = p.add_literal(migraph::literal{vars, {37.11f}}); auto mean = p.add_literal(migraphx::literal{vars, {4.0f}});
p.add_instruction(migraph::op::batch_norm_inference{}, conv, scale, bias, mean, variance); auto variance = p.add_literal(migraphx::literal{vars, {37.11f}});
p.add_instruction(migraphx::op::batch_norm_inference{}, conv, scale, bias, mean, variance);
return p; return p;
}; };
migraph::program p1 = create_program(); migraphx::program p1 = create_program();
migraph::program p2 = create_program(); migraphx::program p2 = create_program();
migraph::fwd_conv_batchnorm_rewrite opt; migraphx::fwd_conv_batchnorm_rewrite opt;
opt.apply(p2); opt.apply(p2);
p1.compile(migraph::cpu::target{}); p1.compile(migraphx::cpu::target{});
p2.compile(migraph::cpu::target{}); p2.compile(migraphx::cpu::target{});
auto result1 = p1.eval({}); auto result1 = p1.eval({});
auto result2 = p2.eval({}); auto result2 = p2.eval({});
...@@ -61,7 +62,7 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test) ...@@ -61,7 +62,7 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test)
std::vector<float> results_vector2; std::vector<float> results_vector2;
result1.visit([&](auto output) { results_vector1.assign(output.begin(), output.end()); }); result1.visit([&](auto output) { results_vector1.assign(output.begin(), output.end()); });
result2.visit([&](auto output) { results_vector2.assign(output.begin(), output.end()); }); result2.visit([&](auto output) { results_vector2.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector1, results_vector2)); EXPECT(migraphx::verify_range(results_vector1, results_vector2));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <test.hpp> #include <test.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraph/generate.hpp> #include <migraphx/generate.hpp>
#include <migraph/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
#include <migraph/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
void gpu_literal_test() void gpu_literal_test()
{ {
migraph::program p; migraphx::program p;
auto lit = generate_literal(migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}); auto lit = generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
p.add_literal(lit); p.add_literal(lit);
p.compile(migraph::gpu::target{}); p.compile(migraphx::gpu::target{});
auto scratch = p.get_parameter("scratch"); auto scratch = p.get_parameter("scratch");
if(scratch == p.end()) if(scratch == p.end())
{ {
auto result = p.eval({}); auto result = p.eval({});
EXPECT(lit == migraph::gpu::from_gpu(result)); EXPECT(lit == migraphx::gpu::from_gpu(result));
} }
else else
{ {
......
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <migraph/generate.hpp> #include <migraphx/generate.hpp>
#include <migraph/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraph/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
#include <migraph/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
#include <migraph/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraph/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraph/type_name.hpp> #include <migraphx/type_name.hpp>
#include <migraph/verify_args.hpp> #include <migraphx/verify_args.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <miopen/miopen.h> #include <miopen/miopen.h>
...@@ -81,12 +81,12 @@ auto get_hash(const T& x) ...@@ -81,12 +81,12 @@ auto get_hash(const T& x)
return std::hash<T>{}(x); return std::hash<T>{}(x);
} }
void compile_check(migraph::program& p, const migraph::target& t) void compile_check(migraphx::program& p, const migraphx::target& t)
{ {
auto name = t.name(); auto name = t.name();
auto s = p.get_shape(); auto s = p.get_shape();
std::stringstream ss; std::stringstream ss;
p.compile(t, migraph::tracer{ss}); p.compile(t, migraphx::tracer{ss});
if(p.get_shape() != s) if(p.get_shape() != s)
{ {
std::cout << ss.str() << std::endl; std::cout << ss.str() << std::endl;
...@@ -95,47 +95,48 @@ void compile_check(migraph::program& p, const migraph::target& t) ...@@ -95,47 +95,48 @@ void compile_check(migraph::program& p, const migraph::target& t)
} }
template <class V> template <class V>
migraph::argument run_cpu(migraph::program& p) migraphx::argument run_cpu(migraphx::program& p)
{ {
V v; V v;
p = v.create_program(); p = v.create_program();
auto_print pp{p, 0}; auto_print pp{p, 0};
compile_check(p, migraph::cpu::target{}); compile_check(p, migraphx::cpu::target{});
migraph::program::parameter_map m; migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes()) for(auto&& x : p.get_parameter_shapes())
{ {
m[x.first] = migraph::generate_argument(x.second, get_hash(x.first)); m[x.first] = migraphx::generate_argument(x.second, get_hash(x.first));
} }
return p.eval(m); return p.eval(m);
} }
template <class V> template <class V>
migraph::argument run_gpu(migraph::program& p) migraphx::argument run_gpu(migraphx::program& p)
{ {
V v; V v;
p = v.create_program(); p = v.create_program();
auto_print pp{p, 1}; auto_print pp{p, 1};
compile_check(p, migraph::gpu::target{}); compile_check(p, migraphx::gpu::target{});
migraph::program::parameter_map m; migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes()) for(auto&& x : p.get_parameter_shapes())
{ {
m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second, get_hash(x.first))); m[x.first] =
migraphx::gpu::to_gpu(migraphx::generate_argument(x.second, get_hash(x.first)));
} }
EXPECT(bool{m.find("output") != m.end()}); EXPECT(bool{m.find("output") != m.end()});
return migraph::gpu::from_gpu(p.eval(m)); return migraphx::gpu::from_gpu(p.eval(m));
} }
template <class V> template <class V>
void verify_program() void verify_program()
{ {
auto_print::set_terminate_handler(migraph::get_type_name<V>()); auto_print::set_terminate_handler(migraphx::get_type_name<V>());
// std::cout << migraph::get_type_name<V>() << std::endl; // std::cout << migraphx::get_type_name<V>() << std::endl;
migraph::program cpu_prog; migraphx::program cpu_prog;
migraph::program gpu_prog; migraphx::program gpu_prog;
auto cpu_arg_f = detach_async([&] { return run_cpu<V>(cpu_prog); }); auto cpu_arg_f = detach_async([&] { return run_cpu<V>(cpu_prog); });
auto gpu_arg = run_gpu<V>(gpu_prog); auto gpu_arg = run_gpu<V>(gpu_prog);
auto cpu_arg = cpu_arg_f.get(); auto cpu_arg = cpu_arg_f.get();
bool passed = verify_args(migraph::get_type_name<V>(), cpu_arg, gpu_arg); bool passed = verify_args(migraphx::get_type_name<V>(), cpu_arg, gpu_arg);
if(not passed) if(not passed)
{ {
V v; V v;
...@@ -150,82 +151,82 @@ void verify_program() ...@@ -150,82 +151,82 @@ void verify_program()
struct test_literals struct test_literals
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
auto input = p.add_literal( auto input = p.add_literal(
generate_literal(migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}})); generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}));
auto weights = p.add_literal( auto weights = p.add_literal(
generate_literal(migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}})); generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}));
auto conv = p.add_instruction(migraph::op::convolution{}, input, weights); auto conv = p.add_instruction(migraphx::op::convolution{}, input, weights);
p.add_instruction(migraph::op::relu{}, conv); p.add_instruction(migraphx::op::relu{}, conv);
return p; return p;
} }
}; };
struct test_add struct test_add
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s); auto y = p.add_parameter("y", s);
p.add_instruction(migraph::op::add{}, x, y); p.add_instruction(migraphx::op::add{}, x, y);
return p; return p;
} }
}; };
struct test_add_half struct test_add_half
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::half_type, {3}}; migraphx::shape s{migraphx::shape::half_type, {3}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s); auto y = p.add_parameter("y", s);
p.add_instruction(migraph::op::add{}, x, y); p.add_instruction(migraphx::op::add{}, x, y);
return p; return p;
} }
}; };
struct test_mul struct test_mul
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s); auto y = p.add_parameter("y", s);
p.add_instruction(migraph::op::mul{}, x, y); p.add_instruction(migraphx::op::mul{}, x, y);
return p; return p;
} }
}; };
struct test_scale struct test_scale
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", migraph::shape::float_type); auto y = p.add_parameter("y", migraphx::shape::float_type);
auto scale = p.add_instruction(migraph::op::scalar{s}, y); auto scale = p.add_instruction(migraphx::op::scalar{s}, y);
p.add_instruction(migraph::op::mul{}, x, scale); p.add_instruction(migraphx::op::mul{}, x, scale);
return p; return p;
} }
}; };
struct test_slice struct test_slice
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::int32_type, {2, 2, 4}}; migraphx::shape s{migraphx::shape::int32_type, {2, 2, 4}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", {migraph::shape::int32_type, {2, 2, 2}}); auto y = p.add_parameter("y", {migraphx::shape::int32_type, {2, 2, 2}});
auto slice0 = p.add_instruction(migraph::op::slice{{2}, {0}, {2}}, x); auto slice0 = p.add_instruction(migraphx::op::slice{{2}, {0}, {2}}, x);
p.add_instruction(migraph::op::add{}, y, slice0); p.add_instruction(migraphx::op::add{}, y, slice0);
return p; return p;
} }
...@@ -233,247 +234,251 @@ struct test_slice ...@@ -233,247 +234,251 @@ struct test_slice
struct test_triadd struct test_triadd
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s); auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", s); auto z = p.add_parameter("z", s);
auto sum = p.add_instruction(migraph::op::add{}, x, y); auto sum = p.add_instruction(migraphx::op::add{}, x, y);
p.add_instruction(migraph::op::add{}, sum, z); p.add_instruction(migraphx::op::add{}, sum, z);
return p; return p;
} }
}; };
struct test_triadd2 struct test_triadd2
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraph::shape b{migraph::shape::float_type, {3}}; migraphx::shape b{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s); auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", b); auto z = p.add_parameter("z", b);
auto zb = p.add_instruction(migraph::op::broadcast{1, s}, z); auto zb = p.add_instruction(migraphx::op::broadcast{1, s}, z);
auto sum = p.add_instruction(migraph::op::add{}, x, y); auto sum = p.add_instruction(migraphx::op::add{}, x, y);
p.add_instruction(migraph::op::add{}, sum, zb); p.add_instruction(migraphx::op::add{}, sum, zb);
return p; return p;
} }
}; };
struct test_add_broadcast struct test_add_broadcast
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 2, 3}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}});
auto y = p.add_parameter("y", {migraph::shape::float_type, {2, 2}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {2, 2}});
auto by = p.add_instruction(migraph::op::broadcast{0, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape()}, y);
p.add_instruction(migraph::op::add{}, x, by); p.add_instruction(migraphx::op::add{}, x, by);
return p; return p;
} }
}; };
struct test_add_broadcast2 struct test_add_broadcast2
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 3, 4}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 3, 4}});
auto y = p.add_parameter("y", {migraph::shape::float_type, {3}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {3}});
auto by = p.add_instruction(migraph::op::broadcast{1, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y);
p.add_instruction(migraph::op::add{}, x, by); p.add_instruction(migraphx::op::add{}, x, by);
return p; return p;
} }
}; };
struct test_add_broadcast3 struct test_add_broadcast3
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 4, 5}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 4, 5}});
auto y = p.add_parameter("y", {migraph::shape::float_type, {4}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {4}});
auto by = p.add_instruction(migraph::op::broadcast{1, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y);
p.add_instruction(migraph::op::add{}, x, by); p.add_instruction(migraphx::op::add{}, x, by);
return p; return p;
} }
}; };
struct test_add_broadcast4 struct test_add_broadcast4
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 3, 5}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 3, 5}});
auto y = p.add_parameter("y", {migraph::shape::float_type, {3}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {3}});
auto by = p.add_instruction(migraph::op::broadcast{1, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y);
p.add_instruction(migraph::op::add{}, x, by); p.add_instruction(migraphx::op::add{}, x, by);
return p; return p;
} }
}; };
struct test_add_broadcast5 struct test_add_broadcast5
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 4, 8}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 4, 8}});
auto y = p.add_parameter("y", {migraph::shape::float_type, {4}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {4}});
auto by = p.add_instruction(migraph::op::broadcast{1, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y);
p.add_instruction(migraph::op::add{}, x, by); p.add_instruction(migraphx::op::add{}, x, by);
return p; return p;
} }
}; };
struct test_triadd_broadcast struct test_triadd_broadcast
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 2, 3}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}});
auto y = p.add_parameter("y", {migraph::shape::float_type, {2, 2}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {2, 2}});
auto z = p.add_parameter("z", {migraph::shape::float_type, {2, 2, 3}}); auto z = p.add_parameter("z", {migraphx::shape::float_type, {2, 2, 3}});
auto by = p.add_instruction(migraph::op::broadcast{0, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape()}, y);
auto sum = p.add_instruction(migraph::op::add{}, x, by); auto sum = p.add_instruction(migraphx::op::add{}, x, by);
p.add_instruction(migraph::op::add{}, sum, z); p.add_instruction(migraphx::op::add{}, sum, z);
return p; return p;
} }
}; };
struct test_softmax struct test_softmax
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {5, 3, 4, 2}}); auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 3, 4, 2}});
p.add_instruction(migraph::op::softmax{}, x); p.add_instruction(migraphx::op::softmax{}, x);
return p; return p;
} }
}; };
struct test_softmax2 struct test_softmax2
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {1, 1000, 1, 1}}); auto x =
p.add_instruction(migraph::op::softmax{}, x); p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1000, 1, 1}});
p.add_instruction(migraphx::op::softmax{}, x);
return p; return p;
} }
}; };
struct test_conv struct test_conv
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
auto input = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}); auto input =
p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto weights = auto weights =
p.add_parameter("w", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
p.add_instruction(migraph::op::convolution{}, input, weights); p.add_instruction(migraphx::op::convolution{}, input, weights);
return p; return p;
} }
}; };
struct test_conv2 struct test_conv2
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
auto input = auto input =
p.add_parameter("x", migraph::shape{migraph::shape::float_type, {1, 512, 28, 28}}); p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 512, 28, 28}});
auto weights = auto weights =
p.add_parameter("w", migraph::shape{migraph::shape::float_type, {256, 512, 1, 1}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {256, 512, 1, 1}});
p.add_instruction(migraph::op::convolution{{0, 0}, {1, 1}, {1, 1}}, input, weights); p.add_instruction(migraphx::op::convolution{{0, 0}, {1, 1}, {1, 1}}, input, weights);
return p; return p;
} }
}; };
struct test_conv_relu struct test_conv_relu
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
auto input = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}); auto input =
p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto weights = auto weights =
p.add_parameter("w", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto conv = p.add_instruction(migraph::op::convolution{}, input, weights); auto conv = p.add_instruction(migraphx::op::convolution{}, input, weights);
p.add_instruction(migraph::op::relu{}, conv); p.add_instruction(migraphx::op::relu{}, conv);
return p; return p;
} }
}; };
struct test_conv_relu_half struct test_conv_relu_half
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
auto input = p.add_parameter("x", migraph::shape{migraph::shape::half_type, {4, 3, 3, 3}}); auto input =
p.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}});
auto weights = auto weights =
p.add_parameter("w", migraph::shape{migraph::shape::half_type, {4, 3, 3, 3}}); p.add_parameter("w", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}});
auto conv = p.add_instruction(migraph::op::convolution{}, input, weights); auto conv = p.add_instruction(migraphx::op::convolution{}, input, weights);
p.add_instruction(migraph::op::relu{}, conv); p.add_instruction(migraphx::op::relu{}, conv);
return p; return p;
} }
}; };
struct test_add_relu struct test_add_relu
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}); auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto y = p.add_parameter("y", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}); auto y = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto add = p.add_instruction(migraph::op::add{}, x, y); auto add = p.add_instruction(migraphx::op::add{}, x, y);
p.add_instruction(migraph::op::relu{}, add); p.add_instruction(migraphx::op::relu{}, add);
return p; return p;
} }
}; };
struct test_leaky_relu struct test_leaky_relu
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}); auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
p.add_instruction(migraph::op::leaky_relu{0.01}, x); p.add_instruction(migraphx::op::leaky_relu{0.01}, x);
return p; return p;
} }
}; };
struct test_conv_pooling struct test_conv_pooling
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
auto input = auto input =
p.add_parameter("x", migraph::shape{migraph::shape::float_type, {4, 3, 32, 32}}); p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 32, 32}});
auto weights = auto weights =
p.add_parameter("w", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}); p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto conv = p.add_instruction(migraph::op::convolution{}, input, weights); auto conv = p.add_instruction(migraphx::op::convolution{}, input, weights);
auto pooling = p.add_instruction(migraph::op::pooling{"max"}, conv); auto pooling = p.add_instruction(migraphx::op::pooling{"max"}, conv);
p.add_instruction(migraph::op::relu{}, pooling); p.add_instruction(migraphx::op::relu{}, pooling);
return p; return p;
} }
}; };
struct test_global_avg_pooling struct test_global_avg_pooling
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
auto input = auto input =
p.add_parameter("x", migraph::shape{migraph::shape::float_type, {1, 3, 16, 16}}); p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto op = migraph::op::pooling{"average"}; auto op = migraphx::op::pooling{"average"};
auto lens = input->get_shape().lens(); auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
p.add_instruction(op, input); p.add_instruction(op, input);
...@@ -483,12 +488,12 @@ struct test_global_avg_pooling ...@@ -483,12 +488,12 @@ struct test_global_avg_pooling
struct test_global_max_pooling struct test_global_max_pooling
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
auto input = auto input =
p.add_parameter("x", migraph::shape{migraph::shape::float_type, {1, 3, 16, 16}}); p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto op = migraph::op::pooling{"max"}; auto op = migraphx::op::pooling{"max"};
auto lens = input->get_shape().lens(); auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
p.add_instruction(op, input); p.add_instruction(op, input);
...@@ -498,88 +503,90 @@ struct test_global_max_pooling ...@@ -498,88 +503,90 @@ struct test_global_max_pooling
struct test_gemm struct test_gemm
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {4, 5}}); auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {5, 3}}); auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}});
p.add_instruction(migraph::op::dot{}, a, b); p.add_instruction(migraphx::op::dot{}, a, b);
return p; return p;
} }
}; };
struct test_gemm_half struct test_gemm_half
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::half_type, {4, 5}}); auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::half_type, {4, 5}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::half_type, {5, 3}}); auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::half_type, {5, 3}});
p.add_instruction(migraph::op::dot{}, a, b); p.add_instruction(migraphx::op::dot{}, a, b);
return p; return p;
} }
}; };
struct test_gemm_ld struct test_gemm_ld
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {4, 5}, {10, 1}}); auto a =
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {5, 3}, {20, 1}}); p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}, {10, 1}});
p.add_instruction(migraph::op::dot{}, a, b); auto b =
p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}, {20, 1}});
p.add_instruction(migraphx::op::dot{}, a, b);
return p; return p;
} }
}; };
struct test_gemm_transposeb struct test_gemm_transposeb
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {4, 5}}); auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {3, 5}}); auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {3, 5}});
auto bt = p.add_instruction(migraph::op::transpose{{1, 0}}, b); auto bt = p.add_instruction(migraphx::op::transpose{{1, 0}}, b);
p.add_instruction(migraph::op::dot{}, a, bt); p.add_instruction(migraphx::op::dot{}, a, bt);
return p; return p;
} }
}; };
struct test_gemm_transposea struct test_gemm_transposea
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {5, 4}}); auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {5, 4}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {5, 3}}); auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}});
auto at = p.add_instruction(migraph::op::transpose{{1, 0}}, a); auto at = p.add_instruction(migraphx::op::transpose{{1, 0}}, a);
p.add_instruction(migraph::op::dot{}, at, b); p.add_instruction(migraphx::op::dot{}, at, b);
return p; return p;
} }
}; };
struct test_gemm_transposeab struct test_gemm_transposeab
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {5, 4}}); auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {5, 4}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {3, 5}}); auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {3, 5}});
auto at = p.add_instruction(migraph::op::transpose{{1, 0}}, a); auto at = p.add_instruction(migraphx::op::transpose{{1, 0}}, a);
auto bt = p.add_instruction(migraph::op::transpose{{1, 0}}, b); auto bt = p.add_instruction(migraphx::op::transpose{{1, 0}}, b);
p.add_instruction(migraph::op::dot{}, at, bt); p.add_instruction(migraphx::op::dot{}, at, bt);
return p; return p;
} }
}; };
struct test_contiguous struct test_contiguous
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}}; migraphx::shape s{migraphx::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
p.add_instruction(migraph::op::contiguous{}, x); p.add_instruction(migraphx::op::contiguous{}, x);
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
return p; return p;
} }
...@@ -587,14 +594,14 @@ struct test_contiguous ...@@ -587,14 +594,14 @@ struct test_contiguous
struct test_transpose struct test_transpose
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {4, 3, 4, 4}}; migraphx::shape s{migraphx::shape::float_type, {4, 3, 4, 4}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
std::vector<int64_t> perm = {0, 2, 3, 1}; std::vector<int64_t> perm = {0, 2, 3, 1};
auto l = p.add_instruction(migraph::op::transpose{perm}, x); auto l = p.add_instruction(migraphx::op::transpose{perm}, x);
p.add_instruction(migraph::op::contiguous{}, l); p.add_instruction(migraphx::op::contiguous{}, l);
return p; return p;
} }
}; };
...@@ -606,18 +613,18 @@ struct test_batchnorm_inference_2 ...@@ -606,18 +613,18 @@ struct test_batchnorm_inference_2
const size_t channels = 256; const size_t channels = 256;
const size_t batches = 1; const size_t batches = 1;
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}}; migraphx::shape s{migraphx::shape::float_type, {batches, channels, height, width}};
migraph::shape vars{migraph::shape::float_type, {channels}}; migraphx::shape vars{migraphx::shape::float_type, {channels}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1))); auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2))); auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3))); auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4))); auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction(migraph::op::batch_norm_inference{}, x, scale, bias, mean, variance); p.add_instruction(migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance);
return p; return p;
} }
}; };
...@@ -629,200 +636,201 @@ struct test_batchnorm_inference ...@@ -629,200 +636,201 @@ struct test_batchnorm_inference
const size_t channels = 3; const size_t channels = 3;
const size_t batches = 4; const size_t batches = 4;
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}}; migraphx::shape s{migraphx::shape::float_type, {batches, channels, height, width}};
migraph::shape vars{migraph::shape::float_type, {channels}}; migraphx::shape vars{migraphx::shape::float_type, {channels}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1))); auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2))); auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3))); auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4))); auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction(migraph::op::batch_norm_inference{}, x, scale, bias, mean, variance); p.add_instruction(migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance);
return p; return p;
} }
}; };
struct test_conv_bn struct test_conv_bn
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
migraph::shape xs{migraph::shape::float_type, {1, 3, 224, 224}}; migraphx::shape xs{migraphx::shape::float_type, {1, 3, 224, 224}};
migraph::shape ws{migraph::shape::float_type, {64, 3, 7, 7}}; migraphx::shape ws{migraphx::shape::float_type, {64, 3, 7, 7}};
migraph::shape vars{migraph::shape::float_type, {64}}; migraphx::shape vars{migraphx::shape::float_type, {64}};
auto x = p.add_parameter("x", xs); auto x = p.add_parameter("x", xs);
auto w = p.add_parameter("w", ws); auto w = p.add_parameter("w", ws);
auto conv = p.add_instruction(migraph::op::convolution{{3, 3}, {2, 2}, {1, 1}}, x, w); auto conv = p.add_instruction(migraphx::op::convolution{{3, 3}, {2, 2}, {1, 1}}, x, w);
auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1))); auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2))); auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3))); auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4))); auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction(migraph::op::batch_norm_inference{}, conv, scale, bias, mean, variance); p.add_instruction(migraphx::op::batch_norm_inference{}, conv, scale, bias, mean, variance);
return p; return p;
} }
}; };
struct test_conv_bn_relu_pooling struct test_conv_bn_relu_pooling
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
migraph::shape xs{migraph::shape::float_type, {1, 3, 224, 224}}; migraphx::shape xs{migraphx::shape::float_type, {1, 3, 224, 224}};
migraph::shape ws{migraph::shape::float_type, {64, 3, 7, 7}}; migraphx::shape ws{migraphx::shape::float_type, {64, 3, 7, 7}};
migraph::shape vars{migraph::shape::float_type, {64}}; migraphx::shape vars{migraphx::shape::float_type, {64}};
auto x = p.add_parameter("x", xs); auto x = p.add_parameter("x", xs);
auto w = p.add_parameter("w", ws); auto w = p.add_parameter("w", ws);
auto conv = p.add_instruction(migraph::op::convolution{{3, 3}, {2, 2}, {1, 1}}, x, w); auto conv = p.add_instruction(migraphx::op::convolution{{3, 3}, {2, 2}, {1, 1}}, x, w);
auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1))); auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2))); auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3))); auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4))); auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
auto bn = p.add_instruction( auto bn = p.add_instruction(
migraph::op::batch_norm_inference{}, conv, scale, bias, mean, variance); migraphx::op::batch_norm_inference{}, conv, scale, bias, mean, variance);
auto relu = p.add_instruction(migraph::op::relu{}, bn); auto relu = p.add_instruction(migraphx::op::relu{}, bn);
p.add_instruction(migraph::op::pooling{"average", {1, 1}, {2, 2}, {3, 3}}, relu); p.add_instruction(migraphx::op::pooling{"average", {1, 1}, {2, 2}, {3, 3}}, relu);
return p; return p;
} }
}; };
struct test_concat struct test_concat
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
std::size_t axis = 1; std::size_t axis = 1;
migraph::shape s0{migraph::shape::int32_type, {2, 2}}; migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraph::shape s1{migraph::shape::int32_type, {2, 3}}; migraphx::shape s1{migraphx::shape::int32_type, {2, 3}};
migraph::shape s2{migraph::shape::int32_type, {2, 1}}; migraphx::shape s2{migraphx::shape::int32_type, {2, 1}};
auto l0 = p.add_parameter("x", s0); auto l0 = p.add_parameter("x", s0);
auto l1 = p.add_parameter("y", s1); auto l1 = p.add_parameter("y", s1);
auto l2 = p.add_parameter("z", s2); auto l2 = p.add_parameter("z", s2);
p.add_instruction(migraph::op::concat{axis}, l0, l1, l2); p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2);
return p; return p;
} }
}; };
struct test_concat2 struct test_concat2
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
std::size_t axis = 0; std::size_t axis = 0;
migraph::shape s0{migraph::shape::int32_type, {2, 2}}; migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraph::shape s1{migraph::shape::int32_type, {3, 2}}; migraphx::shape s1{migraphx::shape::int32_type, {3, 2}};
migraph::shape s2{migraph::shape::int32_type, {1, 2}}; migraphx::shape s2{migraphx::shape::int32_type, {1, 2}};
auto l0 = p.add_parameter("x", s0); auto l0 = p.add_parameter("x", s0);
auto l1 = p.add_parameter("y", s1); auto l1 = p.add_parameter("y", s1);
auto l2 = p.add_parameter("z", s2); auto l2 = p.add_parameter("z", s2);
p.add_instruction(migraph::op::concat{axis}, l0, l1, l2); p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2);
return p; return p;
} }
}; };
struct test_concat_relu struct test_concat_relu
{ {
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
std::size_t axis = 0; std::size_t axis = 0;
migraph::shape s0{migraph::shape::float_type, {2, 2}}; migraphx::shape s0{migraphx::shape::float_type, {2, 2}};
migraph::shape s1{migraph::shape::float_type, {3, 2}}; migraphx::shape s1{migraphx::shape::float_type, {3, 2}};
migraph::shape s2{migraph::shape::float_type, {1, 2}}; migraphx::shape s2{migraphx::shape::float_type, {1, 2}};
auto l0 = p.add_parameter("x", s0); auto l0 = p.add_parameter("x", s0);
auto l1 = p.add_parameter("y", s1); auto l1 = p.add_parameter("y", s1);
auto l2 = p.add_parameter("z", s2); auto l2 = p.add_parameter("z", s2);
auto r0 = p.add_instruction(migraph::op::relu{}, l0); auto r0 = p.add_instruction(migraphx::op::relu{}, l0);
auto r1 = p.add_instruction(migraph::op::relu{}, l1); auto r1 = p.add_instruction(migraphx::op::relu{}, l1);
auto r2 = p.add_instruction(migraph::op::relu{}, l2); auto r2 = p.add_instruction(migraphx::op::relu{}, l2);
auto c0 = p.add_instruction(migraph::op::concat{axis}, r0, r1, r2); auto c0 = p.add_instruction(migraphx::op::concat{axis}, r0, r1, r2);
p.add_instruction(migraph::op::relu{}, c0); p.add_instruction(migraphx::op::relu{}, c0);
return p; return p;
} }
}; };
void manual_identity() void manual_identity()
{ {
migraph::program p; migraphx::program p;
std::vector<float> data0 = {0, 1, 2, 3}; std::vector<float> data0 = {0, 1, 2, 3};
migraph::shape s0{migraph::shape::float_type, {2, 2}}; migraphx::shape s0{migraphx::shape::float_type, {2, 2}};
auto l0 = p.add_literal(migraph::literal{s0, data0}); auto l0 = p.add_literal(migraphx::literal{s0, data0});
p.add_instruction(migraph::op::identity{}, l0); p.add_instruction(migraphx::op::identity{}, l0);
p.compile(migraph::gpu::target{}); p.compile(migraphx::gpu::target{});
migraph::program::parameter_map m; migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes()) for(auto&& x : p.get_parameter_shapes())
{ {
m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second)); m[x.first] = migraphx::gpu::to_gpu(migraphx::generate_argument(x.second));
} }
auto result = migraph::gpu::from_gpu(p.eval(m)); auto result = migraphx::gpu::from_gpu(p.eval(m));
std::cout << result << std::endl; std::cout << result << std::endl;
} }
void manual_test_concat_relu() void manual_test_concat_relu()
{ {
migraph::program p; migraphx::program p;
std::size_t axis = 0; std::size_t axis = 0;
std::vector<float> data0 = {0, 1, 2, 3}; std::vector<float> data0 = {0, 1, 2, 3};
std::vector<float> data1 = {4, 5, 6, 7, 8, 9}; std::vector<float> data1 = {4, 5, 6, 7, 8, 9};
std::vector<float> data2 = {10, 11}; std::vector<float> data2 = {10, 11};
migraph::shape s0{migraph::shape::float_type, {2, 2}}; migraphx::shape s0{migraphx::shape::float_type, {2, 2}};
migraph::shape s1{migraph::shape::float_type, {3, 2}}; migraphx::shape s1{migraphx::shape::float_type, {3, 2}};
migraph::shape s2{migraph::shape::float_type, {1, 2}}; migraphx::shape s2{migraphx::shape::float_type, {1, 2}};
auto l0 = p.add_literal(migraph::literal{s0, data0}); auto l0 = p.add_literal(migraphx::literal{s0, data0});
auto l1 = p.add_literal(migraph::literal{s1, data1}); auto l1 = p.add_literal(migraphx::literal{s1, data1});
auto l2 = p.add_literal(migraph::literal{s2, data2}); auto l2 = p.add_literal(migraphx::literal{s2, data2});
auto r0 = p.add_instruction(migraph::op::relu{}, l0); auto r0 = p.add_instruction(migraphx::op::relu{}, l0);
auto r1 = p.add_instruction(migraph::op::relu{}, l1); auto r1 = p.add_instruction(migraphx::op::relu{}, l1);
auto r2 = p.add_instruction(migraph::op::relu{}, l2); auto r2 = p.add_instruction(migraphx::op::relu{}, l2);
auto c0 = p.add_instruction(migraph::op::concat{axis}, r0, r1, r2); auto c0 = p.add_instruction(migraphx::op::concat{axis}, r0, r1, r2);
p.add_instruction(migraph::op::relu{}, c0); p.add_instruction(migraphx::op::relu{}, c0);
p.compile(migraph::gpu::target{}); p.compile(migraphx::gpu::target{});
migraph::program::parameter_map m; migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes()) for(auto&& x : p.get_parameter_shapes())
{ {
m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second)); m[x.first] = migraphx::gpu::to_gpu(migraphx::generate_argument(x.second));
} }
auto result = migraph::gpu::from_gpu(p.eval(m)); auto result = migraphx::gpu::from_gpu(p.eval(m));
std::cout << result << std::endl; std::cout << result << std::endl;
} }
struct test_conv_bn_relu_pooling2 struct test_conv_bn_relu_pooling2
{ {
static migraph::instruction_ref static migraphx::instruction_ref
add_bn(migraph::program& p, migraph::instruction_ref x, std::size_t channels) add_bn(migraphx::program& p, migraphx::instruction_ref x, std::size_t channels)
{ {
migraph::shape vars{migraph::shape::float_type, {channels}}; migraphx::shape vars{migraphx::shape::float_type, {channels}};
auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1 + channels))); auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + channels)));
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2 + channels))); auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + channels)));
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3 + channels))); auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + channels)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4 + channels))); auto variance =
p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + channels)));
return p.add_instruction( return p.add_instruction(
migraph::op::batch_norm_inference{}, x, scale, bias, mean, variance); migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance);
} }
migraph::program create_program() const migraphx::program create_program() const
{ {
migraph::program p; migraphx::program p;
migraph::shape xs1{migraph::shape::float_type, {1, 512, 7, 7}}; migraphx::shape xs1{migraphx::shape::float_type, {1, 512, 7, 7}};
migraph::shape xs2{migraph::shape::float_type, {1, 1024, 14, 14}}; migraphx::shape xs2{migraphx::shape::float_type, {1, 1024, 14, 14}};
migraph::shape ws1{migraph::shape::float_type, {2048, 512, 1, 1}}; migraphx::shape ws1{migraphx::shape::float_type, {2048, 512, 1, 1}};
migraph::shape ws2{migraph::shape::float_type, {2048, 1024, 1, 1}}; migraphx::shape ws2{migraphx::shape::float_type, {2048, 1024, 1, 1}};
auto x1 = p.add_parameter("x1", xs1); auto x1 = p.add_parameter("x1", xs1);
auto w1 = p.add_parameter("w1", ws1); auto w1 = p.add_parameter("w1", ws1);
auto conv1 = p.add_instruction(migraph::op::convolution{{0, 0}, {1, 1}, {1, 1}}, x1, w1); auto conv1 = p.add_instruction(migraphx::op::convolution{{0, 0}, {1, 1}, {1, 1}}, x1, w1);
auto bn1 = add_bn(p, conv1, 2048); auto bn1 = add_bn(p, conv1, 2048);
auto x2 = p.add_parameter("x2", xs2); auto x2 = p.add_parameter("x2", xs2);
auto w2 = p.add_parameter("w2", ws2); auto w2 = p.add_parameter("w2", ws2);
auto conv2 = p.add_instruction(migraph::op::convolution{{0, 0}, {2, 2}, {1, 1}}, x2, w2); auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}, {1, 1}}, x2, w2);
auto bn2 = add_bn(p, conv2, 2048); auto bn2 = add_bn(p, conv2, 2048);
auto add = p.add_instruction(migraph::op::add{}, bn1, bn2); auto add = p.add_instruction(migraphx::op::add{}, bn1, bn2);
auto relu = p.add_instruction(migraph::op::relu{}, add); auto relu = p.add_instruction(migraphx::op::relu{}, add);
p.add_instruction(migraph::op::pooling{"average", {1, 1}, {2, 2}, {3, 3}}, relu); p.add_instruction(migraphx::op::pooling{"average", {1, 1}, {2, 2}, {3, 3}}, relu);
return p; return p;
} }
}; };
......
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/argument.hpp> #include <migraphx/argument.hpp>
#include <migraph/shape.hpp> #include <migraphx/shape.hpp>
struct sum_op struct sum_op
{ {
std::string name() const { return "sum"; } std::string name() const { return "sum"; }
migraph::argument migraphx::argument
compute(migraph::context&, const migraph::shape&, std::vector<migraph::argument> args) const compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
{ {
migraph::argument result; migraphx::argument result;
if(args.size() != 2) if(args.size() != 2)
MIGRAPH_THROW("Wrong args"); MIGRAPH_THROW("Wrong args");
if(args[0].get_shape() != args[1].get_shape()) if(args[0].get_shape() != args[1].get_shape())
...@@ -19,12 +19,12 @@ struct sum_op ...@@ -19,12 +19,12 @@ struct sum_op
MIGRAPH_THROW("Wrong args"); MIGRAPH_THROW("Wrong args");
args[0].visit_at([&](auto x) { args[0].visit_at([&](auto x) {
args[1].visit_at([&](auto y) { result = migraph::literal{x + y}.get_argument(); }); args[1].visit_at([&](auto y) { result = migraphx::literal{x + y}.get_argument(); });
}); });
return result; return result;
} }
migraph::shape compute_shape(std::vector<migraph::shape> inputs) const migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{ {
if(inputs.size() != 2) if(inputs.size() != 2)
MIGRAPH_THROW("Wrong inputs"); MIGRAPH_THROW("Wrong inputs");
...@@ -35,10 +35,10 @@ struct sum_op ...@@ -35,10 +35,10 @@ struct sum_op
struct minus_op struct minus_op
{ {
std::string name() const { return "minus"; } std::string name() const { return "minus"; }
migraph::argument migraphx::argument
compute(migraph::context&, const migraph::shape&, std::vector<migraph::argument> args) const compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
{ {
migraph::argument result; migraphx::argument result;
if(args.size() != 2) if(args.size() != 2)
MIGRAPH_THROW("Wrong args"); MIGRAPH_THROW("Wrong args");
if(args[0].get_shape() != args[1].get_shape()) if(args[0].get_shape() != args[1].get_shape())
...@@ -49,12 +49,12 @@ struct minus_op ...@@ -49,12 +49,12 @@ struct minus_op
MIGRAPH_THROW("Wrong args"); MIGRAPH_THROW("Wrong args");
args[0].visit_at([&](auto x) { args[0].visit_at([&](auto x) {
args[1].visit_at([&](auto y) { result = migraph::literal{x - y}.get_argument(); }); args[1].visit_at([&](auto y) { result = migraphx::literal{x - y}.get_argument(); });
}); });
return result; return result;
} }
migraph::shape compute_shape(std::vector<migraph::shape> inputs) const migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{ {
if(inputs.size() != 2) if(inputs.size() != 2)
MIGRAPH_THROW("Wrong inputs"); MIGRAPH_THROW("Wrong inputs");
...@@ -65,35 +65,35 @@ struct minus_op ...@@ -65,35 +65,35 @@ struct minus_op
struct pass_op struct pass_op
{ {
std::string name() const { return "pass"; } std::string name() const { return "pass"; }
migraph::argument migraphx::argument
compute(migraph::context&, const migraph::shape&, std::vector<migraph::argument> args) const compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
{ {
if(args.empty()) if(args.empty())
return {}; return {};
return args.front(); return args.front();
} }
migraph::shape compute_shape(std::vector<migraph::shape> inputs) const migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{ {
if(inputs.empty()) if(inputs.empty())
return {}; return {};
return inputs.front(); return inputs.front();
} }
int output_alias(const std::vector<migraph::shape>&) const { return 0; } int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
}; };
struct pass_standard_op struct pass_standard_op
{ {
std::string name() const { return "pass"; } std::string name() const { return "pass"; }
migraph::argument migraphx::argument
compute(migraph::context&, const migraph::shape&, std::vector<migraph::argument> args) const compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
{ {
if(args.empty()) if(args.empty())
return {}; return {};
return args.front(); return args.front();
} }
migraph::shape compute_shape(std::vector<migraph::shape> inputs) const migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{ {
for(auto&& input : inputs) for(auto&& input : inputs)
{ {
...@@ -104,37 +104,38 @@ struct pass_standard_op ...@@ -104,37 +104,38 @@ struct pass_standard_op
return {}; return {};
return inputs.front(); return inputs.front();
} }
int output_alias(const std::vector<migraph::shape>&) const { return 0; } int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
}; };
struct nop struct nop
{ {
std::string name() const { return "nop"; } std::string name() const { return "nop"; }
migraph::argument migraphx::argument compute(migraphx::context&,
compute(migraph::context&, const migraph::shape&, const std::vector<migraph::argument>&) const const migraphx::shape&,
const std::vector<migraphx::argument>&) const
{ {
return {}; return {};
} }
migraph::shape compute_shape(const std::vector<migraph::shape>&) const { return {}; } migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; }
}; };
inline migraph::literal get_2x2() inline migraphx::literal get_2x2()
{ {
return migraph::literal{{migraph::shape::float_type, {2, 2}}, {1, 2, 3, 4}}; return migraphx::literal{{migraphx::shape::float_type, {2, 2}}, {1, 2, 3, 4}};
} }
inline migraph::literal get_2x2_transposed() inline migraphx::literal get_2x2_transposed()
{ {
return migraph::literal{{migraph::shape::float_type, {2, 2}, {1, 2}}, {1, 2, 3, 4}}; return migraphx::literal{{migraphx::shape::float_type, {2, 2}, {1, 2}}, {1, 2, 3, 4}};
} }
inline migraph::literal get_2() inline migraphx::literal get_2()
{ {
return migraph::literal{{migraph::shape::float_type, {2}}, {1, 2}}; return migraphx::literal{{migraphx::shape::float_type, {2}}, {1, 2}};
} }
inline migraph::literal get_2_broadcasted() inline migraphx::literal get_2_broadcasted()
{ {
return migraph::literal{{migraph::shape::float_type, {2, 1}, {1, 0}}, {1, 2}}; return migraphx::literal{{migraphx::shape::float_type, {2, 1}, {1, 0}}, {1, 2}};
} }
#include <migraph/literal.hpp> #include <migraphx/literal.hpp>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include "test.hpp" #include "test.hpp"
TEST_CASE(literal_test) TEST_CASE(literal_test)
{ {
EXPECT(migraph::literal{1} == migraph::literal{1}); EXPECT(migraphx::literal{1} == migraphx::literal{1});
EXPECT(migraph::literal{1} != migraph::literal{2}); EXPECT(migraphx::literal{1} != migraphx::literal{2});
EXPECT(migraph::literal{} == migraph::literal{}); EXPECT(migraphx::literal{} == migraphx::literal{});
EXPECT(migraph::literal{} != migraph::literal{2}); EXPECT(migraphx::literal{} != migraphx::literal{2});
migraph::literal l1{1}; migraphx::literal l1{1};
migraph::literal l2 = l1; // NOLINT migraphx::literal l2 = l1; // NOLINT
EXPECT(l1 == l2); EXPECT(l1 == l2);
EXPECT(l1.at<int>(0) == 1); EXPECT(l1.at<int>(0) == 1);
EXPECT(!l1.empty()); EXPECT(!l1.empty());
EXPECT(!l2.empty()); EXPECT(!l2.empty());
migraph::literal l3{}; migraphx::literal l3{};
migraph::literal l4{}; migraphx::literal l4{};
EXPECT(l3 == l4); EXPECT(l3 == l4);
EXPECT(l3.empty()); EXPECT(l3.empty());
EXPECT(l4.empty()); EXPECT(l4.empty());
...@@ -27,7 +27,7 @@ TEST_CASE(literal_test) ...@@ -27,7 +27,7 @@ TEST_CASE(literal_test)
TEST_CASE(literal_os1) TEST_CASE(literal_os1)
{ {
migraph::literal l{1}; migraphx::literal l{1};
std::stringstream ss; std::stringstream ss;
ss << l; ss << l;
EXPECT(ss.str() == "1"); EXPECT(ss.str() == "1");
...@@ -35,7 +35,7 @@ TEST_CASE(literal_os1) ...@@ -35,7 +35,7 @@ TEST_CASE(literal_os1)
TEST_CASE(literal_os2) TEST_CASE(literal_os2)
{ {
migraph::literal l{}; migraphx::literal l{};
std::stringstream ss; std::stringstream ss;
ss << l; ss << l;
EXPECT(ss.str().empty()); EXPECT(ss.str().empty());
...@@ -43,8 +43,8 @@ TEST_CASE(literal_os2) ...@@ -43,8 +43,8 @@ TEST_CASE(literal_os2)
TEST_CASE(literal_os3) TEST_CASE(literal_os3)
{ {
migraph::shape s{migraph::shape::int64_type, {3}}; migraphx::shape s{migraphx::shape::int64_type, {3}};
migraph::literal l{s, {1, 2, 3}}; migraphx::literal l{s, {1, 2, 3}};
std::stringstream ss; std::stringstream ss;
ss << l; ss << l;
EXPECT(ss.str() == "1, 2, 3"); EXPECT(ss.str() == "1, 2, 3");
......
#include <migraph/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraph/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <test.hpp> #include <test.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
namespace match = migraph::match; namespace match = migraphx::match;
template <class M> template <class M>
migraph::match::matcher_result find_match(migraph::program& p, M&& m) migraphx::match::matcher_result find_match(migraphx::program& p, M&& m)
{ {
migraph::match::matcher_result result; migraphx::match::matcher_result result;
for(auto ins : migraph::iterator_for(p)) for(auto ins : migraphx::iterator_for(p))
{ {
result = migraph::match::match_instruction(p, ins, m); result = migraphx::match::match_instruction(p, ins, m);
if(result.result != p.end()) if(result.result != p.end())
return result; return result;
} }
...@@ -20,7 +20,7 @@ migraph::match::matcher_result find_match(migraph::program& p, M&& m) ...@@ -20,7 +20,7 @@ migraph::match::matcher_result find_match(migraph::program& p, M&& m)
void match1() void match1()
{ {
migraph::program p; migraphx::program p;
auto l = p.add_literal(1); auto l = p.add_literal(1);
auto m = match::standard_shape(); auto m = match::standard_shape();
auto r = find_match(p, m); auto r = find_match(p, m);
...@@ -29,7 +29,7 @@ void match1() ...@@ -29,7 +29,7 @@ void match1()
TEST_CASE(match_name1) TEST_CASE(match_name1)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
...@@ -41,7 +41,7 @@ TEST_CASE(match_name1) ...@@ -41,7 +41,7 @@ TEST_CASE(match_name1)
TEST_CASE(match_name2) TEST_CASE(match_name2)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
...@@ -53,7 +53,7 @@ TEST_CASE(match_name2) ...@@ -53,7 +53,7 @@ TEST_CASE(match_name2)
TEST_CASE(match_name3) TEST_CASE(match_name3)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
...@@ -65,7 +65,7 @@ TEST_CASE(match_name3) ...@@ -65,7 +65,7 @@ TEST_CASE(match_name3)
TEST_CASE(match_arg1) TEST_CASE(match_arg1)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
...@@ -77,7 +77,7 @@ TEST_CASE(match_arg1) ...@@ -77,7 +77,7 @@ TEST_CASE(match_arg1)
TEST_CASE(match_arg2) TEST_CASE(match_arg2)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
...@@ -89,7 +89,7 @@ TEST_CASE(match_arg2) ...@@ -89,7 +89,7 @@ TEST_CASE(match_arg2)
TEST_CASE(match_arg3) TEST_CASE(match_arg3)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
...@@ -101,7 +101,7 @@ TEST_CASE(match_arg3) ...@@ -101,7 +101,7 @@ TEST_CASE(match_arg3)
TEST_CASE(match_arg4) TEST_CASE(match_arg4)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
...@@ -113,7 +113,7 @@ TEST_CASE(match_arg4) ...@@ -113,7 +113,7 @@ TEST_CASE(match_arg4)
TEST_CASE(match_arg5) TEST_CASE(match_arg5)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
...@@ -125,7 +125,7 @@ TEST_CASE(match_arg5) ...@@ -125,7 +125,7 @@ TEST_CASE(match_arg5)
TEST_CASE(match_arg6) TEST_CASE(match_arg6)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
...@@ -137,7 +137,7 @@ TEST_CASE(match_arg6) ...@@ -137,7 +137,7 @@ TEST_CASE(match_arg6)
TEST_CASE(match_arg7) TEST_CASE(match_arg7)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
...@@ -150,7 +150,7 @@ TEST_CASE(match_arg7) ...@@ -150,7 +150,7 @@ TEST_CASE(match_arg7)
TEST_CASE(match_args1) TEST_CASE(match_args1)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
...@@ -163,7 +163,7 @@ TEST_CASE(match_args1) ...@@ -163,7 +163,7 @@ TEST_CASE(match_args1)
TEST_CASE(match_args2) TEST_CASE(match_args2)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
...@@ -176,7 +176,7 @@ TEST_CASE(match_args2) ...@@ -176,7 +176,7 @@ TEST_CASE(match_args2)
TEST_CASE(match_args3) TEST_CASE(match_args3)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
...@@ -188,7 +188,7 @@ TEST_CASE(match_args3) ...@@ -188,7 +188,7 @@ TEST_CASE(match_args3)
TEST_CASE(match_args4) TEST_CASE(match_args4)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two); auto sum1 = p.add_instruction(sum_op{}, one, two);
...@@ -202,7 +202,7 @@ TEST_CASE(match_args4) ...@@ -202,7 +202,7 @@ TEST_CASE(match_args4)
TEST_CASE(match_args5) TEST_CASE(match_args5)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
...@@ -215,7 +215,7 @@ TEST_CASE(match_args5) ...@@ -215,7 +215,7 @@ TEST_CASE(match_args5)
TEST_CASE(match_args6) TEST_CASE(match_args6)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
...@@ -227,7 +227,7 @@ TEST_CASE(match_args6) ...@@ -227,7 +227,7 @@ TEST_CASE(match_args6)
TEST_CASE(match_args7) TEST_CASE(match_args7)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
...@@ -241,7 +241,7 @@ TEST_CASE(match_args7) ...@@ -241,7 +241,7 @@ TEST_CASE(match_args7)
TEST_CASE(match_either_args1) TEST_CASE(match_either_args1)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two); auto sum1 = p.add_instruction(sum_op{}, one, two);
...@@ -255,7 +255,7 @@ TEST_CASE(match_either_args1) ...@@ -255,7 +255,7 @@ TEST_CASE(match_either_args1)
TEST_CASE(match_either_args2) TEST_CASE(match_either_args2)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two); auto sum1 = p.add_instruction(sum_op{}, one, two);
...@@ -269,7 +269,7 @@ TEST_CASE(match_either_args2) ...@@ -269,7 +269,7 @@ TEST_CASE(match_either_args2)
TEST_CASE(match_either_args3) TEST_CASE(match_either_args3)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two); auto sum1 = p.add_instruction(sum_op{}, one, two);
...@@ -283,7 +283,7 @@ TEST_CASE(match_either_args3) ...@@ -283,7 +283,7 @@ TEST_CASE(match_either_args3)
TEST_CASE(match_all_of1) TEST_CASE(match_all_of1)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
...@@ -296,7 +296,7 @@ TEST_CASE(match_all_of1) ...@@ -296,7 +296,7 @@ TEST_CASE(match_all_of1)
TEST_CASE(match_all_of2) TEST_CASE(match_all_of2)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
...@@ -309,7 +309,7 @@ TEST_CASE(match_all_of2) ...@@ -309,7 +309,7 @@ TEST_CASE(match_all_of2)
TEST_CASE(match_any_of1) TEST_CASE(match_any_of1)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
...@@ -322,7 +322,7 @@ TEST_CASE(match_any_of1) ...@@ -322,7 +322,7 @@ TEST_CASE(match_any_of1)
TEST_CASE(match_any_of2) TEST_CASE(match_any_of2)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
...@@ -335,7 +335,7 @@ TEST_CASE(match_any_of2) ...@@ -335,7 +335,7 @@ TEST_CASE(match_any_of2)
TEST_CASE(match_none_of1) TEST_CASE(match_none_of1)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
...@@ -348,7 +348,7 @@ TEST_CASE(match_none_of1) ...@@ -348,7 +348,7 @@ TEST_CASE(match_none_of1)
TEST_CASE(match_none_of2) TEST_CASE(match_none_of2)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
...@@ -361,7 +361,7 @@ TEST_CASE(match_none_of2) ...@@ -361,7 +361,7 @@ TEST_CASE(match_none_of2)
TEST_CASE(match_bind1) TEST_CASE(match_bind1)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
...@@ -382,18 +382,18 @@ TEST_CASE(match_bind1) ...@@ -382,18 +382,18 @@ TEST_CASE(match_bind1)
struct match_find_sum struct match_find_sum
{ {
migraph::instruction_ref ins; migraphx::instruction_ref ins;
auto matcher() const { return match::name("sum"); } auto matcher() const { return match::name("sum"); }
void apply(migraph::program&, match::matcher_result r) const { EXPECT(bool{r.result == ins}); } void apply(migraphx::program&, match::matcher_result r) const { EXPECT(bool{r.result == ins}); }
}; };
struct match_find_literal struct match_find_literal
{ {
migraph::instruction_ref ins; migraphx::instruction_ref ins;
auto matcher() const { return match::name("@literal"); } auto matcher() const { return match::name("@literal"); }
void apply(migraph::program&, match::matcher_result r) const void apply(migraphx::program&, match::matcher_result r) const
{ {
EXPECT(bool{r.result != ins}); EXPECT(bool{r.result != ins});
EXPECT(r.result->name() == "@literal"); EXPECT(r.result->name() == "@literal");
...@@ -402,7 +402,7 @@ struct match_find_literal ...@@ -402,7 +402,7 @@ struct match_find_literal
TEST_CASE(match_finder) TEST_CASE(match_finder)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
......
#include <migraph/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <migraph/generate.hpp> #include <migraphx/generate.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
struct memory_coloring_target struct memory_coloring_target
{ {
std::string name() const { return "memory_coloring"; } std::string name() const { return "memory_coloring"; }
std::vector<migraph::pass> get_passes(migraph::context&) const std::vector<migraphx::pass> get_passes(migraphx::context&) const
{ {
return {migraph::memory_coloring{"allocate", true}}; return {migraphx::memory_coloring{"allocate", true}};
} }
migraph::context get_context() const { return {}; } migraphx::context get_context() const { return {}; }
}; };
struct allocate struct allocate
{ {
migraph::shape s{}; migraphx::shape s{};
std::string name() const { return "allocate"; } std::string name() const { return "allocate"; }
migraph::shape compute_shape(const std::vector<migraph::shape>& inputs) const migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{ {
migraph::check_shapes{inputs, *this}.has(1); migraphx::check_shapes{inputs, *this}.has(1);
return inputs.front(); return inputs.front();
} }
migraph::argument compute(migraph::context&, migraphx::argument compute(migraphx::context&,
const migraph::shape& output_shape, const migraphx::shape& output_shape,
const std::vector<migraph::argument>&) const const std::vector<migraphx::argument>&) const
{ {
return {output_shape}; return {output_shape};
} }
}; };
migraph::instruction_ref add_alloc(migraph::program& p, const migraph::shape& s) migraphx::instruction_ref add_alloc(migraphx::program& p, const migraphx::shape& s)
{ {
auto a0 = p.add_outline(s); auto a0 = p.add_outline(s);
return p.add_instruction(allocate{}, a0); return p.add_instruction(allocate{}, a0);
} }
bool no_allocate(const migraph::program& p) bool no_allocate(const migraphx::program& p)
{ {
return std::none_of(p.begin(), p.end(), [](auto&& ins) { return ins.name() == "allocate"; }); return std::none_of(p.begin(), p.end(), [](auto&& ins) { return ins.name() == "allocate"; });
} }
TEST_CASE(test1) TEST_CASE(test1)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a2, p1); p.add_instruction(pass_op{}, a2, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 192); CHECK(p.get_parameter_shape("scratch").bytes() == 192);
...@@ -57,12 +57,12 @@ TEST_CASE(test1) ...@@ -57,12 +57,12 @@ TEST_CASE(test1)
TEST_CASE(test2) TEST_CASE(test2)
{ {
migraph::program p; migraphx::program p;
auto input = p.add_parameter("input", migraph::shape{migraph::shape::float_type, {16}}); auto input = p.add_parameter("input", migraphx::shape{migraphx::shape::float_type, {16}});
auto a1 = add_alloc(p, {migraph::shape::float_type, {128}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {128}});
auto p1 = p.add_instruction(pass_op{}, a1, input); auto p1 = p.add_instruction(pass_op{}, a1, input);
auto p2 = add_alloc(p, {migraph::shape::float_type, {40}}); auto p2 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, p2, p1); p.add_instruction(pass_op{}, p2, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 672); CHECK(p.get_parameter_shape("scratch").bytes() == 672);
...@@ -71,11 +71,11 @@ TEST_CASE(test2) ...@@ -71,11 +71,11 @@ TEST_CASE(test2)
TEST_CASE(test3) TEST_CASE(test3)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p2 = add_alloc(p, {migraph::shape::float_type, {128}}); auto p2 = add_alloc(p, {migraphx::shape::float_type, {128}});
auto p1 = p.add_instruction(pass_op{}, p2, a1); auto p1 = p.add_instruction(pass_op{}, p2, a1);
auto p3 = add_alloc(p, {migraph::shape::float_type, {40}}); auto p3 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, p3, p1); p.add_instruction(pass_op{}, p3, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 704); // The optimal solution is actually 672 CHECK(p.get_parameter_shape("scratch").bytes() == 704); // The optimal solution is actually 672
...@@ -84,11 +84,11 @@ TEST_CASE(test3) ...@@ -84,11 +84,11 @@ TEST_CASE(test3)
TEST_CASE(test4) TEST_CASE(test4)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {0}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {0}});
auto p2 = add_alloc(p, {migraph::shape::float_type, {128}}); auto p2 = add_alloc(p, {migraphx::shape::float_type, {128}});
auto p1 = p.add_instruction(pass_op{}, p2, a1); auto p1 = p.add_instruction(pass_op{}, p2, a1);
auto p3 = add_alloc(p, {migraph::shape::float_type, {40}}); auto p3 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, p3, p1); p.add_instruction(pass_op{}, p3, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 672); CHECK(p.get_parameter_shape("scratch").bytes() == 672);
...@@ -97,10 +97,10 @@ TEST_CASE(test4) ...@@ -97,10 +97,10 @@ TEST_CASE(test4)
TEST_CASE(test5) TEST_CASE(test5)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto p2 = add_alloc(p, {migraph::shape::float_type, {8}}); auto p2 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, p2, p1); p.add_instruction(pass_op{}, p2, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 192); CHECK(p.get_parameter_shape("scratch").bytes() == 192);
...@@ -109,11 +109,11 @@ TEST_CASE(test5) ...@@ -109,11 +109,11 @@ TEST_CASE(test5)
TEST_CASE(test6) TEST_CASE(test6)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto p2 = add_alloc(p, {migraph::shape::float_type, {40}}); auto p2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p3 = add_alloc(p, {migraph::shape::float_type, {40}}); auto p3 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, p3, p2, p1); p.add_instruction(pass_op{}, p3, p2, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 352); CHECK(p.get_parameter_shape("scratch").bytes() == 352);
...@@ -122,11 +122,11 @@ TEST_CASE(test6) ...@@ -122,11 +122,11 @@ TEST_CASE(test6)
TEST_CASE(test7) TEST_CASE(test7)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto p2 = add_alloc(p, {migraph::shape::float_type, {40}}); auto p2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p3 = add_alloc(p, {migraph::shape::float_type, {8}}); auto p3 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, p3, p2, p1); p.add_instruction(pass_op{}, p3, p2, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 224); CHECK(p.get_parameter_shape("scratch").bytes() == 224);
...@@ -135,11 +135,11 @@ TEST_CASE(test7) ...@@ -135,11 +135,11 @@ TEST_CASE(test7)
TEST_CASE(test8) TEST_CASE(test8)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto p2 = add_alloc(p, {migraph::shape::float_type, {40}}); auto p2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p3 = add_alloc(p, {migraph::shape::float_type, {192}}); auto p3 = add_alloc(p, {migraphx::shape::float_type, {192}});
p.add_instruction(pass_op{}, p3, p2, p1); p.add_instruction(pass_op{}, p3, p2, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 960); CHECK(p.get_parameter_shape("scratch").bytes() == 960);
...@@ -148,11 +148,11 @@ TEST_CASE(test8) ...@@ -148,11 +148,11 @@ TEST_CASE(test8)
TEST_CASE(test9) TEST_CASE(test9)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto p2 = add_alloc(p, {migraph::shape::float_type, {8}}); auto p2 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p3 = add_alloc(p, {migraph::shape::float_type, {8}}); auto p3 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, p3, p2, p1); p.add_instruction(pass_op{}, p3, p2, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 96); CHECK(p.get_parameter_shape("scratch").bytes() == 96);
...@@ -161,8 +161,8 @@ TEST_CASE(test9) ...@@ -161,8 +161,8 @@ TEST_CASE(test9)
TEST_CASE(test10) TEST_CASE(test10)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, a1); p.add_instruction(pass_op{}, a1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 32); CHECK(p.get_parameter_shape("scratch").bytes() == 32);
...@@ -171,11 +171,11 @@ TEST_CASE(test10) ...@@ -171,11 +171,11 @@ TEST_CASE(test10)
TEST_CASE(test11) TEST_CASE(test11)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.add_instruction(pass_op{}, a3, p2); p.add_instruction(pass_op{}, a3, p2);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
...@@ -185,11 +185,11 @@ TEST_CASE(test11) ...@@ -185,11 +185,11 @@ TEST_CASE(test11)
TEST_CASE(test12) TEST_CASE(test12)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.add_instruction(pass_op{}, a3, p2); p.add_instruction(pass_op{}, a3, p2);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
...@@ -199,11 +199,11 @@ TEST_CASE(test12) ...@@ -199,11 +199,11 @@ TEST_CASE(test12)
TEST_CASE(test13) TEST_CASE(test13)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.add_instruction(pass_op{}, a3, p2); p.add_instruction(pass_op{}, a3, p2);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
...@@ -213,10 +213,10 @@ TEST_CASE(test13) ...@@ -213,10 +213,10 @@ TEST_CASE(test13)
TEST_CASE(test14) TEST_CASE(test14)
{ {
migraph::program p; migraphx::program p;
auto a3 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.add_instruction(pass_op{}, a3, p2); p.add_instruction(pass_op{}, a3, p2);
...@@ -227,12 +227,12 @@ TEST_CASE(test14) ...@@ -227,12 +227,12 @@ TEST_CASE(test14)
TEST_CASE(test15) TEST_CASE(test15)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a2); auto p2 = p.add_instruction(pass_op{}, a2);
auto a3 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a3, p1, p2); p.add_instruction(pass_op{}, a3, p1, p2);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 352); CHECK(p.get_parameter_shape("scratch").bytes() == 352);
...@@ -241,12 +241,12 @@ TEST_CASE(test15) ...@@ -241,12 +241,12 @@ TEST_CASE(test15)
TEST_CASE(test16) TEST_CASE(test16)
{ {
migraph::program p; migraphx::program p;
auto a1 = p.add_literal(migraph::generate_literal({migraph::shape::float_type, {8}})); auto a1 = p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {8}}));
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = p.add_literal(migraph::generate_literal({migraph::shape::float_type, {40}})); auto a2 = p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {40}}));
auto p2 = p.add_instruction(pass_op{}, a2); auto p2 = p.add_instruction(pass_op{}, a2);
auto a3 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a3, p1, p2); p.add_instruction(pass_op{}, a3, p1, p2);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 160); CHECK(p.get_parameter_shape("scratch").bytes() == 160);
...@@ -255,11 +255,11 @@ TEST_CASE(test16) ...@@ -255,11 +255,11 @@ TEST_CASE(test16)
TEST_CASE(test17) TEST_CASE(test17)
{ {
migraph::program p; migraphx::program p;
auto a3 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto a1 = p.add_literal(migraph::generate_literal({migraph::shape::float_type, {8}})); auto a1 = p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {8}}));
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = p.add_literal(migraph::generate_literal({migraph::shape::float_type, {40}})); auto a2 = p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {40}}));
auto p2 = p.add_instruction(pass_op{}, a2); auto p2 = p.add_instruction(pass_op{}, a2);
p.add_instruction(pass_op{}, a3, p1, p2); p.add_instruction(pass_op{}, a3, p1, p2);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
...@@ -269,12 +269,12 @@ TEST_CASE(test17) ...@@ -269,12 +269,12 @@ TEST_CASE(test17)
TEST_CASE(test18) TEST_CASE(test18)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto p2 = p.add_instruction(pass_op{}, a1, p1); auto p2 = p.add_instruction(pass_op{}, a1, p1);
auto p3 = p.add_instruction(pass_op{}, p2, p1); auto p3 = p.add_instruction(pass_op{}, p2, p1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a2, p1, p2, p3); p.add_instruction(pass_op{}, a2, p1, p2, p3);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 192); CHECK(p.get_parameter_shape("scratch").bytes() == 192);
...@@ -283,12 +283,12 @@ TEST_CASE(test18) ...@@ -283,12 +283,12 @@ TEST_CASE(test18)
TEST_CASE(test19) TEST_CASE(test19)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto p2 = p.add_instruction(pass_op{}, a2, p1);
auto a3 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a3, p2, p1); p.add_instruction(pass_op{}, a3, p2, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 352); CHECK(p.get_parameter_shape("scratch").bytes() == 352);
...@@ -297,12 +297,12 @@ TEST_CASE(test19) ...@@ -297,12 +297,12 @@ TEST_CASE(test19)
TEST_CASE(test20) TEST_CASE(test20)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {32}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {32}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {32}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {32}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {32}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {32}});
auto p1 = p.add_instruction(pass_op{}, a1, a2, a3); auto p1 = p.add_instruction(pass_op{}, a1, a2, a3);
auto a4 = add_alloc(p, {migraph::shape::float_type, {32}}); auto a4 = add_alloc(p, {migraphx::shape::float_type, {32}});
p.add_instruction(pass_op{}, a4, p1); p.add_instruction(pass_op{}, a4, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 384); CHECK(p.get_parameter_shape("scratch").bytes() == 384);
...@@ -311,12 +311,12 @@ TEST_CASE(test20) ...@@ -311,12 +311,12 @@ TEST_CASE(test20)
TEST_CASE(test21) TEST_CASE(test21)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {32}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {32}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {32}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {32}});
auto p1 = p.add_instruction(pass_op{}, a1, a2, a3); auto p1 = p.add_instruction(pass_op{}, a1, a2, a3);
auto a4 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a4 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, a4, p1); p.add_instruction(pass_op{}, a4, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 288); CHECK(p.get_parameter_shape("scratch").bytes() == 288);
...@@ -325,12 +325,12 @@ TEST_CASE(test21) ...@@ -325,12 +325,12 @@ TEST_CASE(test21)
TEST_CASE(test22) TEST_CASE(test22)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {32}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {32}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {32}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {32}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1, a2, a3); auto p1 = p.add_instruction(pass_op{}, a1, a2, a3);
auto a4 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a4 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, a4, p1); p.add_instruction(pass_op{}, a4, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 288); CHECK(p.get_parameter_shape("scratch").bytes() == 288);
...@@ -339,12 +339,12 @@ TEST_CASE(test22) ...@@ -339,12 +339,12 @@ TEST_CASE(test22)
TEST_CASE(test23) TEST_CASE(test23)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {32}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {32}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {32}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {32}});
auto p1 = p.add_instruction(pass_op{}, a1, a2, a3); auto p1 = p.add_instruction(pass_op{}, a1, a2, a3);
auto a4 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a4 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, a4, p1); p.add_instruction(pass_op{}, a4, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 288); CHECK(p.get_parameter_shape("scratch").bytes() == 288);
...@@ -353,12 +353,12 @@ TEST_CASE(test23) ...@@ -353,12 +353,12 @@ TEST_CASE(test23)
TEST_CASE(test24) TEST_CASE(test24)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {32}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {32}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {32}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {32}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {32}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {32}});
auto p1 = p.add_instruction(pass_op{}, a1, a2, a3); auto p1 = p.add_instruction(pass_op{}, a1, a2, a3);
auto a4 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a4 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, a4, p1); p.add_instruction(pass_op{}, a4, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 384); CHECK(p.get_parameter_shape("scratch").bytes() == 384);
...@@ -367,12 +367,12 @@ TEST_CASE(test24) ...@@ -367,12 +367,12 @@ TEST_CASE(test24)
TEST_CASE(test25) TEST_CASE(test25)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(nop{}); p.add_instruction(nop{});
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
p.add_instruction(nop{}); p.add_instruction(nop{});
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a2, p1); p.add_instruction(pass_op{}, a2, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 192); CHECK(p.get_parameter_shape("scratch").bytes() == 192);
...@@ -381,12 +381,12 @@ TEST_CASE(test25) ...@@ -381,12 +381,12 @@ TEST_CASE(test25)
TEST_CASE(test26) TEST_CASE(test26)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(nop{}, a1); p.add_instruction(nop{}, a1);
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
p.add_instruction(nop{}, a1, p1); p.add_instruction(nop{}, a1, p1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a2, p1); p.add_instruction(pass_op{}, a2, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 192); CHECK(p.get_parameter_shape("scratch").bytes() == 192);
...@@ -395,10 +395,10 @@ TEST_CASE(test26) ...@@ -395,10 +395,10 @@ TEST_CASE(test26)
TEST_CASE(test27) TEST_CASE(test27)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(nop{}, a2, p1); p.add_instruction(nop{}, a2, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 192); CHECK(p.get_parameter_shape("scratch").bytes() == 192);
...@@ -407,11 +407,11 @@ TEST_CASE(test27) ...@@ -407,11 +407,11 @@ TEST_CASE(test27)
TEST_CASE(test28) TEST_CASE(test28)
{ {
migraph::program p; migraphx::program p;
auto output = p.add_parameter("output", {migraph::shape::float_type, {8}}); auto output = p.add_parameter("output", {migraphx::shape::float_type, {8}});
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.add_instruction(pass_op{}, p2, output); p.add_instruction(pass_op{}, p2, output);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
...@@ -421,11 +421,11 @@ TEST_CASE(test28) ...@@ -421,11 +421,11 @@ TEST_CASE(test28)
TEST_CASE(test29) TEST_CASE(test29)
{ {
migraph::program p; migraphx::program p;
auto output = p.add_parameter("output", {migraph::shape::float_type, {8}}); auto output = p.add_parameter("output", {migraphx::shape::float_type, {8}});
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.move_instruction(output, p2); p.move_instruction(output, p2);
p.add_instruction(pass_op{}, p2, output); p.add_instruction(pass_op{}, p2, output);
...@@ -436,11 +436,11 @@ TEST_CASE(test29) ...@@ -436,11 +436,11 @@ TEST_CASE(test29)
TEST_CASE(test30) TEST_CASE(test30)
{ {
migraph::program p; migraphx::program p;
auto output = p.add_parameter("x", {migraph::shape::float_type, {8}}); auto output = p.add_parameter("x", {migraphx::shape::float_type, {8}});
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a2, p1); auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.move_instruction(output, p2); p.move_instruction(output, p2);
p.add_instruction(pass_op{}, p2, output); p.add_instruction(pass_op{}, p2, output);
...@@ -451,11 +451,11 @@ TEST_CASE(test30) ...@@ -451,11 +451,11 @@ TEST_CASE(test30)
TEST_CASE(test31) TEST_CASE(test31)
{ {
migraph::program p; migraphx::program p;
auto output = p.add_parameter("output", {migraph::shape::float_type, {8}}); auto output = p.add_parameter("output", {migraphx::shape::float_type, {8}});
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1); auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.move_instruction(output, a2); p.move_instruction(output, a2);
p.add_instruction(pass_op{}, a2, p1); p.add_instruction(pass_op{}, a2, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
...@@ -465,12 +465,12 @@ TEST_CASE(test31) ...@@ -465,12 +465,12 @@ TEST_CASE(test31)
TEST_CASE(test32) TEST_CASE(test32)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p1 = p.add_instruction(pass_op{}, a2, a1, a3); auto p1 = p.add_instruction(pass_op{}, a2, a1, a3);
auto a5 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a5 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a5, p1); p.add_instruction(pass_op{}, a5, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 352); CHECK(p.get_parameter_shape("scratch").bytes() == 352);
...@@ -479,12 +479,12 @@ TEST_CASE(test32) ...@@ -479,12 +479,12 @@ TEST_CASE(test32)
TEST_CASE(test33) TEST_CASE(test33)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a2, a1, a3); auto p1 = p.add_instruction(pass_op{}, a2, a1, a3);
auto a5 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a5 = add_alloc(p, {migraphx::shape::float_type, {40}});
p.add_instruction(pass_op{}, a5, p1); p.add_instruction(pass_op{}, a5, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 224); CHECK(p.get_parameter_shape("scratch").bytes() == 224);
...@@ -493,12 +493,12 @@ TEST_CASE(test33) ...@@ -493,12 +493,12 @@ TEST_CASE(test33)
TEST_CASE(test34) TEST_CASE(test34)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p1 = p.add_instruction(pass_op{}, a2, a1, a3); auto p1 = p.add_instruction(pass_op{}, a2, a1, a3);
auto a5 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a5 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, a5, p1); p.add_instruction(pass_op{}, a5, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 480); CHECK(p.get_parameter_shape("scratch").bytes() == 480);
...@@ -507,12 +507,12 @@ TEST_CASE(test34) ...@@ -507,12 +507,12 @@ TEST_CASE(test34)
TEST_CASE(test35) TEST_CASE(test35)
{ {
migraph::program p; migraphx::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a2, a1, a3); auto p1 = p.add_instruction(pass_op{}, a2, a1, a3);
auto a5 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a5 = add_alloc(p, {migraphx::shape::float_type, {8}});
p.add_instruction(pass_op{}, a5, p1); p.add_instruction(pass_op{}, a5, p1);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 224); CHECK(p.get_parameter_shape("scratch").bytes() == 224);
...@@ -521,14 +521,14 @@ TEST_CASE(test35) ...@@ -521,14 +521,14 @@ TEST_CASE(test35)
TEST_CASE(test36) TEST_CASE(test36)
{ {
migraph::program p; migraphx::program p;
auto output = p.add_parameter("output", {migraph::shape::float_type, {20}}); auto output = p.add_parameter("output", {migraphx::shape::float_type, {20}});
auto a1 = add_alloc(p, {migraph::shape::float_type, {0}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {0}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p1 = p.add_instruction(pass_op{}, a2, a1); auto p1 = p.add_instruction(pass_op{}, a2, a1);
auto a3 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a3, p1); auto p2 = p.add_instruction(pass_op{}, a3, p1);
auto a4 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a4 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p3 = p.add_instruction(pass_op{}, a4, p2); auto p3 = p.add_instruction(pass_op{}, a4, p2);
p.add_instruction(pass_op{}, output, p3); p.add_instruction(pass_op{}, output, p3);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
...@@ -538,14 +538,14 @@ TEST_CASE(test36) ...@@ -538,14 +538,14 @@ TEST_CASE(test36)
TEST_CASE(test37) TEST_CASE(test37)
{ {
migraph::program p; migraphx::program p;
auto output = p.add_parameter("output", {migraph::shape::float_type, {20}}); auto output = p.add_parameter("output", {migraphx::shape::float_type, {20}});
auto a1 = add_alloc(p, {migraph::shape::float_type, {4}}); auto a1 = add_alloc(p, {migraphx::shape::float_type, {4}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p1 = p.add_instruction(pass_op{}, a2, a1); auto p1 = p.add_instruction(pass_op{}, a2, a1);
auto a3 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a3, p1); auto p2 = p.add_instruction(pass_op{}, a3, p1);
auto a4 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a4 = add_alloc(p, {migraphx::shape::float_type, {40}});
auto p3 = p.add_instruction(pass_op{}, a4, p2); auto p3 = p.add_instruction(pass_op{}, a4, p2);
p.add_instruction(pass_op{}, output, p3); p.add_instruction(pass_op{}, output, p3);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
...@@ -555,42 +555,42 @@ TEST_CASE(test37) ...@@ -555,42 +555,42 @@ TEST_CASE(test37)
TEST_CASE(test38) TEST_CASE(test38)
{ {
migraph::program p; migraphx::program p;
auto output = p.add_parameter("output", {migraph::shape::float_type, {1, 64, 56, 56}}); auto output = p.add_parameter("output", {migraphx::shape::float_type, {1, 64, 56, 56}});
auto p29 = add_alloc(p, {migraph::shape::float_type, {0}}); auto p29 = add_alloc(p, {migraphx::shape::float_type, {0}});
auto p30 = add_alloc(p, {migraph::shape::float_type, {1, 64, 112, 112}}); auto p30 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 112, 112}});
auto p31 = p.add_instruction(pass_op{}, p30, p29); auto p31 = p.add_instruction(pass_op{}, p30, p29);
auto p32 = add_alloc(p, {migraph::shape::float_type, {1, 64, 112, 112}}); auto p32 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 112, 112}});
auto p37 = p.add_instruction(pass_op{}, p32, p31); auto p37 = p.add_instruction(pass_op{}, p32, p31);
auto p38 = add_alloc(p, {migraph::shape::float_type, {1, 64, 112, 112}}); auto p38 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 112, 112}});
auto p39 = p.add_instruction(pass_op{}, p38, p37); auto p39 = p.add_instruction(pass_op{}, p38, p37);
auto p40 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}}); auto p40 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}});
auto p41 = p.add_instruction(pass_op{}, p40, p39); auto p41 = p.add_instruction(pass_op{}, p40, p39);
auto p42 = add_alloc(p, {migraph::shape::float_type, {0}}); auto p42 = add_alloc(p, {migraphx::shape::float_type, {0}});
auto p43 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}}); auto p43 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}});
auto p44 = p.add_instruction(pass_op{}, p43, p41, p42); auto p44 = p.add_instruction(pass_op{}, p43, p41, p42);
auto p45 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}}); auto p45 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}});
auto p50 = p.add_instruction(pass_op{}, p45, p44); auto p50 = p.add_instruction(pass_op{}, p45, p44);
auto p51 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}}); auto p51 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}});
auto p52 = p.add_instruction(pass_op{}, p51, p50); auto p52 = p.add_instruction(pass_op{}, p51, p50);
auto p53 = add_alloc(p, {migraph::shape::float_type, {0}}); auto p53 = add_alloc(p, {migraphx::shape::float_type, {0}});
auto p54 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}}); auto p54 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}});
auto p55 = p.add_instruction(pass_op{}, p54, p52, p53); auto p55 = p.add_instruction(pass_op{}, p54, p52, p53);
auto p56 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}}); auto p56 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}});
auto p61 = p.add_instruction(pass_op{}, p56, p55); auto p61 = p.add_instruction(pass_op{}, p56, p55);
auto p62 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}}); auto p62 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}});
auto p63 = p.add_instruction(pass_op{}, p62, p61, p41); auto p63 = p.add_instruction(pass_op{}, p62, p61, p41);
auto p64 = add_alloc(p, {migraph::shape::float_type, {0}}); auto p64 = add_alloc(p, {migraphx::shape::float_type, {0}});
auto p65 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}}); auto p65 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}});
auto p66 = p.add_instruction(pass_op{}, p65, p63, p64); auto p66 = p.add_instruction(pass_op{}, p65, p63, p64);
auto p67 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}}); auto p67 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}});
auto p72 = p.add_instruction(pass_op{}, p67, p66); auto p72 = p.add_instruction(pass_op{}, p67, p66);
auto p73 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}}); auto p73 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}});
auto p74 = p.add_instruction(pass_op{}, p73, p72); auto p74 = p.add_instruction(pass_op{}, p73, p72);
auto p75 = add_alloc(p, {migraph::shape::float_type, {0}}); auto p75 = add_alloc(p, {migraphx::shape::float_type, {0}});
auto p76 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}}); auto p76 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}});
auto p77 = p.add_instruction(pass_op{}, p76, p74, p75); auto p77 = p.add_instruction(pass_op{}, p76, p74, p75);
auto p78 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}}); auto p78 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}});
auto p83 = p.add_instruction(pass_op{}, p78, p77); auto p83 = p.add_instruction(pass_op{}, p78, p77);
p.add_instruction(pass_op{}, output, p83, p63); p.add_instruction(pass_op{}, output, p83, p63);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
...@@ -600,8 +600,8 @@ TEST_CASE(test38) ...@@ -600,8 +600,8 @@ TEST_CASE(test38)
TEST_CASE(literal_test) TEST_CASE(literal_test)
{ {
migraph::program p; migraphx::program p;
auto lit = generate_literal(migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}); auto lit = generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
p.add_literal(lit); p.add_literal(lit);
p.compile(memory_coloring_target{}); p.compile(memory_coloring_target{});
auto result = p.eval({}); auto result = p.eval({});
......
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <migraph/literal.hpp> #include <migraphx/literal.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraph/onnx.hpp> #include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
void pytorch_conv_bias_test() void pytorch_conv_bias_test()
{ {
migraph::program p; migraphx::program p;
auto l0 = p.add_parameter("0", {migraph::shape::float_type, {1, 3, 32, 32}}); auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraph::shape::float_type, {1, 3, 5, 5}}); auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {1, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {1}}); auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1); auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2); auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2);
p.add_instruction(migraph::op::add{}, l3, l4); p.add_instruction(migraphx::op::add{}, l3, l4);
auto prog = migraph::parse_onnx("conv.onnx"); auto prog = migraphx::parse_onnx("conv.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
void pytorch_conv_relu_maxpool() void pytorch_conv_relu_maxpool()
{ {
migraph::program p; migraphx::program p;
auto l0 = p.add_parameter("0", {migraph::shape::float_type, {1, 3, 32, 32}}); auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraph::shape::float_type, {1, 3, 5, 5}}); auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {1, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {1}}); auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1); auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2); auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2);
auto l5 = p.add_instruction(migraph::op::add{}, l3, l4); auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::op::relu{}, l5); auto l6 = p.add_instruction(migraphx::op::relu{}, l5);
p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6); p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
auto prog = migraph::parse_onnx("conv_relu_maxpool.onnx"); auto prog = migraphx::parse_onnx("conv_relu_maxpool.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
void pytorch_conv_bn_relu_maxpool() void pytorch_conv_bn_relu_maxpool()
{ {
migraph::program p; migraphx::program p;
auto l0 = p.add_parameter("0", {migraph::shape::float_type, {1, 3, 32, 32}}); auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraph::shape::float_type, {1, 3, 5, 5}}); auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {1, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {1}}); auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}});
auto p3 = p.add_parameter("3", {migraph::shape::float_type, {1}}); auto p3 = p.add_parameter("3", {migraphx::shape::float_type, {1}});
auto p4 = p.add_parameter("4", {migraph::shape::float_type, {1}}); auto p4 = p.add_parameter("4", {migraphx::shape::float_type, {1}});
auto p5 = p.add_parameter("5", {migraph::shape::float_type, {1}}); auto p5 = p.add_parameter("5", {migraphx::shape::float_type, {1}});
auto p6 = p.add_parameter("6", {migraph::shape::float_type, {1}}); auto p6 = p.add_parameter("6", {migraphx::shape::float_type, {1}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1); auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2); auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2);
auto l5 = p.add_instruction(migraph::op::add{}, l3, l4); auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::op::batch_norm_inference{1.0e-5f}, l5, p3, p4, p5, p6); auto l6 = p.add_instruction(migraphx::op::batch_norm_inference{1.0e-5f}, l5, p3, p4, p5, p6);
auto l7 = p.add_instruction(migraph::op::relu{}, l6); auto l7 = p.add_instruction(migraphx::op::relu{}, l6);
p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l7); p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l7);
auto prog = migraph::parse_onnx("conv_bn_relu_maxpool.onnx"); auto prog = migraphx::parse_onnx("conv_bn_relu_maxpool.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
void pytorch_conv_relu_maxpool_x2() void pytorch_conv_relu_maxpool_x2()
{ {
migraph::program p; migraphx::program p;
auto l0 = p.add_parameter("0", {migraph::shape::float_type, {1, 3, 32, 32}}); auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraph::shape::float_type, {5, 3, 5, 5}}); auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {5, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {5}}); auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {5}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1); auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2); auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2);
auto l5 = p.add_instruction(migraph::op::add{}, l3, l4); auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::op::relu{}, l5); auto l6 = p.add_instruction(migraphx::op::relu{}, l5);
auto l7 = p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6); auto l7 = p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
auto l8 = p.add_parameter("3", {migraph::shape::float_type, {1, 5, 5, 5}}); auto l8 = p.add_parameter("3", {migraphx::shape::float_type, {1, 5, 5, 5}});
auto l9 = p.add_parameter("4", {migraph::shape::float_type, {1}}); auto l9 = p.add_parameter("4", {migraphx::shape::float_type, {1}});
auto l10 = p.add_instruction(migraph::op::convolution{}, l7, l8); auto l10 = p.add_instruction(migraphx::op::convolution{}, l7, l8);
auto l11 = p.add_instruction(migraph::op::broadcast{axis, l10->get_shape()}, l9); auto l11 = p.add_instruction(migraphx::op::broadcast{axis, l10->get_shape()}, l9);
auto l12 = p.add_instruction(migraph::op::add{}, l10, l11); auto l12 = p.add_instruction(migraphx::op::add{}, l10, l11);
auto l13 = p.add_instruction(migraph::op::relu{}, l12); auto l13 = p.add_instruction(migraphx::op::relu{}, l12);
p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13); p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13);
auto prog = migraph::parse_onnx("conv_relu_maxpoolX2.onnx"); auto prog = migraphx::parse_onnx("conv_relu_maxpoolX2.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
void leaky_relu_test() void leaky_relu_test()
{ {
migraph::program p; migraphx::program p;
float alpha = 0.01f; float alpha = 0.01f;
auto l0 = p.add_parameter("0", {migraph::shape::float_type, {3}}); auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {3}});
p.add_instruction(migraph::op::leaky_relu{alpha}, l0); p.add_instruction(migraphx::op::leaky_relu{alpha}, l0);
auto prog = migraph::parse_onnx("leaky_relu.onnx"); auto prog = migraphx::parse_onnx("leaky_relu.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
void imagescaler_test() void imagescaler_test()
{ {
migraph::program p; migraphx::program p;
migraph::shape s{migraph::shape::float_type, {1, 3, 16, 16}}; migraphx::shape s{migraphx::shape::float_type, {1, 3, 16, 16}};
auto l0 = p.add_parameter("0", s); auto l0 = p.add_parameter("0", s);
auto scale_val = p.add_literal(0.5f); auto scale_val = p.add_literal(0.5f);
auto bias_vals = p.add_literal( auto bias_vals = p.add_literal(
migraph::literal{migraph::shape{migraph::shape::float_type, {3}}, {0.01, 0.02, 0.03}}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
auto scaled_tensor = p.add_instruction(migraph::op::scalar{s}, scale_val); auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s}, scale_val);
auto img_scaled = p.add_instruction(migraph::op::mul{}, l0, scaled_tensor); auto img_scaled = p.add_instruction(migraphx::op::mul{}, l0, scaled_tensor);
auto bias_bcast = p.add_instruction(migraph::op::broadcast{1, s}, bias_vals); auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s}, bias_vals);
p.add_instruction(migraph::op::add{}, img_scaled, bias_bcast); p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
auto prog = migraph::parse_onnx("imagescaler_test.onnx"); auto prog = migraphx::parse_onnx("imagescaler_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
void globalavgpool_test() void globalavgpool_test()
{ {
migraph::program p; migraphx::program p;
auto input = p.add_parameter("0", migraph::shape{migraph::shape::float_type, {1, 3, 16, 16}}); auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto op = migraph::op::pooling{"average"}; auto op = migraphx::op::pooling{"average"};
auto lens = input->get_shape().lens(); auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
p.add_instruction(op, input); p.add_instruction(op, input);
auto prog = migraph::parse_onnx("globalavgpool_test.onnx"); auto prog = migraphx::parse_onnx("globalavgpool_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
void globalmaxpool_test() void globalmaxpool_test()
{ {
migraph::program p; migraphx::program p;
auto input = p.add_parameter("0", migraph::shape{migraph::shape::float_type, {1, 3, 16, 16}}); auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto op = migraph::op::pooling{"max"}; auto op = migraphx::op::pooling{"max"};
auto lens = input->get_shape().lens(); auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
p.add_instruction(op, input); p.add_instruction(op, input);
auto prog = migraph::parse_onnx("globalmaxpool_test.onnx"); auto prog = migraphx::parse_onnx("globalmaxpool_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
void transpose_test() void transpose_test()
{ {
migraph::program p; migraphx::program p;
auto input = p.add_parameter("0", migraph::shape{migraph::shape::float_type, {1, 2, 2, 3}}); auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
std::vector<int64_t> perm{0, 3, 1, 2}; std::vector<int64_t> perm{0, 3, 1, 2};
p.add_instruction(migraph::op::transpose{perm}, input); p.add_instruction(migraphx::op::transpose{perm}, input);
auto prog = migraph::parse_onnx("transpose_test.onnx"); auto prog = migraphx::parse_onnx("transpose_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
void dropout_test() void dropout_test()
{ {
migraph::program p; migraphx::program p;
auto input = p.add_parameter("0", migraph::shape{migraph::shape::float_type, {1, 3, 2, 2}}); auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}});
p.add_instruction(migraph::op::identity{}, input); p.add_instruction(migraphx::op::identity{}, input);
auto prog = migraph::parse_onnx("dropout_test.onnx"); auto prog = migraphx::parse_onnx("dropout_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
......
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <sstream> #include <sstream>
#include "test.hpp" #include "test.hpp"
template <class... Ts> template <class... Ts>
void expect_shape(const migraph::shape& expected, const migraph::operation& op, Ts... xs) void expect_shape(const migraphx::shape& expected, const migraphx::operation& op, Ts... xs)
{ {
migraph::program p; migraphx::program p;
std::vector<migraph::shape> shapes{xs...}; std::vector<migraphx::shape> shapes{xs...};
std::vector<migraph::instruction_ref> args(shapes.size()); std::vector<migraphx::instruction_ref> args(shapes.size());
std::transform( std::transform(
shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return p.add_outline(s); }); shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return p.add_outline(s); });
p.add_instruction(op, args); p.add_instruction(op, args);
...@@ -24,11 +24,11 @@ void expect_shape(const migraph::shape& expected, const migraph::operation& op, ...@@ -24,11 +24,11 @@ void expect_shape(const migraph::shape& expected, const migraph::operation& op,
} }
template <class... Ts> template <class... Ts>
void throws_shape(const migraph::operation& op, Ts... xs) void throws_shape(const migraphx::operation& op, Ts... xs)
{ {
migraph::program p; migraphx::program p;
std::vector<migraph::shape> shapes{xs...}; std::vector<migraphx::shape> shapes{xs...};
std::vector<migraph::instruction_ref> args(shapes.size()); std::vector<migraphx::instruction_ref> args(shapes.size());
std::transform( std::transform(
shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return p.add_outline(s); }); shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return p.add_outline(s); });
bool thrown = test::throws([&] { p.add_instruction(op, args); }); bool thrown = test::throws([&] { p.add_instruction(op, args); });
...@@ -46,7 +46,7 @@ struct always_false : std::false_type ...@@ -46,7 +46,7 @@ struct always_false : std::false_type
}; };
template <class... Ts> template <class... Ts>
void throws_shape(const migraph::shape&, Ts...) void throws_shape(const migraphx::shape&, Ts...)
{ {
static_assert(always_false<Ts...>{}, static_assert(always_false<Ts...>{},
"An expected shape should not be passed to throws_shape function"); "An expected shape should not be passed to throws_shape function");
...@@ -55,94 +55,97 @@ void throws_shape(const migraph::shape&, Ts...) ...@@ -55,94 +55,97 @@ void throws_shape(const migraph::shape&, Ts...)
TEST_CASE(batch_norm_inference_shape) TEST_CASE(batch_norm_inference_shape)
{ {
const size_t channels = 3; const size_t channels = 3;
migraph::shape s{migraph::shape::float_type, {4, channels, 3, 3}}; migraphx::shape s{migraphx::shape::float_type, {4, channels, 3, 3}};
migraph::shape vars{migraph::shape::float_type, {channels}}; migraphx::shape vars{migraphx::shape::float_type, {channels}};
expect_shape(s, migraph::op::batch_norm_inference{}, s, vars, vars, vars, vars); expect_shape(s, migraphx::op::batch_norm_inference{}, s, vars, vars, vars, vars);
throws_shape(migraph::op::batch_norm_inference{}, s); throws_shape(migraphx::op::batch_norm_inference{}, s);
throws_shape(migraph::op::batch_norm_inference{}, s, vars, vars, vars, vars, vars); throws_shape(migraphx::op::batch_norm_inference{}, s, vars, vars, vars, vars, vars);
} }
TEST_CASE(convolution_shape) TEST_CASE(convolution_shape)
{ {
migraph::shape output{migraph::shape::float_type, {4, 4, 1, 1}}; migraphx::shape output{migraphx::shape::float_type, {4, 4, 1, 1}};
migraph::shape input{migraph::shape::float_type, {4, 3, 3, 3}}; migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
migraph::shape weights{migraph::shape::float_type, {4, 3, 3, 3}}; migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
expect_shape(output, migraph::op::convolution{}, input, weights); expect_shape(output, migraphx::op::convolution{}, input, weights);
throws_shape(migraph::op::convolution{}, input); throws_shape(migraphx::op::convolution{}, input);
migraph::shape input2{migraph::shape::float_type, {3, 3}}; migraphx::shape input2{migraphx::shape::float_type, {3, 3}};
migraph::shape weights2{migraph::shape::float_type, {3, 3}}; migraphx::shape weights2{migraphx::shape::float_type, {3, 3}};
throws_shape(migraph::op::convolution{}, input2, weights2); throws_shape(migraphx::op::convolution{}, input2, weights2);
throws_shape(migraph::op::convolution{}, input2, weights); throws_shape(migraphx::op::convolution{}, input2, weights);
} }
TEST_CASE(transpose_shape) TEST_CASE(transpose_shape)
{ {
migraph::shape input{migraph::shape::float_type, {2, 2}}; migraphx::shape input{migraphx::shape::float_type, {2, 2}};
migraph::shape output{migraph::shape::float_type, {2, 2}, {1, 2}}; migraphx::shape output{migraphx::shape::float_type, {2, 2}, {1, 2}};
expect_shape(input, migraph::op::transpose{{0, 1}}, input); expect_shape(input, migraphx::op::transpose{{0, 1}}, input);
expect_shape(output, migraph::op::transpose{{1, 0}}, input); expect_shape(output, migraphx::op::transpose{{1, 0}}, input);
throws_shape(migraph::op::transpose{{1, 2}}, input); throws_shape(migraphx::op::transpose{{1, 2}}, input);
} }
TEST_CASE(contiguous_shape) TEST_CASE(contiguous_shape)
{ {
migraph::shape output{migraph::shape::float_type, {2, 2}}; migraphx::shape output{migraphx::shape::float_type, {2, 2}};
migraph::shape input{migraph::shape::float_type, {2, 2}, {1, 2}}; migraphx::shape input{migraphx::shape::float_type, {2, 2}, {1, 2}};
expect_shape(output, migraph::op::contiguous{}, input); expect_shape(output, migraphx::op::contiguous{}, input);
throws_shape(migraph::op::contiguous{}, input, input); throws_shape(migraphx::op::contiguous{}, input, input);
migraph::shape single{migraph::shape::float_type, {2}}; migraphx::shape single{migraphx::shape::float_type, {2}};
expect_shape(single, migraph::op::contiguous{}, single); expect_shape(single, migraphx::op::contiguous{}, single);
} }
TEST_CASE(reshape_shape) TEST_CASE(reshape_shape)
{ {
migraph::shape input{migraph::shape::float_type, {24, 1, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}};
for(auto&& new_shape : for(auto&& new_shape :
std::vector<std::vector<int64_t>>{{8, 3, 1, 1}, {1, 3, 4, 2}, {1, 3, 4, 2}}) std::vector<std::vector<int64_t>>{{8, 3, 1, 1}, {1, 3, 4, 2}, {1, 3, 4, 2}})
{ {
std::vector<std::size_t> lens(new_shape.size()); std::vector<std::size_t> lens(new_shape.size());
std::copy(new_shape.begin(), new_shape.end(), lens.begin()); std::copy(new_shape.begin(), new_shape.end(), lens.begin());
migraph::shape output{migraph::shape::float_type, lens}; migraphx::shape output{migraphx::shape::float_type, lens};
expect_shape(output, migraph::op::reshape{new_shape}, input); expect_shape(output, migraphx::op::reshape{new_shape}, input);
} }
for(auto&& new_shape : std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}}) for(auto&& new_shape : std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}})
{ {
throws_shape(migraph::op::reshape{new_shape}, input); throws_shape(migraphx::op::reshape{new_shape}, input);
} }
} }
TEST_CASE(flatten_shape) TEST_CASE(flatten_shape)
{ {
migraph::shape input{migraph::shape::float_type, {2, 4, 6, 8}}; migraphx::shape input{migraphx::shape::float_type, {2, 4, 6, 8}};
expect_shape(migraph::shape{migraph::shape::float_type, {1, 2 * 4 * 6 * 8}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 2 * 4 * 6 * 8}},
migraph::op::flatten{0}, migraphx::op::flatten{0},
input); input);
expect_shape( expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 4 * 6 * 8}},
migraph::shape{migraph::shape::float_type, {2, 4 * 6 * 8}}, migraph::op::flatten{1}, input); migraphx::op::flatten{1},
expect_shape(
migraph::shape{migraph::shape::float_type, {2 * 4, 6 * 8}}, migraph::op::flatten{2}, input);
expect_shape(
migraph::shape{migraph::shape::float_type, {2 * 4 * 6, 8}}, migraph::op::flatten{3}, input);
expect_shape(migraph::shape{migraph::shape::float_type, {2 * 4 * 6 * 8, 1}},
migraph::op::flatten{4},
input); input);
throws_shape(migraph::op::flatten{5}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {2 * 4, 6 * 8}},
migraphx::op::flatten{2},
input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2 * 4 * 6, 8}},
migraphx::op::flatten{3},
input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2 * 4 * 6 * 8, 1}},
migraphx::op::flatten{4},
input);
throws_shape(migraphx::op::flatten{5}, input);
} }
TEST_CASE(slice_shape) TEST_CASE(slice_shape)
{ {
migraph::shape input{migraph::shape::int32_type, {2, 2, 3}}; migraphx::shape input{migraphx::shape::int32_type, {2, 2, 3}};
expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}}, expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraph::op::slice{{2}, {1}, {3}}, migraphx::op::slice{{2}, {1}, {3}},
input); input);
expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}}, expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraph::op::slice{{0, 1, 2}, {0, 0, 1}, {2, 2, 3}}, migraphx::op::slice{{0, 1, 2}, {0, 0, 1}, {2, 2, 3}},
input); input);
expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 1}, {6, 3, 1}}, expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 1}, {6, 3, 1}},
migraph::op::slice{{2}, {2}, {10}}, migraphx::op::slice{{2}, {2}, {10}},
input); input);
} }
...@@ -150,62 +153,62 @@ TEST_CASE(multibroadcast) ...@@ -150,62 +153,62 @@ TEST_CASE(multibroadcast)
{ {
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {2, 1, 3}}; migraphx::shape input{migraphx::shape::float_type, {2, 1, 3}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 3, 0, 1}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 0, 1}},
migraph::op::multibroadcast{lens}, migraphx::op::multibroadcast{lens},
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {2, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {2, 1, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 1, 0, 0}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 1, 0, 0}},
migraph::op::multibroadcast{lens}, migraphx::op::multibroadcast{lens},
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {5, 1}}; migraphx::shape input{migraphx::shape::float_type, {5, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 0, 1, 0}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 1, 0}},
migraph::op::multibroadcast{lens}, migraphx::op::multibroadcast{lens},
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {1, 0, 0, 0}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 0, 0, 0}},
migraph::op::multibroadcast{lens}, migraphx::op::multibroadcast{lens},
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {3}}; migraphx::shape input{migraphx::shape::float_type, {3}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 0, 0, 1}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 0, 1}},
migraph::op::multibroadcast{lens}, migraphx::op::multibroadcast{lens},
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 4, 1, 3}; std::vector<std::size_t> lens{4, 4, 1, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 3}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 3, 3, 1}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 3, 1}},
migraph::op::multibroadcast{lens}, migraphx::op::multibroadcast{lens},
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 1, 1, 3}; std::vector<std::size_t> lens{4, 1, 1, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {1, 1, 1, 0}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 1, 1, 0}},
migraph::op::multibroadcast{lens}, migraphx::op::multibroadcast{lens},
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 1, 3}; std::vector<std::size_t> lens{4, 1, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
throws_shape(migraph::op::multibroadcast{lens}, input); throws_shape(migraphx::op::multibroadcast{lens}, input);
} }
{ {
std::vector<std::size_t> lens{4, 1, 3}; std::vector<std::size_t> lens{4, 1, 3};
migraph::shape input{migraph::shape::float_type, {}}; migraphx::shape input{migraphx::shape::float_type, {}};
throws_shape(migraph::op::multibroadcast{lens}, input); throws_shape(migraphx::op::multibroadcast{lens}, input);
} }
} }
......
#include <migraph/operation.hpp> #include <migraphx/operation.hpp>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include "test.hpp" #include "test.hpp"
...@@ -9,16 +9,17 @@ struct simple_operation ...@@ -9,16 +9,17 @@ struct simple_operation
template <class T, class F> template <class T, class F>
static auto reflect(T& x, F f) static auto reflect(T& x, F f)
{ {
return migraph::pack(f(x.data, "data")); return migraphx::pack(f(x.data, "data"));
} }
int data = 1; int data = 1;
std::string name() const { return "simple"; } std::string name() const { return "simple"; }
migraph::shape compute_shape(const std::vector<migraph::shape>&) const migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const
{ {
MIGRAPH_THROW("not computable"); MIGRAPH_THROW("not computable");
} }
migraph::argument migraphx::argument compute(migraphx::context&,
compute(migraph::context&, const migraph::shape&, const std::vector<migraph::argument>&) const const migraphx::shape&,
const std::vector<migraphx::argument>&) const
{ {
MIGRAPH_THROW("not computable"); MIGRAPH_THROW("not computable");
} }
...@@ -32,12 +33,13 @@ struct simple_operation ...@@ -32,12 +33,13 @@ struct simple_operation
struct simple_operation_no_print struct simple_operation_no_print
{ {
std::string name() const { return "simple"; } std::string name() const { return "simple"; }
migraph::shape compute_shape(const std::vector<migraph::shape>&) const migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const
{ {
MIGRAPH_THROW("not computable"); MIGRAPH_THROW("not computable");
} }
migraph::argument migraphx::argument compute(migraphx::context&,
compute(migraph::context&, const migraph::shape&, const std::vector<migraph::argument>&) const const migraphx::shape&,
const std::vector<migraphx::argument>&) const
{ {
MIGRAPH_THROW("not computable"); MIGRAPH_THROW("not computable");
} }
...@@ -46,8 +48,8 @@ struct simple_operation_no_print ...@@ -46,8 +48,8 @@ struct simple_operation_no_print
TEST_CASE(operation_copy_test) TEST_CASE(operation_copy_test)
{ {
simple_operation s{}; simple_operation s{};
migraph::operation op1 = s; // NOLINT migraphx::operation op1 = s; // NOLINT
migraph::operation op2 = op1; // NOLINT migraphx::operation op2 = op1; // NOLINT
// cppcheck-suppress duplicateExpression // cppcheck-suppress duplicateExpression
EXPECT(s == op1); EXPECT(s == op1);
// cppcheck-suppress duplicateExpression // cppcheck-suppress duplicateExpression
...@@ -57,10 +59,10 @@ TEST_CASE(operation_copy_test) ...@@ -57,10 +59,10 @@ TEST_CASE(operation_copy_test)
TEST_CASE(operation_equal_test) TEST_CASE(operation_equal_test)
{ {
simple_operation s{}; simple_operation s{};
migraph::operation op1 = s; migraphx::operation op1 = s;
s.data = 2; s.data = 2;
migraph::operation op2 = op1; // NOLINT migraphx::operation op2 = op1; // NOLINT
migraph::operation op3 = s; // NOLINT migraphx::operation op3 = s; // NOLINT
EXPECT(s != op1); EXPECT(s != op1);
EXPECT(op2 == op1); EXPECT(op2 == op1);
...@@ -74,18 +76,18 @@ struct not_operation ...@@ -74,18 +76,18 @@ struct not_operation
TEST_CASE(operation_any_cast) TEST_CASE(operation_any_cast)
{ {
migraph::operation op1 = simple_operation{}; migraphx::operation op1 = simple_operation{};
EXPECT(migraph::any_cast<simple_operation>(op1).data == 1); EXPECT(migraphx::any_cast<simple_operation>(op1).data == 1);
EXPECT(migraph::any_cast<not_operation*>(&op1) == nullptr); EXPECT(migraphx::any_cast<not_operation*>(&op1) == nullptr);
EXPECT(test::throws([&] { migraph::any_cast<not_operation&>(op1); })); EXPECT(test::throws([&] { migraphx::any_cast<not_operation&>(op1); }));
migraph::operation op2 = simple_operation{2}; migraphx::operation op2 = simple_operation{2};
EXPECT(migraph::any_cast<simple_operation>(op2).data == 2); EXPECT(migraphx::any_cast<simple_operation>(op2).data == 2);
EXPECT(migraph::any_cast<not_operation*>(&op2) == nullptr); EXPECT(migraphx::any_cast<not_operation*>(&op2) == nullptr);
} }
TEST_CASE(operation_print) TEST_CASE(operation_print)
{ {
migraph::operation op = simple_operation{}; migraphx::operation op = simple_operation{};
std::stringstream ss; std::stringstream ss;
ss << op; ss << op;
std::string s = ss.str(); std::string s = ss.str();
...@@ -94,7 +96,7 @@ TEST_CASE(operation_print) ...@@ -94,7 +96,7 @@ TEST_CASE(operation_print)
TEST_CASE(operation_default_print) TEST_CASE(operation_default_print)
{ {
migraph::operation op = simple_operation_no_print{}; migraphx::operation op = simple_operation_no_print{};
std::stringstream ss; std::stringstream ss;
ss << op; ss << op;
std::string s = ss.str(); std::string s = ss.str();
......
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <test.hpp> #include <test.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
TEST_CASE(simple_alias) TEST_CASE(simple_alias)
{ {
migraph::program p; migraphx::program p;
auto l = p.add_literal(1); auto l = p.add_literal(1);
auto p1 = p.add_instruction(pass_op{}, l); auto p1 = p.add_instruction(pass_op{}, l);
EXPECT(bool{migraph::instruction::get_output_alias(l) == l}); EXPECT(bool{migraphx::instruction::get_output_alias(l) == l});
EXPECT(bool{migraph::instruction::get_output_alias(p1) == l}); EXPECT(bool{migraphx::instruction::get_output_alias(p1) == l});
} }
TEST_CASE(cascade_alias) TEST_CASE(cascade_alias)
{ {
migraph::program p; migraphx::program p;
auto l = p.add_literal(1); auto l = p.add_literal(1);
auto p1 = p.add_instruction(pass_op{}, l); auto p1 = p.add_instruction(pass_op{}, l);
auto p2 = p.add_instruction(pass_op{}, p1); auto p2 = p.add_instruction(pass_op{}, p1);
auto p3 = p.add_instruction(pass_op{}, p2); auto p3 = p.add_instruction(pass_op{}, p2);
EXPECT(bool{migraph::instruction::get_output_alias(l) == l}); EXPECT(bool{migraphx::instruction::get_output_alias(l) == l});
EXPECT(bool{migraph::instruction::get_output_alias(p1) == l}); EXPECT(bool{migraphx::instruction::get_output_alias(p1) == l});
EXPECT(bool{migraph::instruction::get_output_alias(p2) == l}); EXPECT(bool{migraphx::instruction::get_output_alias(p2) == l});
EXPECT(bool{migraph::instruction::get_output_alias(p3) == l}); EXPECT(bool{migraphx::instruction::get_output_alias(p3) == l});
} }
TEST_CASE(no_alias) TEST_CASE(no_alias)
{ {
migraph::program p; migraphx::program p;
auto x = p.add_literal(1); auto x = p.add_literal(1);
auto y = p.add_literal(2); auto y = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, x, y); auto sum = p.add_instruction(sum_op{}, x, y);
EXPECT(bool{migraph::instruction::get_output_alias(sum) == sum}); EXPECT(bool{migraphx::instruction::get_output_alias(sum) == sum});
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment