Commit 3a848f0d authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into doc2

parents 64e8e30a d1e945da
...@@ -112,4 +112,16 @@ TEST_CASE(no_packed_unary_op) ...@@ -112,4 +112,16 @@ TEST_CASE(no_packed_unary_op)
EXPECT(std::distance(p.begin(), p.end()) == count - 1); EXPECT(std::distance(p.begin(), p.end()) == count - 1);
} }
TEST_CASE(non_standard_return_input)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto tl = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, tl);
p.add_return({c});
auto count = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -21,7 +21,7 @@ TEST_CASE(simple_test) ...@@ -21,7 +21,7 @@ TEST_CASE(simple_test)
EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) { EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity"; return ins.name() == "identity";
})); }));
auto result = p.eval({}); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
} }
...@@ -37,7 +37,7 @@ TEST_CASE(simple_test_end) ...@@ -37,7 +37,7 @@ TEST_CASE(simple_test_end)
EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) { EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity"; return ins.name() == "identity";
})); }));
auto result = p.eval({}); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
} }
...@@ -55,7 +55,7 @@ TEST_CASE(simple_test_end_dependency) ...@@ -55,7 +55,7 @@ TEST_CASE(simple_test_end_dependency)
EXPECT(std::any_of(p.begin(), p.end(), [](const migraphx::instruction& ins) { EXPECT(std::any_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity"; return ins.name() == "identity";
})); }));
auto result = p.eval({}); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3.0}); EXPECT(result == migraphx::literal{3.0});
} }
......
...@@ -134,7 +134,7 @@ TEST_CASE(literal_test1) ...@@ -134,7 +134,7 @@ TEST_CASE(literal_test1)
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
p.add_instruction(sum_op{}, one, two); p.add_instruction(sum_op{}, one, two);
auto result = p.eval({}); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4}); EXPECT(result != migraphx::literal{4});
} }
...@@ -148,7 +148,7 @@ TEST_CASE(literal_test2) ...@@ -148,7 +148,7 @@ TEST_CASE(literal_test2)
auto sum1 = p.add_instruction(sum_op{}, one, two); auto sum1 = p.add_instruction(sum_op{}, one, two);
p.add_instruction(sum_op{}, sum1, two); p.add_instruction(sum_op{}, sum1, two);
auto result = p.eval({}); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{5}); EXPECT(result == migraphx::literal{5});
EXPECT(result != migraphx::literal{3}); EXPECT(result != migraphx::literal{3});
} }
...@@ -175,8 +175,9 @@ TEST_CASE(param_test) ...@@ -175,8 +175,9 @@ TEST_CASE(param_test)
auto y = p.add_parameter("y", {migraphx::shape::int32_type}); auto y = p.add_parameter("y", {migraphx::shape::int32_type});
p.add_instruction(sum_op{}, x, y); p.add_instruction(sum_op{}, x, y);
auto result = p.eval( auto result = p.eval({{"x", migraphx::literal{1}.get_argument()},
{{"x", migraphx::literal{1}.get_argument()}, {"y", migraphx::literal{2}.get_argument()}}); {"y", migraphx::literal{2}.get_argument()}})
.back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4}); EXPECT(result != migraphx::literal{4});
} }
...@@ -258,7 +259,7 @@ TEST_CASE(replace_test) ...@@ -258,7 +259,7 @@ TEST_CASE(replace_test)
p.replace_instruction(sum, minus_op{}, two, one); p.replace_instruction(sum, minus_op{}, two, one);
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{1}); EXPECT(result == migraphx::literal{1});
EXPECT(result != migraphx::literal{3}); EXPECT(result != migraphx::literal{3});
} }
...@@ -274,7 +275,7 @@ TEST_CASE(replace_ins_test) ...@@ -274,7 +275,7 @@ TEST_CASE(replace_ins_test)
p.replace_instruction(sum, minus); p.replace_instruction(sum, minus);
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{1}); EXPECT(result == migraphx::literal{1});
EXPECT(result != migraphx::literal{3}); EXPECT(result != migraphx::literal{3});
} }
...@@ -291,7 +292,7 @@ TEST_CASE(replace_ins_test2) ...@@ -291,7 +292,7 @@ TEST_CASE(replace_ins_test2)
p.replace_instruction(two, sum); p.replace_instruction(two, sum);
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{2}); EXPECT(result == migraphx::literal{2});
EXPECT(result != migraphx::literal{3}); EXPECT(result != migraphx::literal{3});
} }
...@@ -306,7 +307,7 @@ TEST_CASE(replace_op_test) ...@@ -306,7 +307,7 @@ TEST_CASE(replace_op_test)
sum->replace(minus_op{}); sum->replace(minus_op{});
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{1}); EXPECT(result == migraphx::literal{1});
EXPECT(result != migraphx::literal{3}); EXPECT(result != migraphx::literal{3});
} }
...@@ -334,7 +335,7 @@ TEST_CASE(insert_replace_test) ...@@ -334,7 +335,7 @@ TEST_CASE(insert_replace_test)
p.replace_instruction(sum1, minus_op{}, sum0, two); p.replace_instruction(sum1, minus_op{}, sum0, two);
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{4}); EXPECT(result == migraphx::literal{4});
EXPECT(result != migraphx::literal{5}); EXPECT(result != migraphx::literal{5});
} }
...@@ -350,7 +351,7 @@ TEST_CASE(remove_test1) ...@@ -350,7 +351,7 @@ TEST_CASE(remove_test1)
p.remove_instruction(removed); p.remove_instruction(removed);
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{1}); EXPECT(result != migraphx::literal{1});
} }
...@@ -366,7 +367,7 @@ TEST_CASE(remove_test2) ...@@ -366,7 +367,7 @@ TEST_CASE(remove_test2)
p.remove_instruction(removed); p.remove_instruction(removed);
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{1}); EXPECT(result != migraphx::literal{1});
} }
...@@ -379,7 +380,7 @@ TEST_CASE(target_test) ...@@ -379,7 +380,7 @@ TEST_CASE(target_test)
auto two = p.add_literal(2); auto two = p.add_literal(2);
p.add_instruction(sum_op{}, one, two); p.add_instruction(sum_op{}, one, two);
p.compile(id_target{}); p.compile(id_target{});
auto result = p.eval({}); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4}); EXPECT(result != migraphx::literal{4});
} }
...@@ -392,7 +393,7 @@ TEST_CASE(invert_target_test) ...@@ -392,7 +393,7 @@ TEST_CASE(invert_target_test)
auto two = p.add_literal(2); auto two = p.add_literal(2);
p.add_instruction(sum_op{}, two, one); p.add_instruction(sum_op{}, two, one);
p.compile(invert_target{}); p.compile(invert_target{});
auto result = p.eval({}); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{1}); EXPECT(result == migraphx::literal{1});
EXPECT(result != migraphx::literal{4}); EXPECT(result != migraphx::literal{4});
} }
...@@ -405,7 +406,7 @@ TEST_CASE(double_invert_target_test) ...@@ -405,7 +406,7 @@ TEST_CASE(double_invert_target_test)
auto two = p.add_literal(2); auto two = p.add_literal(2);
p.add_instruction(sum_op{}, two, one); p.add_instruction(sum_op{}, two, one);
p.compile(double_invert_target{}); p.compile(double_invert_target{});
auto result = p.eval({}); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4}); EXPECT(result != migraphx::literal{4});
} }
...@@ -432,7 +433,7 @@ TEST_CASE(eval_context1) ...@@ -432,7 +433,7 @@ TEST_CASE(eval_context1)
p.add_instruction(sum_op{}, one, two); p.add_instruction(sum_op{}, one, two);
p.compile(t); p.compile(t);
EXPECT(is_shared(t.ctx, p.get_context())); EXPECT(is_shared(t.ctx, p.get_context()));
p.eval({}); p.eval({}).back();
EXPECT(is_shared(t.ctx, p.get_context())); EXPECT(is_shared(t.ctx, p.get_context()));
} }
...@@ -446,7 +447,7 @@ TEST_CASE(eval_context2) ...@@ -446,7 +447,7 @@ TEST_CASE(eval_context2)
p.add_instruction(id_ctx_op{}, one, two); p.add_instruction(id_ctx_op{}, one, two);
p.compile(t); p.compile(t);
EXPECT(is_shared(t.ctx, p.get_context())); EXPECT(is_shared(t.ctx, p.get_context()));
p.eval({}); p.eval({}).back();
// id_ctx_op will modify the context // id_ctx_op will modify the context
EXPECT(not is_shared(t.ctx, p.get_context())); EXPECT(not is_shared(t.ctx, p.get_context()));
} }
...@@ -463,7 +464,7 @@ TEST_CASE(eval_context3) ...@@ -463,7 +464,7 @@ TEST_CASE(eval_context3)
// Finalizer will modify the context // Finalizer will modify the context
EXPECT(not is_shared(t.ctx, p.get_context())); EXPECT(not is_shared(t.ctx, p.get_context()));
auto ctx = p.get_context(); auto ctx = p.get_context();
p.eval({}); p.eval({}).back();
EXPECT(is_shared(ctx, p.get_context())); EXPECT(is_shared(ctx, p.get_context()));
EXPECT(not is_shared(t.ctx, p.get_context())); EXPECT(not is_shared(t.ctx, p.get_context()));
} }
......
...@@ -15,7 +15,7 @@ void gpu_literal_test() ...@@ -15,7 +15,7 @@ void gpu_literal_test()
auto scratch = p.get_parameter("scratch"); auto scratch = p.get_parameter("scratch");
if(scratch == p.end()) if(scratch == p.end())
{ {
auto result = p.eval({}); auto result = p.eval({}).back();
EXPECT(lit == migraphx::gpu::from_gpu(result)); EXPECT(lit == migraphx::gpu::from_gpu(result));
} }
else else
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
...@@ -87,17 +88,29 @@ auto get_hash(const T& x) ...@@ -87,17 +88,29 @@ auto get_hash(const T& x)
void compile_check(migraphx::program& p, const migraphx::target& t, bool show_trace = false) void compile_check(migraphx::program& p, const migraphx::target& t, bool show_trace = false)
{ {
auto name = t.name(); auto name = t.name();
auto s = p.get_shape(); auto shapes = p.get_output_shapes();
std::stringstream ss; std::stringstream ss;
migraphx::compile_options options; migraphx::compile_options options;
options.trace = migraphx::tracer{ss}; options.trace = migraphx::tracer{ss};
p.compile(t, options); p.compile(t, options);
if(p.get_shape() != s) if(shapes.size() != p.get_output_shapes().size())
{ {
std::cout << ss.str() << std::endl; std::cout << ss.str() << std::endl;
throw std::runtime_error("Compiling program with " + name + " alters its shape"); throw std::runtime_error("Compiling program with " + name +
" alters its number of outputs");
} }
auto num = shapes.size();
for(std::size_t i = 0; i < num; ++i)
{
if(p.get_output_shapes()[i].lens() != shapes[i].lens())
{
std::cout << ss.str() << std::endl;
throw std::runtime_error("Compiling program with " + name + " alters its shape");
}
}
if(show_trace) if(show_trace)
{ {
std::cout << ss.str() << std::endl; std::cout << ss.str() << std::endl;
...@@ -105,7 +118,7 @@ void compile_check(migraphx::program& p, const migraphx::target& t, bool show_tr ...@@ -105,7 +118,7 @@ void compile_check(migraphx::program& p, const migraphx::target& t, bool show_tr
} }
template <class V> template <class V>
migraphx::argument run_cpu(migraphx::program& p) std::vector<migraphx::argument> run_cpu(migraphx::program& p)
{ {
V v; V v;
p = v.create_program(); p = v.create_program();
...@@ -120,7 +133,7 @@ migraphx::argument run_cpu(migraphx::program& p) ...@@ -120,7 +133,7 @@ migraphx::argument run_cpu(migraphx::program& p)
} }
template <class V> template <class V>
migraphx::argument run_gpu(migraphx::program& p) std::vector<migraphx::argument> run_gpu(migraphx::program& p)
{ {
V v; V v;
p = v.create_program(); p = v.create_program();
...@@ -133,7 +146,9 @@ migraphx::argument run_gpu(migraphx::program& p) ...@@ -133,7 +146,9 @@ migraphx::argument run_gpu(migraphx::program& p)
migraphx::gpu::to_gpu(migraphx::generate_argument(x.second, get_hash(x.first))); migraphx::gpu::to_gpu(migraphx::generate_argument(x.second, get_hash(x.first)));
} }
// Program should have an output parameter // Program should have an output parameter
EXPECT(bool{m.find("output") != m.end()}); EXPECT(std::any_of(
m.begin(), m.end(), [](auto& x) { return migraphx::contains(x.first, "output"); }));
// Ensure the program doesn't modify the context in a dry run // Ensure the program doesn't modify the context in a dry run
auto ctx = p.get_context(); auto ctx = p.get_context();
assert(&ctx != &p.get_context()); assert(&ctx != &p.get_context());
...@@ -141,7 +156,14 @@ migraphx::argument run_gpu(migraphx::program& p) ...@@ -141,7 +156,14 @@ migraphx::argument run_gpu(migraphx::program& p)
p.dry_run(m); p.dry_run(m);
EXPECT(is_shared(ctx, p.get_context())); EXPECT(is_shared(ctx, p.get_context()));
p.eval(m); p.eval(m);
return migraphx::gpu::from_gpu(p.eval(m));
auto gpu_res = p.eval(m);
std::vector<migraphx::argument> res(gpu_res.size());
std::transform(gpu_res.begin(), gpu_res.end(), res.begin(), [&](auto& argu) {
return migraphx::gpu::from_gpu(argu);
});
return res;
} }
template <class V> template <class V>
...@@ -154,7 +176,15 @@ void run_verify_program() ...@@ -154,7 +176,15 @@ void run_verify_program()
auto cpu_arg_f = detach_async([&] { return run_cpu<V>(cpu_prog); }); auto cpu_arg_f = detach_async([&] { return run_cpu<V>(cpu_prog); });
auto gpu_arg = run_gpu<V>(gpu_prog); auto gpu_arg = run_gpu<V>(gpu_prog);
auto cpu_arg = cpu_arg_f.get(); auto cpu_arg = cpu_arg_f.get();
bool passed = verify_args(migraphx::get_type_name<V>(), cpu_arg, gpu_arg);
bool passed = true;
passed &= (cpu_arg.size() == gpu_arg.size());
std::size_t num = cpu_arg.size();
for(std::size_t i = 0; ((i < num) and passed); ++i)
{
passed &= verify_args(migraphx::get_type_name<V>(), cpu_arg[i], gpu_arg[i]);
}
if(not passed) if(not passed)
{ {
V v; V v;
...@@ -316,6 +346,21 @@ struct test_pow : verify_program<test_pow> ...@@ -316,6 +346,21 @@ struct test_pow : verify_program<test_pow>
} }
}; };
struct test_prelu_brcst : verify_program<test_prelu_brcst>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {6}};
auto x = p.add_parameter("x", s);
auto slp = p.add_parameter("slp", s);
auto r = p.add_instruction(migraphx::op::prelu{}, x, slp);
p.add_return({r});
return p;
}
};
struct test_sin : verify_program<test_sin> struct test_sin : verify_program<test_sin>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -402,6 +447,21 @@ struct test_trans_tanh : verify_program<test_trans_tanh> ...@@ -402,6 +447,21 @@ struct test_trans_tanh : verify_program<test_trans_tanh>
} }
}; };
struct test_trans_tanh1 : verify_program<test_trans_tanh1>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto tanhx = p.add_instruction(migraphx::op::tanh{}, tx);
auto r = p.add_instruction(migraphx::op::add{}, tanhx, tanhx);
p.add_return({tx, r});
return p;
}
};
struct test_slice_sin : verify_program<test_slice_sin> struct test_slice_sin : verify_program<test_slice_sin>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -451,6 +511,44 @@ struct test_atan : verify_program<test_atan> ...@@ -451,6 +511,44 @@ struct test_atan : verify_program<test_atan>
} }
}; };
struct test_asinh : verify_program<test_asinh>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::double_type, {16}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::asinh{}, x);
return p;
}
};
struct test_acosh : verify_program<test_acosh>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {16}};
auto x = p.add_parameter("x", s);
auto cx = p.add_instruction(migraphx::op::clip{100.0f, 1.1f}, x);
p.add_instruction(migraphx::op::acosh{}, cx);
return p;
}
};
struct test_atanh : verify_program<test_atanh>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::double_type, {16}};
auto x = p.add_parameter("x", s);
auto cx = p.add_instruction(migraphx::op::clip{0.95f, -0.95f}, x);
p.add_instruction(migraphx::op::atanh{}, cx);
return p;
}
};
struct test_scale : verify_program<test_scale> struct test_scale : verify_program<test_scale>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -744,11 +842,15 @@ template struct test_arg_ops<migraphx::op::argmax, 0>; ...@@ -744,11 +842,15 @@ template struct test_arg_ops<migraphx::op::argmax, 0>;
template struct test_arg_ops<migraphx::op::argmax, 1>; template struct test_arg_ops<migraphx::op::argmax, 1>;
template struct test_arg_ops<migraphx::op::argmax, 2>; template struct test_arg_ops<migraphx::op::argmax, 2>;
template struct test_arg_ops<migraphx::op::argmax, 3>; template struct test_arg_ops<migraphx::op::argmax, 3>;
template struct test_arg_ops<migraphx::op::argmax, -1>;
template struct test_arg_ops<migraphx::op::argmax, -2>;
template struct test_arg_ops<migraphx::op::argmin, 0>; template struct test_arg_ops<migraphx::op::argmin, 0>;
template struct test_arg_ops<migraphx::op::argmin, 1>; template struct test_arg_ops<migraphx::op::argmin, 1>;
template struct test_arg_ops<migraphx::op::argmin, 2>; template struct test_arg_ops<migraphx::op::argmin, 2>;
template struct test_arg_ops<migraphx::op::argmin, 3>; template struct test_arg_ops<migraphx::op::argmin, 3>;
template struct test_arg_ops<migraphx::op::argmin, -3>;
template struct test_arg_ops<migraphx::op::argmin, -4>;
struct test_conv : verify_program<test_conv> struct test_conv : verify_program<test_conv>
{ {
...@@ -1717,7 +1819,7 @@ struct test_contiguous : verify_program<test_contiguous> ...@@ -1717,7 +1819,7 @@ struct test_contiguous : verify_program<test_contiguous>
migraphx::shape s{migraphx::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}}; migraphx::shape s{migraphx::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::contiguous{}, x); p.add_instruction(migraphx::op::contiguous{}, x);
EXPECT(p.get_shape().standard()); EXPECT(p.get_output_shapes().back().standard());
return p; return p;
} }
}; };
...@@ -1730,7 +1832,7 @@ struct test_contiguous_broadcast : verify_program<test_contiguous_broadcast> ...@@ -1730,7 +1832,7 @@ struct test_contiguous_broadcast : verify_program<test_contiguous_broadcast>
migraphx::shape s{migraphx::shape::float_type, {1, 2}, {0, 1}}; migraphx::shape s{migraphx::shape::float_type, {1, 2}, {0, 1}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::contiguous{}, x); p.add_instruction(migraphx::op::contiguous{}, x);
EXPECT(p.get_shape().standard()); EXPECT(p.get_output_shapes().back().standard());
return p; return p;
} }
}; };
...@@ -1743,7 +1845,7 @@ struct test_contiguous_broadcast_transpose : verify_program<test_contiguous_broa ...@@ -1743,7 +1845,7 @@ struct test_contiguous_broadcast_transpose : verify_program<test_contiguous_broa
migraphx::shape s{migraphx::shape::float_type, {1, 3072, 768}, {0, 1, 3072}}; migraphx::shape s{migraphx::shape::float_type, {1, 3072, 768}, {0, 1, 3072}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::contiguous{}, x); p.add_instruction(migraphx::op::contiguous{}, x);
EXPECT(p.get_shape().standard()); EXPECT(p.get_output_shapes().back().standard());
return p; return p;
} }
}; };
...@@ -1762,6 +1864,19 @@ struct test_transpose : verify_program<test_transpose> ...@@ -1762,6 +1864,19 @@ struct test_transpose : verify_program<test_transpose>
} }
}; };
struct test_trans_ret : verify_program<test_trans_ret>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
p.add_return({tx});
return p;
}
};
struct test_batchnorm_inference_2 : verify_program<test_batchnorm_inference_2> struct test_batchnorm_inference_2 : verify_program<test_batchnorm_inference_2>
{ {
const size_t width = 14; const size_t width = 14;
...@@ -2228,7 +2343,7 @@ void manual_identity() ...@@ -2228,7 +2343,7 @@ void manual_identity()
{ {
m[x.first] = migraphx::gpu::to_gpu(migraphx::generate_argument(x.second)); m[x.first] = migraphx::gpu::to_gpu(migraphx::generate_argument(x.second));
} }
auto result = migraphx::gpu::from_gpu(p.eval(m)); auto result = migraphx::gpu::from_gpu(p.eval(m).back());
std::cout << result << std::endl; std::cout << result << std::endl;
} }
...@@ -2257,7 +2372,7 @@ void manual_test_concat_relu() ...@@ -2257,7 +2372,7 @@ void manual_test_concat_relu()
{ {
m[x.first] = migraphx::gpu::to_gpu(migraphx::generate_argument(x.second)); m[x.first] = migraphx::gpu::to_gpu(migraphx::generate_argument(x.second));
} }
auto result = migraphx::gpu::from_gpu(p.eval(m)); auto result = migraphx::gpu::from_gpu(p.eval(m).back());
std::cout << result << std::endl; std::cout << result << std::endl;
} }
...@@ -2382,6 +2497,48 @@ struct test_rnn_forward10 : verify_program<test_rnn_forward10> ...@@ -2382,6 +2497,48 @@ struct test_rnn_forward10 : verify_program<test_rnn_forward10>
} }
}; };
struct test_rnn_two_outputs : verify_program<test_rnn_two_outputs>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 10;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto hs = p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
und,
ih);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.add_return({hs, last_hs});
return p;
}
};
struct test_rnn_reverse : verify_program<test_rnn_reverse> struct test_rnn_reverse : verify_program<test_rnn_reverse>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -2902,6 +3059,38 @@ struct test_gru_forward_default_actv : verify_program<test_gru_forward_default_a ...@@ -2902,6 +3059,38 @@ struct test_gru_forward_default_actv : verify_program<test_gru_forward_default_a
} }
}; };
struct test_gru_two_outputs : verify_program<test_gru_two_outputs>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 1;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto hs = p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.add_return({hs, last_hs});
return p;
}
};
struct test_gru_forward_default_actv1 : verify_program<test_gru_forward_default_actv1> struct test_gru_forward_default_actv1 : verify_program<test_gru_forward_default_actv1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -3452,6 +3641,79 @@ struct test_lstm_forward_3args : verify_program<test_lstm_forward_3args> ...@@ -3452,6 +3641,79 @@ struct test_lstm_forward_3args : verify_program<test_lstm_forward_3args>
} }
}; };
struct test_lstm_two_outputs : verify_program<test_lstm_two_outputs>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.add_return({hs, last_hs});
return p;
}
};
struct test_lstm_three_outputs : verify_program<test_lstm_three_outputs>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_output{}, hs);
auto last_cell = p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
p.add_return({hs, last_hs, last_cell});
return p;
}
};
struct test_lstm_forward_seq1 : verify_program<test_lstm_forward_seq1> struct test_lstm_forward_seq1 : verify_program<test_lstm_forward_seq1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -4091,10 +4353,11 @@ struct test_reduce_op_large : verify_program<test_reduce_op_large<Op, Axis, T>> ...@@ -4091,10 +4353,11 @@ struct test_reduce_op_large : verify_program<test_reduce_op_large<Op, Axis, T>>
}; };
}; };
template struct test_reduce_op_large<migraphx::op::reduce_sum, 1, migraphx::shape::float_type>;
template struct test_reduce_op_large<migraphx::op::reduce_mean, 1, migraphx::shape::float_type>;
template struct test_reduce_op_large<migraphx::op::reduce_max, 1, migraphx::shape::float_type>; template struct test_reduce_op_large<migraphx::op::reduce_max, 1, migraphx::shape::float_type>;
template struct test_reduce_op_large<migraphx::op::reduce_mean, 1, migraphx::shape::float_type>;
template struct test_reduce_op_large<migraphx::op::reduce_min, 1, migraphx::shape::float_type>; template struct test_reduce_op_large<migraphx::op::reduce_min, 1, migraphx::shape::float_type>;
template struct test_reduce_op_large<migraphx::op::reduce_prod, 2, migraphx::shape::float_type>;
template struct test_reduce_op_large<migraphx::op::reduce_sum, 1, migraphx::shape::float_type>;
template <class Op, int Axis, migraphx::shape::type_t T> template <class Op, int Axis, migraphx::shape::type_t T>
struct test_reduce_op_small : verify_program<test_reduce_op_small<Op, Axis, T>> struct test_reduce_op_small : verify_program<test_reduce_op_small<Op, Axis, T>>
...@@ -4117,6 +4380,7 @@ template struct test_reduce_op_small<migraphx::op::reduce_sum, 2, migraphx::shap ...@@ -4117,6 +4380,7 @@ template struct test_reduce_op_small<migraphx::op::reduce_sum, 2, migraphx::shap
template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::shape::half_type>; template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::half_type>; template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_min, 2, migraphx::shape::half_type>; template struct test_reduce_op_small<migraphx::op::reduce_min, 2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_prod, -2, migraphx::shape::half_type>;
struct test_rsqrt : verify_program<test_rsqrt> struct test_rsqrt : verify_program<test_rsqrt>
{ {
......
...@@ -56,7 +56,7 @@ TEST_CASE(int8_quantization) ...@@ -56,7 +56,7 @@ TEST_CASE(int8_quantization)
} }
} }
auto result = t.copy_from(p.eval(m)); auto result = t.copy_from(p.eval(m).back());
result.visit([&](auto v) { res.assign(v.begin(), v.end()); }); result.visit([&](auto v) { res.assign(v.begin(), v.end()); });
}; };
......
...@@ -606,7 +606,7 @@ TEST_CASE(literal_test) ...@@ -606,7 +606,7 @@ TEST_CASE(literal_test)
auto lit = generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto lit = generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
p.add_literal(lit); p.add_literal(lit);
run_pass(p); run_pass(p);
auto result = p.eval({}); auto result = p.eval({}).back();
CHECK(lit == result); CHECK(lit == result);
} }
......

acosh_test:=
xy"Acosh
acosh_testZ
x


b
y


B
\ No newline at end of file

asinh_test:=
xy"Asinh
asinh_testZ
x


b
y


B
\ No newline at end of file

atanh_test:=
xy"Atanh
atanh_testZ
x


b
y


B
\ No newline at end of file
averagepool_same_lower_test:
E
xy" AveragePool*
auto_pad"
SAME_LOWER*
kernel_shape@@averagepool_same_lower_testZ
x




b
y




B
\ No newline at end of file
averagepool_same_upper_test:
E
xy" AveragePool*
auto_pad"
SAME_UPPER*
kernel_shape@@averagepool_same_upper_testZ
x




b
y




B
\ No newline at end of file
constant-scalar-example:R constant_scalar_test:Y
00"Constant*! 00"Constant*!
value**B const_tensor  test-constantb value**B const_tensor constant_scalar_testb
0 0
 
......
conv_autopad_same_test:»
J
0
12"Conv*
auto_pad"SAME *
dilations@@ *
strides@@ conv_autopad_same_testZ
0




 Z
1




b
2




 B
\ No newline at end of file
 conv-example:­ conv_bias_test:²
8 8
0 0
1 1
23"Conv* 23"Conv*
dilations@@ * dilations@@ *
strides@@  test_convZ strides@@ conv_bias_testZ
0 0
 
 
......
convinteger_bias_test:À
?
0
1
23" ConvInteger*
dilations@@ *
strides@@ convinteger_bias_testZ
0




 Z
1




Z
2

b
3




B
\ No newline at end of file
deconv_bias_test:ž
"
x
w
byconv1" ConvTransposedeconv_bias_testZ
x




Z
w




Z
b

b
y




B
deconv_input_pads_strides_test:¶
=
x
wy" ConvTranspose*
pads@@@@ *
strides@@ deconv_input_pads_strides_testZ
x




Z
w




b
y




B
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment