Unverified Commit 6554639b authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

add get_main_module api (#665)



* add an api get_main_module

* clang format

* modify onnx unit test for module

* clang format

* refactor ops unit test with the get_main_module

* clang format

* code backup

* clang format

* refine module c api

* add python api for module

* clang format

* fix a python api issue

* clang format

* fix cppcheck error

* clang format

* refine unit tests changes

* clang format

* defer some changes to later PRs

* change return of get_main_module from ref to pointer

* clang format

* add unit tests for the get_main_module_api

* clang format

* fix cppcheck error

* clang format

* add more unit tests for more code change coverage

* clang format

* fix unit test

* clang format

* code change for more code coverage

* Add option to no generate a destroy method

* Formatting

* fix some review comments

* clang format

* fix review comments

* clang format

* clang format
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
parent 500d9441
......@@ -167,7 +167,9 @@ std::vector<argument> run(program& p, const program::parameter_map& params)
std::vector<shape> get_output_shapes(program& p) { return p.get_output_shapes(); }
void print(const program& p) { std::cout << p << std::endl; }
void print_program(const program& p) { std::cout << p << std::endl; }
void print_module(const module& m) { std::cout << m << std::endl; }
} // namespace migraphx
......@@ -264,6 +266,16 @@ struct migraphx_shapes
std::vector<migraphx::shape> object;
};
extern "C" struct migraphx_module;
struct migraphx_module
{
template <class... Ts>
migraphx_module(Ts&&... xs) : object(std::forward<Ts>(xs)...)
{
}
migraphx::module object;
};
extern "C" struct migraphx_program;
struct migraphx_program
{
......@@ -616,11 +628,30 @@ migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_
});
}
extern "C" migraphx_status migraphx_module_print(const_migraphx_module_t module)
{
return migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
migraphx::print_module((module->object));
});
}
extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program)
{
return migraphx::try_([&] { destroy((program)); });
}
extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program)
{
return migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = object_cast<migraphx_module_t>((program->object).get_main_module());
});
}
extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target,
migraphx_compile_options* options)
......@@ -664,7 +695,7 @@ extern "C" migraphx_status migraphx_program_print(const_migraphx_program_t progr
return migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
migraphx::print((program->object));
migraphx::print_program((program->object));
});
}
......
......@@ -71,6 +71,9 @@ typedef const struct migraphx_arguments* const_migraphx_arguments_t;
typedef struct migraphx_shapes* migraphx_shapes_t;
typedef const struct migraphx_shapes* const_migraphx_shapes_t;
typedef struct migraphx_module* migraphx_module_t;
typedef const struct migraphx_module* const_migraphx_module_t;
typedef struct migraphx_program* migraphx_program_t;
typedef const struct migraphx_program* const_migraphx_program_t;
......@@ -174,8 +177,13 @@ migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes);
migraphx_status
migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx);
migraphx_status migraphx_module_print(const_migraphx_module_t module);
migraphx_status migraphx_program_destroy(migraphx_program_t program);
migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program);
migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target,
migraphx_compile_options* options);
......
......@@ -6,6 +6,7 @@
#include <exception>
#include <vector>
#include <cassert>
#include <iostream>
namespace migraphx {
inline namespace api { // NOLINT
......@@ -422,6 +423,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{
const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, self, pidx);
return argument(pout);
}
};
......@@ -459,6 +461,14 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
};
};
struct module
{
migraphx_module_t mm;
module(const migraphx_module_t& m) : mm(m) {}
void print() const { call(&migraphx_module_print, mm); }
};
struct program : MIGRAPHX_HANDLE_BASE(program)
{
program() {}
......@@ -514,6 +524,13 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return pout;
}
module get_main_module()
{
migraphx_module_t p_modu;
call(&migraphx_program_get_main_module, &p_modu, this->get_handle_ptr());
return module{p_modu};
}
friend bool operator!=(const program& px, const program& py) { return !(px == py); }
};
......
......@@ -167,8 +167,14 @@ def shapes(h):
returns='const migraphx::shape&')
@auto_handle(ref=True)
def module(h):
h.method('print', invoke='migraphx::print_module($@)', const=True)
@auto_handle()
def program(h):
h.method('get_main_module', returns='migraphx::module*')
h.method(
'compile',
api.params(target='migraphx::target',
......@@ -178,7 +184,7 @@ def program(h):
h.method('get_output_shapes',
invoke='migraphx::get_output_shapes($@)',
returns='std::vector<migraphx::shape>')
h.method('print', invoke='migraphx::print($@)', const=True)
h.method('print', invoke='migraphx::print_program($@)', const=True)
h.method('sort')
h.method('run',
api.params(
......
......@@ -17,6 +17,9 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
using module = program;
using module_ref = module*;
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_COMPILE)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL)
......@@ -139,6 +142,8 @@ struct program
friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); }
module* get_main_module() { return this; }
private:
void assign(const program& p);
std::unique_ptr<program_impl> impl;
......
......@@ -208,6 +208,15 @@ migraphx::shape to_shape(const py::buffer_info& info)
}
}
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_wrap
{
migraphx::program* prog;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
MIGRAPHX_PYBIND11_MODULE(migraphx, m)
{
py::class_<migraphx::shape>(m, "shape")
......@@ -247,6 +256,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target");
py::class_<migraphx::module_wrap>(m, "module")
.def("print", [](const migraphx::module_wrap& mm) { std::cout << *mm.prog << std::endl; });
py::class_<migraphx::program>(m, "program")
.def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); })
.def("get_parameter_names", &migraphx::program::get_parameter_names)
......@@ -263,6 +275,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("t"),
py::arg("offload_copy") = true,
py::arg("fast_math") = true)
.def("get_main_module",
[](migraphx::program& p) {
auto mm = p.get_main_module();
return migraphx::module_wrap{mm};
})
.def("run",
[](migraphx::program& p, py::dict params) {
migraphx::program::parameter_map pm;
......
......@@ -155,4 +155,12 @@ TEST_CASE(strided_shape)
EXPECT(s.strides() == strides);
}
TEST_CASE(get_main_module)
{
auto p = migraphx::parse_onnx("constant_fill_test.onnx");
migraphx::module mm = p.get_main_module();
mm.print();
p.print();
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -42,26 +42,27 @@ TEST_CASE(rnn_test_bidirectional)
migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}};
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_rnn_bi.onnx");
EXPECT(p == prog);
......@@ -85,25 +86,26 @@ TEST_CASE(rnn_test_one_direction)
// forward
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_rnn_forward.onnx");
EXPECT(p == prog);
......@@ -112,24 +114,25 @@ TEST_CASE(rnn_test_one_direction)
// reverse
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_rnn_reverse.onnx");
EXPECT(p == prog);
......@@ -138,22 +141,23 @@ TEST_CASE(rnn_test_one_direction)
// 3 argumments
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
und,
und,
und);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_rnn_3args.onnx");
EXPECT(p == prog);
......@@ -162,26 +166,27 @@ TEST_CASE(rnn_test_one_direction)
// 5 argumments
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
seq_len,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
seq_len,
und);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_rnn_5args.onnx");
EXPECT(p == prog);
......@@ -200,32 +205,34 @@ TEST_CASE(gru_test)
{
nd = 1;
migraphx::program p;
auto* mm = p.get_main_module();
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih =
mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip,
1},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip,
1},
seq,
w,
r,
bias,
seq_len,
ih);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_forward.onnx");
EXPECT(p == prog);
......@@ -235,31 +242,33 @@ TEST_CASE(gru_test)
{
nd = 1;
migraphx::program p;
auto* mm = p.get_main_module();
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih =
mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_reverse.onnx");
EXPECT(p == prog);
......@@ -269,34 +278,36 @@ TEST_CASE(gru_test)
{
nd = 2;
migraphx::program p;
auto* mm = p.get_main_module();
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih =
mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::relu{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::relu{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_bi.onnx");
EXPECT(p == prog);
......@@ -316,26 +327,27 @@ TEST_CASE(gru_test_args)
{
nd = 1;
migraphx::program p;
auto* mm = p.get_main_module();
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto und = p.add_instruction(migraphx::op::undefined{});
mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
und,
und,
und);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_3arg.onnx");
EXPECT(p == prog);
......@@ -345,29 +357,30 @@ TEST_CASE(gru_test_args)
{
nd = 1;
migraphx::program p;
auto* mm = p.get_main_module();
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto und = p.add_instruction(migraphx::op::undefined{});
mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::relu{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
und,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::gru{hs,
{migraphx::op::relu{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
und,
und);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_4arg.onnx");
EXPECT(p == prog);
......@@ -377,34 +390,35 @@ TEST_CASE(gru_test_args)
{
nd = 2;
migraphx::program p;
auto* mm = p.get_main_module();
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto und = p.add_instruction(migraphx::op::undefined{});
mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::relu{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::relu{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
und);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_5arg.onnx");
EXPECT(p == prog);
......@@ -423,34 +437,36 @@ TEST_CASE(gru_test_actv_funcs)
{
nd = 2;
migraphx::program p;
auto* mm = p.get_main_module();
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih =
mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::gru{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_bi_0.onnx");
EXPECT(p == prog);
......@@ -460,34 +476,36 @@ TEST_CASE(gru_test_actv_funcs)
{
nd = 2;
migraphx::program p;
auto* mm = p.get_main_module();
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih =
mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::gru{hs,
{migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_bi_1.onnx");
EXPECT(p == prog);
......@@ -497,34 +515,36 @@ TEST_CASE(gru_test_actv_funcs)
{
nd = 2;
migraphx::program p;
auto* mm = p.get_main_module();
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih =
mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_bi_2.onnx");
EXPECT(p == prog);
......@@ -534,34 +554,36 @@ TEST_CASE(gru_test_actv_funcs)
{
nd = 2;
migraphx::program p;
auto* mm = p.get_main_module();
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih =
mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_bi_3.onnx");
EXPECT(p == prog);
......@@ -571,31 +593,33 @@ TEST_CASE(gru_test_actv_funcs)
{
nd = 1;
migraphx::program p;
auto* mm = p.get_main_module();
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih =
mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::gru{hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_forward_0.onnx");
EXPECT(p == prog);
......@@ -605,31 +629,33 @@ TEST_CASE(gru_test_actv_funcs)
{
nd = 1;
migraphx::program p;
auto* mm = p.get_main_module();
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih =
mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::relu{}, migraphx::op::relu{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::gru{hs,
{migraphx::op::relu{}, migraphx::op::relu{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_reverse_1.onnx");
EXPECT(p == prog);
......@@ -654,16 +680,17 @@ TEST_CASE(lstm_forward)
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto ic = p.add_parameter("c0", ih_shape);
auto pph = p.add_parameter("pph", pph_shape);
auto out_hs = p.add_instruction(
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
auto out_hs = mm->add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
......@@ -678,7 +705,7 @@ TEST_CASE(lstm_forward)
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_forward.onnx");
EXPECT(p == prog);
......@@ -687,12 +714,13 @@ TEST_CASE(lstm_forward)
// 3 args
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
auto out_hs = mm->add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
......@@ -707,7 +735,7 @@ TEST_CASE(lstm_forward)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f3args.onnx");
EXPECT(p == prog);
......@@ -716,12 +744,13 @@ TEST_CASE(lstm_forward)
// 3 args, hs output
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
p.add_instruction(
mm->add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
......@@ -744,12 +773,13 @@ TEST_CASE(lstm_forward)
// 3 args, last output
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
auto out_hs = mm->add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
......@@ -764,7 +794,7 @@ TEST_CASE(lstm_forward)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_last.onnx");
EXPECT(p == prog);
......@@ -773,12 +803,13 @@ TEST_CASE(lstm_forward)
// 3 args, cell output
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
auto out_hs = mm->add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
......@@ -793,7 +824,7 @@ TEST_CASE(lstm_forward)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_cell.onnx");
EXPECT(p == prog);
......@@ -802,13 +833,14 @@ TEST_CASE(lstm_forward)
// 4 args
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs = mm->add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
......@@ -823,7 +855,7 @@ TEST_CASE(lstm_forward)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f4args.onnx");
EXPECT(p == prog);
......@@ -832,14 +864,15 @@ TEST_CASE(lstm_forward)
// 5 args
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs = mm->add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
......@@ -854,8 +887,8 @@ TEST_CASE(lstm_forward)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f5args.onnx");
EXPECT(p == prog);
......@@ -864,15 +897,16 @@ TEST_CASE(lstm_forward)
// 6 args
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs = mm->add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
......@@ -887,8 +921,8 @@ TEST_CASE(lstm_forward)
ih,
und,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f6args.onnx");
EXPECT(p == prog);
......@@ -897,16 +931,17 @@ TEST_CASE(lstm_forward)
// 7 args
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto ic = p.add_parameter("c0", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs = mm->add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
......@@ -921,8 +956,8 @@ TEST_CASE(lstm_forward)
ih,
ic,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f7args.onnx");
EXPECT(p == prog);
......@@ -947,13 +982,14 @@ TEST_CASE(lstm_forward_actv_func)
// no activation function specified
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
// auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
// auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs = mm->add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
......@@ -968,7 +1004,7 @@ TEST_CASE(lstm_forward_actv_func)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f0af.onnx");
EXPECT(p == prog);
......@@ -977,13 +1013,14 @@ TEST_CASE(lstm_forward_actv_func)
// 1 activation function specified
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs = mm->add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
......@@ -998,7 +1035,7 @@ TEST_CASE(lstm_forward_actv_func)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f1af.onnx");
EXPECT(p == prog);
......@@ -1007,14 +1044,15 @@ TEST_CASE(lstm_forward_actv_func)
// 2 activation function specified
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs = mm->add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
......@@ -1029,8 +1067,8 @@ TEST_CASE(lstm_forward_actv_func)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f2af.onnx");
EXPECT(p == prog);
......@@ -1055,16 +1093,17 @@ TEST_CASE(lstm_reverse)
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto ic = p.add_parameter("c0", ih_shape);
auto pph = p.add_parameter("pph", pph_shape);
auto out_hs = p.add_instruction(
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
auto out_hs = mm->add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
......@@ -1079,7 +1118,7 @@ TEST_CASE(lstm_reverse)
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_reverse.onnx");
EXPECT(p == prog);
......@@ -1088,14 +1127,15 @@ TEST_CASE(lstm_reverse)
// 5 args
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs = mm->add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
......@@ -1110,8 +1150,8 @@ TEST_CASE(lstm_reverse)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_r5args.onnx");
EXPECT(p == prog);
......@@ -1120,12 +1160,13 @@ TEST_CASE(lstm_reverse)
// no activation function specified
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
auto out_hs = mm->add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
......@@ -1140,7 +1181,7 @@ TEST_CASE(lstm_reverse)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_r0af.onnx");
EXPECT(p == prog);
......@@ -1165,35 +1206,36 @@ TEST_CASE(lstm_bidirectional)
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto ic = p.add_parameter("c0", ih_shape);
auto pph = p.add_parameter("pph", pph_shape);
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi.onnx");
EXPECT(p == prog);
......@@ -1202,31 +1244,32 @@ TEST_CASE(lstm_bidirectional)
// 3 args
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi3args.onnx");
EXPECT(p == prog);
......@@ -1235,32 +1278,33 @@ TEST_CASE(lstm_bidirectional)
// 4 args
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
und,
und,
und,
und);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi4args.onnx");
EXPECT(p == prog);
......@@ -1269,33 +1313,34 @@ TEST_CASE(lstm_bidirectional)
// 5 args
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
und,
und,
und);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi5args.onnx");
EXPECT(p == prog);
......@@ -1304,34 +1349,35 @@ TEST_CASE(lstm_bidirectional)
// 6 args
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
und,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
und,
und);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi6args.onnx");
EXPECT(p == prog);
......@@ -1340,35 +1386,36 @@ TEST_CASE(lstm_bidirectional)
// 7 args
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto ic = p.add_parameter("c0", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
ic,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
ic,
und);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi7args.onnx");
EXPECT(p == prog);
......@@ -1394,31 +1441,32 @@ TEST_CASE(lstm_bi_actv_funcs)
// 0 activation function
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi0af.onnx");
EXPECT(p == prog);
......@@ -1427,32 +1475,33 @@ TEST_CASE(lstm_bi_actv_funcs)
// 1 activation function
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
und,
und,
und,
und);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi1af.onnx");
EXPECT(p == prog);
......@@ -1461,33 +1510,34 @@ TEST_CASE(lstm_bi_actv_funcs)
// 2 activation functions
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
und,
und,
und);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi2af.onnx");
EXPECT(p == prog);
......@@ -1496,34 +1546,35 @@ TEST_CASE(lstm_bi_actv_funcs)
// 4 activation functions
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
und,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
und,
und);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi4af.onnx");
EXPECT(p == prog);
......@@ -1532,35 +1583,36 @@ TEST_CASE(lstm_bi_actv_funcs)
// 5 activation functions
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", bias_shape);
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto ih = p.add_parameter("h0", ih_shape);
auto ic = p.add_parameter("c0", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
ic,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
ic,
und);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi5af.onnx");
EXPECT(p == prog);
......@@ -1569,31 +1621,32 @@ TEST_CASE(lstm_bi_actv_funcs)
// 6 activation functions
{
migraphx::program p;
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto und = mm->add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
mm->add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi6af.onnx");
EXPECT(p == prog);
......
......@@ -32,8 +32,9 @@ migraphx::program optimize_onnx(const std::string& name, bool eliminate_deadcode
TEST_CASE(acos_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::acos{}, input);
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
mm->add_instruction(migraphx::op::acos{}, input);
auto prog = optimize_onnx("acos_test.onnx");
......@@ -43,8 +44,9 @@ TEST_CASE(acos_test)
TEST_CASE(acosh_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::acosh{}, input);
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
mm->add_instruction(migraphx::op::acosh{}, input);
auto prog = optimize_onnx("acosh_test.onnx");
......@@ -54,10 +56,11 @@ TEST_CASE(acosh_test)
TEST_CASE(add_bcast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::add{}, l0, l2);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = mm->add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1);
mm->add_instruction(migraphx::op::add{}, l0, l2);
auto prog = optimize_onnx("add_bcast_test.onnx");
......@@ -67,11 +70,12 @@ TEST_CASE(add_bcast_test)
TEST_CASE(add_fp16_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type, {1}}, {1.5}});
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type, {1}}, {1.5}});
auto l1 =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type, {1}}, {2.5}});
p.add_instruction(migraphx::op::add{}, l0, l1);
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type, {1}}, {2.5}});
mm->add_instruction(migraphx::op::add{}, l0, l1);
auto prog = optimize_onnx("add_fp16_test.onnx");
EXPECT(p == prog);
......@@ -80,11 +84,12 @@ TEST_CASE(add_fp16_test)
TEST_CASE(add_scalar_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::uint8_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::uint8_type});
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
auto r = p.add_instruction(migraphx::op::add{}, l0, m1);
p.add_return({r});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::uint8_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::uint8_type});
auto m1 = mm->add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
auto r = mm->add_instruction(migraphx::op::add{}, l0, m1);
mm->add_return({r});
auto prog = migraphx::parse_onnx("add_scalar_test.onnx");
EXPECT(p == prog);
......@@ -93,9 +98,10 @@ TEST_CASE(add_scalar_test)
TEST_CASE(argmax_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto ins = p.add_instruction(migraphx::op::argmax{2}, l0);
p.add_instruction(migraphx::op::squeeze{{2}}, ins);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto ins = mm->add_instruction(migraphx::op::argmax{2}, l0);
mm->add_instruction(migraphx::op::squeeze{{2}}, ins);
auto prog = optimize_onnx("argmax_test.onnx");
EXPECT(p == prog);
......@@ -104,9 +110,10 @@ TEST_CASE(argmax_test)
TEST_CASE(argmin_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto ins = p.add_instruction(migraphx::op::argmin{3}, l0);
p.add_instruction(migraphx::op::squeeze{{3}}, ins);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto ins = mm->add_instruction(migraphx::op::argmin{3}, l0);
mm->add_instruction(migraphx::op::squeeze{{3}}, ins);
auto prog = optimize_onnx("argmin_test.onnx");
EXPECT(p == prog);
......@@ -115,8 +122,9 @@ TEST_CASE(argmin_test)
TEST_CASE(asin_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::asin{}, input);
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
mm->add_instruction(migraphx::op::asin{}, input);
auto prog = optimize_onnx("asin_test.onnx");
......@@ -126,8 +134,9 @@ TEST_CASE(asin_test)
TEST_CASE(asinh_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::asinh{}, input);
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
mm->add_instruction(migraphx::op::asinh{}, input);
auto prog = optimize_onnx("asinh_test.onnx");
......@@ -137,8 +146,9 @@ TEST_CASE(asinh_test)
TEST_CASE(atan_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::atan{}, input);
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
mm->add_instruction(migraphx::op::atan{}, input);
auto prog = optimize_onnx("atan_test.onnx");
......@@ -148,8 +158,9 @@ TEST_CASE(atan_test)
TEST_CASE(atanh_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::atanh{}, input);
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
mm->add_instruction(migraphx::op::atanh{}, input);
auto prog = optimize_onnx("atanh_test.onnx");
......@@ -159,8 +170,9 @@ TEST_CASE(atanh_test)
TEST_CASE(averagepool_1d_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 5}});
p.add_instruction(migraphx::op::pooling{"average", {0}, {1}, {3}}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 5}});
mm->add_instruction(migraphx::op::pooling{"average", {0}, {1}, {3}}, l0);
auto prog = optimize_onnx("averagepool_1d_test.onnx");
EXPECT(p == prog);
......@@ -169,8 +181,9 @@ TEST_CASE(averagepool_1d_test)
TEST_CASE(averagepool_3d_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 5, 5, 5}});
p.add_instruction(migraphx::op::pooling{"average", {0, 0, 0}, {1, 1, 1}, {3, 3, 3}}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 5, 5, 5}});
mm->add_instruction(migraphx::op::pooling{"average", {0, 0, 0}, {1, 1, 1}, {3, 3, 3}}, l0);
auto prog = optimize_onnx("averagepool_3d_test.onnx");
EXPECT(p == prog);
......@@ -179,10 +192,11 @@ TEST_CASE(averagepool_3d_test)
TEST_CASE(averagepool_notset_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto ins = p.add_instruction(migraphx::op::pooling{"average", {2, 2}, {2, 2}, {6, 6}}, input);
auto ret = p.add_instruction(migraphx::op::slice{{2, 3}, {1, 1}, {2, 2}}, ins);
p.add_return({ret});
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto ins = mm->add_instruction(migraphx::op::pooling{"average", {2, 2}, {2, 2}, {6, 6}}, input);
auto ret = mm->add_instruction(migraphx::op::slice{{2, 3}, {1, 1}, {2, 2}}, ins);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("averagepool_notset_test.onnx");
EXPECT(p == prog);
......@@ -191,11 +205,13 @@ TEST_CASE(averagepool_notset_test)
TEST_CASE(averagepool_nt_cip_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
std::vector<int64_t> pads = {0, 0, 0, 0, 0, 0, 1, 1};
auto ins_pad = p.add_instruction(migraphx::op::pad{pads}, input);
auto ret = p.add_instruction(migraphx::op::pooling{"average", {0, 0}, {2, 2}, {6, 6}}, ins_pad);
p.add_return({ret});
auto ins_pad = mm->add_instruction(migraphx::op::pad{pads}, input);
auto ret =
mm->add_instruction(migraphx::op::pooling{"average", {0, 0}, {2, 2}, {6, 6}}, ins_pad);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("averagepool_nt_cip_test.onnx");
EXPECT(p == prog);
......@@ -204,10 +220,11 @@ TEST_CASE(averagepool_nt_cip_test)
TEST_CASE(averagepool_same_lower_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto ins = p.add_instruction(migraphx::op::pooling{"average", {1, 1}, {1, 1}, {2, 2}}, input);
auto ret = p.add_instruction(migraphx::op::slice{{2, 3}, {0, 0}, {5, 5}}, ins);
p.add_return({ret});
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto ins = mm->add_instruction(migraphx::op::pooling{"average", {1, 1}, {1, 1}, {2, 2}}, input);
auto ret = mm->add_instruction(migraphx::op::slice{{2, 3}, {0, 0}, {5, 5}}, ins);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("averagepool_same_lower_test.onnx");
EXPECT(p == prog);
......@@ -216,11 +233,13 @@ TEST_CASE(averagepool_same_lower_test)
TEST_CASE(averagepool_sl_cip_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
std::vector<int64_t> pads = {0, 0, 1, 1, 0, 0, 0, 0};
auto ins_pad = p.add_instruction(migraphx::op::pad{pads}, input);
auto ret = p.add_instruction(migraphx::op::pooling{"average", {0, 0}, {1, 1}, {2, 2}}, ins_pad);
p.add_return({ret});
auto ins_pad = mm->add_instruction(migraphx::op::pad{pads}, input);
auto ret =
mm->add_instruction(migraphx::op::pooling{"average", {0, 0}, {1, 1}, {2, 2}}, ins_pad);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("averagepool_sl_cip_test.onnx");
EXPECT(p == prog);
......@@ -229,10 +248,11 @@ TEST_CASE(averagepool_sl_cip_test)
TEST_CASE(averagepool_same_upper_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto ins = p.add_instruction(migraphx::op::pooling{"average", {1, 1}, {1, 1}, {2, 2}}, input);
auto ret = p.add_instruction(migraphx::op::slice{{2, 3}, {1, 1}, {6, 6}}, ins);
p.add_return({ret});
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto ins = mm->add_instruction(migraphx::op::pooling{"average", {1, 1}, {1, 1}, {2, 2}}, input);
auto ret = mm->add_instruction(migraphx::op::slice{{2, 3}, {1, 1}, {6, 6}}, ins);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("averagepool_same_upper_test.onnx");
EXPECT(p == prog);
......@@ -241,12 +261,13 @@ TEST_CASE(averagepool_same_upper_test)
TEST_CASE(batchnorm_1d_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 5}});
auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {3}});
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {3}});
auto l3 = p.add_parameter("3", {migraphx::shape::float_type, {3}});
auto l4 = p.add_parameter("4", {migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::batch_norm_inference{}, l0, l1, l2, l3, l4);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {3}});
auto l2 = mm->add_parameter("2", {migraphx::shape::float_type, {3}});
auto l3 = mm->add_parameter("3", {migraphx::shape::float_type, {3}});
auto l4 = mm->add_parameter("4", {migraphx::shape::float_type, {3}});
mm->add_instruction(migraphx::op::batch_norm_inference{}, l0, l1, l2, l3, l4);
auto prog = optimize_onnx("batchnorm_1d_test.onnx");
EXPECT(p == prog);
......@@ -255,12 +276,13 @@ TEST_CASE(batchnorm_1d_test)
TEST_CASE(batchnorm_3d_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 5, 5, 5}});
auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {3}});
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {3}});
auto l3 = p.add_parameter("3", {migraphx::shape::float_type, {3}});
auto l4 = p.add_parameter("4", {migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::batch_norm_inference{}, l0, l1, l2, l3, l4);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 5, 5, 5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {3}});
auto l2 = mm->add_parameter("2", {migraphx::shape::float_type, {3}});
auto l3 = mm->add_parameter("3", {migraphx::shape::float_type, {3}});
auto l4 = mm->add_parameter("4", {migraphx::shape::float_type, {3}});
mm->add_instruction(migraphx::op::batch_norm_inference{}, l0, l1, l2, l3, l4);
auto prog = optimize_onnx("batchnorm_3d_test.onnx");
EXPECT(p == prog);
......@@ -269,8 +291,9 @@ TEST_CASE(batchnorm_3d_test)
TEST_CASE(cast_test)
{
migraphx::program p;
auto l = p.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {10}});
p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, l);
auto* mm = p.get_main_module();
auto l = mm->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {10}});
mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, l);
auto prog = optimize_onnx("cast_test.onnx");
EXPECT(p == prog);
......@@ -279,8 +302,9 @@ TEST_CASE(cast_test)
TEST_CASE(ceil_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::ceil{}, input);
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
mm->add_instruction(migraphx::op::ceil{}, input);
auto prog = optimize_onnx("ceil_test.onnx");
......@@ -290,12 +314,13 @@ TEST_CASE(ceil_test)
TEST_CASE(clip_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
auto min_val = p.add_literal(0.0f);
auto max_val = p.add_literal(6.0f);
min_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, min_val);
max_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, max_val);
p.add_instruction(migraphx::op::clip{}, l0, min_val, max_val);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f);
min_val = mm->add_instruction(migraphx::op::multibroadcast{{3}}, min_val);
max_val = mm->add_instruction(migraphx::op::multibroadcast{{3}}, max_val);
mm->add_instruction(migraphx::op::clip{}, l0, min_val, max_val);
auto prog = optimize_onnx("clip_test.onnx");
EXPECT(p == prog);
......@@ -304,12 +329,13 @@ TEST_CASE(clip_test)
TEST_CASE(clip_test_op11_max_only)
{
migraphx::program p;
auto max_val = p.add_literal(0.0f);
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::undefined{});
max_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, max_val);
auto r = p.add_instruction(migraphx::op::min{}, l0, max_val);
p.add_return({r});
auto* mm = p.get_main_module();
auto max_val = mm->add_literal(0.0f);
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
mm->add_instruction(migraphx::op::undefined{});
max_val = mm->add_instruction(migraphx::op::multibroadcast{{3}}, max_val);
auto r = mm->add_instruction(migraphx::op::min{}, l0, max_val);
mm->add_return({r});
auto prog = migraphx::parse_onnx("clip_test_op11_max_only.onnx");
......@@ -319,12 +345,13 @@ TEST_CASE(clip_test_op11_max_only)
TEST_CASE(clip_test_op11)
{
migraphx::program p;
auto min_val = p.add_literal(0.0f);
auto max_val = p.add_literal(6.0f);
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
min_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, min_val);
max_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, max_val);
p.add_instruction(migraphx::op::clip{}, l0, min_val, max_val);
auto* mm = p.get_main_module();
auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f);
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
min_val = mm->add_instruction(migraphx::op::multibroadcast{{3}}, min_val);
max_val = mm->add_instruction(migraphx::op::multibroadcast{{3}}, max_val);
mm->add_instruction(migraphx::op::clip{}, l0, min_val, max_val);
auto prog = optimize_onnx("clip_test_op11.onnx");
EXPECT(p == prog);
......@@ -333,10 +360,11 @@ TEST_CASE(clip_test_op11)
TEST_CASE(clip_test_op11_min_only)
{
migraphx::program p;
auto min_val = p.add_literal(0.0f);
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
min_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, min_val);
p.add_instruction(migraphx::op::max{}, l0, min_val);
auto* mm = p.get_main_module();
auto min_val = mm->add_literal(0.0f);
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
min_val = mm->add_instruction(migraphx::op::multibroadcast{{3}}, min_val);
mm->add_instruction(migraphx::op::max{}, l0, min_val);
auto prog = optimize_onnx("clip_test_op11_min_only.onnx");
EXPECT(p == prog);
......@@ -345,8 +373,9 @@ TEST_CASE(clip_test_op11_min_only)
TEST_CASE(clip_test_op11_no_args)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::identity{}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
mm->add_instruction(migraphx::op::identity{}, l0);
auto prog = optimize_onnx("clip_test_op11_no_args.onnx");
EXPECT(p == prog);
......@@ -355,11 +384,12 @@ TEST_CASE(clip_test_op11_no_args)
TEST_CASE(clip_test_op11_no_args1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::undefined{});
auto r = p.add_instruction(migraphx::op::identity{}, l0);
p.add_return({r});
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
mm->add_instruction(migraphx::op::undefined{});
auto r = mm->add_instruction(migraphx::op::identity{}, l0);
mm->add_return({r});
auto prog = migraphx::parse_onnx("clip_test_op11_no_args1.onnx");
EXPECT(p == prog);
......@@ -368,9 +398,10 @@ TEST_CASE(clip_test_op11_no_args1)
TEST_CASE(concat_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7, 4, 3}});
p.add_instruction(migraphx::op::concat{0}, l0, l1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4, 3}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7, 4, 3}});
mm->add_instruction(migraphx::op::concat{0}, l0, l1);
auto prog = optimize_onnx("concat_test.onnx");
EXPECT(p == prog);
......@@ -379,7 +410,9 @@ TEST_CASE(concat_test)
TEST_CASE(constant_test)
{
migraphx::program p;
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0, 1, 2}});
auto* mm = p.get_main_module();
mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0, 1, 2}});
auto prog = optimize_onnx("constant_test.onnx");
EXPECT(p == prog);
......@@ -389,9 +422,10 @@ TEST_CASE(constant_fill_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> value(s.elements(), 1.0);
p.add_literal(migraphx::literal{s, value});
mm->add_literal(migraphx::literal{s, value});
auto prog = optimize_onnx("constant_fill_test.onnx");
EXPECT(p == prog);
......@@ -400,13 +434,14 @@ TEST_CASE(constant_fill_test)
TEST_CASE(constant_fill_input_as_shape_test)
{
migraphx::program p;
auto l0 = p.add_literal(migraphx::literal{{migraphx::shape::int32_type, {2}}, {2, 3}});
auto* mm = p.get_main_module();
auto l0 = mm->add_literal(migraphx::literal{{migraphx::shape::int32_type, {2}}, {2, 3}});
std::vector<std::size_t> dims(l0->get_shape().elements());
migraphx::literal ls = l0->get_literal();
ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
migraphx::shape s{migraphx::shape::float_type, dims};
std::vector<float> value(s.elements(), 1.0);
p.add_literal(migraphx::literal{s, value});
mm->add_literal(migraphx::literal{s, value});
auto prog = optimize_onnx("constant_fill_input_as_shape_test.onnx");
EXPECT(p == prog);
......@@ -415,7 +450,8 @@ TEST_CASE(constant_fill_input_as_shape_test)
TEST_CASE(constant_scalar_test)
{
migraphx::program p;
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {1}}, {1}});
auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {1}}, {1}});
auto prog = optimize_onnx("constant_scalar_test.onnx");
EXPECT(p == prog);
......@@ -424,10 +460,11 @@ TEST_CASE(constant_scalar_test)
TEST_CASE(const_of_shape_empty_input_test)
{
migraphx::program p;
p.add_literal(migraphx::literal());
auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal());
migraphx::shape s(migraphx::shape::int64_type, {1}, {0});
std::vector<int64_t> vec(s.elements(), 10);
p.add_literal(migraphx::literal(s, vec));
mm->add_literal(migraphx::literal(s, vec));
auto prog = optimize_onnx("const_of_shape_empty_input_test.onnx");
EXPECT(p == prog);
......@@ -436,11 +473,12 @@ TEST_CASE(const_of_shape_empty_input_test)
TEST_CASE(const_of_shape_float_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ss(migraphx::shape::int32_type, {3});
p.add_literal(migraphx::literal(ss, {2, 3, 4}));
mm->add_literal(migraphx::literal(ss, {2, 3, 4}));
migraphx::shape s(migraphx::shape::float_type, {2, 3, 4});
std::vector<float> vec(s.elements(), 10.0f);
p.add_literal(migraphx::literal(s, vec));
mm->add_literal(migraphx::literal(s, vec));
auto prog = optimize_onnx("const_of_shape_float_test.onnx");
EXPECT(p == prog);
......@@ -449,11 +487,12 @@ TEST_CASE(const_of_shape_float_test)
TEST_CASE(const_of_shape_int64_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ss(migraphx::shape::int32_type, {3});
p.add_literal(migraphx::literal(ss, {2, 3, 4}));
mm->add_literal(migraphx::literal(ss, {2, 3, 4}));
migraphx::shape s(migraphx::shape::int64_type, {2, 3, 4});
std::vector<int64_t> vec(s.elements(), 10);
p.add_literal(migraphx::literal(s, vec));
mm->add_literal(migraphx::literal(s, vec));
auto prog = optimize_onnx("const_of_shape_int64_test.onnx");
EXPECT(p == prog);
......@@ -462,11 +501,12 @@ TEST_CASE(const_of_shape_int64_test)
TEST_CASE(const_of_shape_no_value_attr_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ss(migraphx::shape::int32_type, {3});
p.add_literal(migraphx::literal(ss, {2, 3, 4}));
mm->add_literal(migraphx::literal(ss, {2, 3, 4}));
migraphx::shape s(migraphx::shape::float_type, {2, 3, 4});
std::vector<float> vec(s.elements(), 0.0f);
p.add_literal(migraphx::literal(s, vec));
mm->add_literal(migraphx::literal(s, vec));
auto prog = optimize_onnx("const_of_shape_no_value_attr_test.onnx");
EXPECT(p == prog);
......@@ -480,9 +520,10 @@ TEST_CASE(conv_autopad_fail_test)
TEST_CASE(conv_1d_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 5}});
auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {1, 3, 3}});
p.add_instruction(migraphx::op::convolution{{0}, {1}, {1}}, l0, l1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 3}});
mm->add_instruction(migraphx::op::convolution{{0}, {1}, {1}}, l0, l1);
auto prog = optimize_onnx("conv_1d_test.onnx");
EXPECT(p == prog);
......@@ -491,9 +532,10 @@ TEST_CASE(conv_1d_test)
TEST_CASE(conv_3d_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 5, 5, 5}});
auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {1, 3, 3, 3, 3}});
p.add_instruction(migraphx::op::convolution{{0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, l0, l1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 5, 5, 5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 3, 3, 3}});
mm->add_instruction(migraphx::op::convolution{{0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, l0, l1);
auto prog = optimize_onnx("conv_3d_test.onnx");
EXPECT(p == prog);
......@@ -507,12 +549,13 @@ TEST_CASE(conv_attr_fail_test)
TEST_CASE(conv_autopad_same_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {1, 3, 3, 3}});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 3, 3}});
migraphx::op::convolution op;
op.padding = {1, 1};
op.padding_mode = migraphx::op::padding_mode_t::same;
p.add_instruction(op, l0, l1);
mm->add_instruction(op, l0, l1);
auto prog = optimize_onnx("conv_autopad_same_test.onnx");
EXPECT(p == prog);
......@@ -521,13 +564,14 @@ TEST_CASE(conv_autopad_same_test)
TEST_CASE(conv_bias_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {1, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 5, 5}});
auto l2 = mm->add_parameter("2", {migraphx::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
p.add_instruction(migraphx::op::add{}, l3, l4);
auto l3 = mm->add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = mm->add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
mm->add_instruction(migraphx::op::add{}, l3, l4);
auto prog = optimize_onnx("conv_bias_test.onnx");
EXPECT(p == prog);
......@@ -536,21 +580,22 @@ TEST_CASE(conv_bias_test)
TEST_CASE(conv_bn_relu_maxpool_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {1, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 5, 5}});
auto l2 = mm->add_parameter("2", {migraphx::shape::float_type, {1}});
auto p3 = p.add_parameter("3", {migraphx::shape::float_type, {1}});
auto p4 = p.add_parameter("4", {migraphx::shape::float_type, {1}});
auto p5 = p.add_parameter("5", {migraphx::shape::float_type, {1}});
auto p6 = p.add_parameter("6", {migraphx::shape::float_type, {1}});
auto p3 = mm->add_parameter("3", {migraphx::shape::float_type, {1}});
auto p4 = mm->add_parameter("4", {migraphx::shape::float_type, {1}});
auto p5 = mm->add_parameter("5", {migraphx::shape::float_type, {1}});
auto p6 = mm->add_parameter("6", {migraphx::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraphx::op::batch_norm_inference{1.0e-5f}, l5, p3, p4, p5, p6);
auto l7 = p.add_instruction(migraphx::op::relu{}, l6);
p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l7);
auto l3 = mm->add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = mm->add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
auto l5 = mm->add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = mm->add_instruction(migraphx::op::batch_norm_inference{1.0e-5f}, l5, p3, p4, p5, p6);
auto l7 = mm->add_instruction(migraphx::op::relu{}, l6);
mm->add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l7);
auto prog = optimize_onnx("conv_bn_relu_maxpool_test.onnx");
EXPECT(p == prog);
......@@ -559,15 +604,16 @@ TEST_CASE(conv_bn_relu_maxpool_test)
TEST_CASE(conv_relu_maxpool_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {1, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 5, 5}});
auto l2 = mm->add_parameter("2", {migraphx::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraphx::op::relu{}, l5);
p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
auto l3 = mm->add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = mm->add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
auto l5 = mm->add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = mm->add_instruction(migraphx::op::relu{}, l5);
mm->add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
auto prog = optimize_onnx("conv_relu_maxpool_test.onnx");
EXPECT(p == prog);
......@@ -576,23 +622,24 @@ TEST_CASE(conv_relu_maxpool_test)
TEST_CASE(conv_relu_maxpool_x2_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {5, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {5}});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {5, 3, 5, 5}});
auto l2 = mm->add_parameter("2", {migraphx::shape::float_type, {5}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraphx::op::relu{}, l5);
auto l7 = p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
auto l8 = p.add_parameter("3", {migraphx::shape::float_type, {1, 5, 5, 5}});
auto l9 = p.add_parameter("4", {migraphx::shape::float_type, {1}});
auto l10 = p.add_instruction(migraphx::op::convolution{}, l7, l8);
auto l11 = p.add_instruction(migraphx::op::broadcast{axis, l10->get_shape().lens()}, l9);
auto l12 = p.add_instruction(migraphx::op::add{}, l10, l11);
auto l13 = p.add_instruction(migraphx::op::relu{}, l12);
p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13);
auto l3 = mm->add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = mm->add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
auto l5 = mm->add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = mm->add_instruction(migraphx::op::relu{}, l5);
auto l7 = mm->add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
auto l8 = mm->add_parameter("3", {migraphx::shape::float_type, {1, 5, 5, 5}});
auto l9 = mm->add_parameter("4", {migraphx::shape::float_type, {1}});
auto l10 = mm->add_instruction(migraphx::op::convolution{}, l7, l8);
auto l11 = mm->add_instruction(migraphx::op::broadcast{axis, l10->get_shape().lens()}, l9);
auto l12 = mm->add_instruction(migraphx::op::add{}, l10, l11);
auto l13 = mm->add_instruction(migraphx::op::relu{}, l12);
mm->add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13);
auto prog = optimize_onnx("conv_relu_maxpool_x2_test.onnx");
......@@ -602,13 +649,14 @@ TEST_CASE(conv_relu_maxpool_x2_test)
TEST_CASE(convinteger_bias_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::int8_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraphx::shape::int8_type, {1, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraphx::shape::int32_type, {1}});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::int8_type, {1, 3, 32, 32}});
auto l1 = mm->add_parameter("1", {migraphx::shape::int8_type, {1, 3, 5, 5}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int32_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::quant_convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
p.add_instruction(migraphx::op::add{}, l3, l4);
auto l3 = mm->add_instruction(migraphx::op::quant_convolution{}, l0, l1);
auto l4 = mm->add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
mm->add_instruction(migraphx::op::add{}, l3, l4);
auto prog = optimize_onnx("convinteger_bias_test.onnx");
EXPECT(p == prog);
......@@ -617,8 +665,9 @@ TEST_CASE(convinteger_bias_test)
TEST_CASE(cos_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::cos{}, input);
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
mm->add_instruction(migraphx::op::cos{}, input);
auto prog = optimize_onnx("cos_test.onnx");
EXPECT(p == prog);
......@@ -627,8 +676,9 @@ TEST_CASE(cos_test)
TEST_CASE(cosh_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}});
p.add_instruction(migraphx::op::cosh{}, input);
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}});
mm->add_instruction(migraphx::op::cosh{}, input);
auto prog = optimize_onnx("cosh_test.onnx");
......@@ -638,9 +688,10 @@ TEST_CASE(cosh_test)
TEST_CASE(deconv_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}});
auto l1 = p.add_parameter("w", {migraphx::shape::float_type, {1, 1, 3, 3}});
p.add_instruction(migraphx::op::deconvolution{}, l0, l1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}});
auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 1, 3, 3}});
mm->add_instruction(migraphx::op::deconvolution{}, l0, l1);
auto prog = optimize_onnx("deconv_test.onnx");
EXPECT(p == prog);
......@@ -649,13 +700,14 @@ TEST_CASE(deconv_test)
TEST_CASE(deconv_bias_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}});
auto l1 = p.add_parameter("w", {migraphx::shape::float_type, {1, 1, 3, 3}});
auto l2 = p.add_parameter("b", {migraphx::shape::float_type, {1}});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}});
auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 1, 3, 3}});
auto l2 = mm->add_parameter("b", {migraphx::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::deconvolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
p.add_instruction(migraphx::op::add{}, l3, l4);
auto l3 = mm->add_instruction(migraphx::op::deconvolution{}, l0, l1);
auto l4 = mm->add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
mm->add_instruction(migraphx::op::add{}, l3, l4);
auto prog = optimize_onnx("deconv_bias_test.onnx");
EXPECT(p == prog);
......@@ -664,9 +716,10 @@ TEST_CASE(deconv_bias_test)
TEST_CASE(deconv_input_pads_strides_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}});
auto l1 = p.add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3}});
p.add_instruction(migraphx::op::deconvolution{{1, 1}, {3, 2}}, l0, l1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}});
auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3}});
mm->add_instruction(migraphx::op::deconvolution{{1, 1}, {3, 2}}, l0, l1);
auto prog = optimize_onnx("deconv_input_pads_strides_test.onnx");
EXPECT(p == prog);
......@@ -675,10 +728,11 @@ TEST_CASE(deconv_input_pads_strides_test)
TEST_CASE(deconv_input_pads_asymm_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}});
auto l1 = p.add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3}});
auto l2 = p.add_instruction(migraphx::op::deconvolution{{0, 0}, {3, 2}}, l0, l1);
p.add_instruction(migraphx::op::slice{{2, 3}, {0, 0}, {8, 6}}, l2);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}});
auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3}});
auto l2 = mm->add_instruction(migraphx::op::deconvolution{{0, 0}, {3, 2}}, l0, l1);
mm->add_instruction(migraphx::op::slice{{2, 3}, {0, 0}, {8, 6}}, l2);
auto prog = optimize_onnx("deconv_input_pads_asymm_test.onnx");
EXPECT(p == prog);
......@@ -687,10 +741,11 @@ TEST_CASE(deconv_input_pads_asymm_test)
TEST_CASE(deconv_input_pads_asymm_1d_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", {migraphx::shape::float_type, {1, 1, 3}});
auto l1 = p.add_parameter("w", {migraphx::shape::float_type, {1, 2, 3}});
auto l2 = p.add_instruction(migraphx::op::deconvolution{{0}, {2}, {1}}, l0, l1);
p.add_instruction(migraphx::op::slice{{2}, {0}, {6}}, l2);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 1, 3}});
auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3}});
auto l2 = mm->add_instruction(migraphx::op::deconvolution{{0}, {2}, {1}}, l0, l1);
mm->add_instruction(migraphx::op::slice{{2}, {0}, {6}}, l2);
auto prog = optimize_onnx("deconv_input_pads_asymm_1d_test.onnx");
EXPECT(p == prog);
......@@ -699,10 +754,11 @@ TEST_CASE(deconv_input_pads_asymm_1d_test)
TEST_CASE(deconv_output_padding_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}});
auto l1 = p.add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3}});
auto l2 = p.add_instruction(migraphx::op::deconvolution{{0, 0}, {3, 2}}, l0, l1);
p.add_instruction(migraphx::op::pad{{0, 0, 0, 0, 0, 0, 1, 1}}, l2);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}});
auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3}});
auto l2 = mm->add_instruction(migraphx::op::deconvolution{{0, 0}, {3, 2}}, l0, l1);
mm->add_instruction(migraphx::op::pad{{0, 0, 0, 0, 0, 0, 1, 1}}, l2);
auto prog = optimize_onnx("deconv_output_padding_test.onnx");
EXPECT(p == prog);
......@@ -711,11 +767,12 @@ TEST_CASE(deconv_output_padding_test)
TEST_CASE(deconv_output_padding_3d_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3, 3}});
auto l1 = p.add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3, 3}});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3, 3}});
auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3, 3}});
auto l2 =
p.add_instruction(migraphx::op::deconvolution{{0, 0, 0}, {3, 2, 2}, {1, 1, 1}}, l0, l1);
p.add_instruction(migraphx::op::pad{{0, 0, 0, 0, 0, 0, 0, 1, 1, 1}}, l2);
mm->add_instruction(migraphx::op::deconvolution{{0, 0, 0}, {3, 2, 2}, {1, 1, 1}}, l0, l1);
mm->add_instruction(migraphx::op::pad{{0, 0, 0, 0, 0, 0, 0, 1, 1, 1}}, l2);
auto prog = optimize_onnx("deconv_output_padding_3d_test.onnx");
EXPECT(p == prog);
......@@ -724,10 +781,11 @@ TEST_CASE(deconv_output_padding_3d_test)
TEST_CASE(deconv_output_shape_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}});
auto l1 = p.add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3}});
auto l2 = p.add_instruction(migraphx::op::deconvolution{{0, 0}, {3, 2}}, l0, l1);
p.add_instruction(migraphx::op::pad{{0, 0, 0, 0, 0, 0, 1, 1}}, l2);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}});
auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3}});
auto l2 = mm->add_instruction(migraphx::op::deconvolution{{0, 0}, {3, 2}}, l0, l1);
mm->add_instruction(migraphx::op::pad{{0, 0, 0, 0, 0, 0, 1, 1}}, l2);
auto prog = optimize_onnx("deconv_output_shape_test.onnx");
EXPECT(p == prog);
......@@ -736,11 +794,12 @@ TEST_CASE(deconv_output_shape_test)
TEST_CASE(deconv_output_shape_3d_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3, 3}});
auto l1 = p.add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3, 3}});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3, 3}});
auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3, 3}});
auto l2 =
p.add_instruction(migraphx::op::deconvolution{{0, 0, 0}, {3, 2, 2}, {1, 1, 1}}, l0, l1);
p.add_instruction(migraphx::op::pad{{0, 0, 0, 0, 0, 0, 0, 1, 1, 1}}, l2);
mm->add_instruction(migraphx::op::deconvolution{{0, 0, 0}, {3, 2, 2}, {1, 1, 1}}, l0, l1);
mm->add_instruction(migraphx::op::pad{{0, 0, 0, 0, 0, 0, 0, 1, 1, 1}}, l2);
auto prog = optimize_onnx("deconv_output_shape_3d_test.onnx");
EXPECT(p == prog);
......@@ -749,12 +808,13 @@ TEST_CASE(deconv_output_shape_3d_test)
TEST_CASE(dropout_test)
{
migraphx::program p;
auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}});
auto out = p.add_instruction(migraphx::op::identity{}, input);
auto* mm = p.get_main_module();
auto input = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}});
auto out = mm->add_instruction(migraphx::op::identity{}, input);
migraphx::shape s{migraphx::shape::bool_type, {1, 3, 2, 2}};
std::vector<int8_t> vec(s.elements(), 1);
p.add_literal(migraphx::literal(s, vec));
p.add_return({out});
mm->add_literal(migraphx::literal(s, vec));
mm->add_return({out});
auto prog = migraphx::parse_onnx("dropout_test.onnx");
EXPECT(p == prog);
......@@ -763,8 +823,9 @@ TEST_CASE(dropout_test)
TEST_CASE(elu_test)
{
migraphx::program p;
auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::elu{0.01}, input);
auto* mm = p.get_main_module();
auto input = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
mm->add_instruction(migraphx::op::elu{0.01}, input);
auto prog = optimize_onnx("elu_test.onnx");
......@@ -774,17 +835,18 @@ TEST_CASE(elu_test)
TEST_CASE(embedding_bag_test)
{
migraphx::program p;
auto l0 = p.add_parameter("weight", migraphx::shape{migraphx::shape::float_type, {4, 2}});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("weight", migraphx::shape{migraphx::shape::float_type, {4, 2}});
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {3}}, {1, 0, 2}};
auto l1 = p.add_literal(l);
p.add_literal(0);
auto l4 = p.add_instruction(migraphx::op::gather{}, l0, l1);
auto r1 = p.add_instruction(migraphx::op::reduce_sum{{0}}, l4);
auto l5 = p.add_instruction(migraphx::op::gather{}, l0, l1);
auto r2 = p.add_instruction(migraphx::op::reduce_mean{{0}}, l5);
auto l6 = p.add_instruction(migraphx::op::gather{}, l0, l1);
auto r3 = p.add_instruction(migraphx::op::reduce_max{{0}}, l6);
p.add_return({r1, r2, r3});
auto l1 = mm->add_literal(l);
mm->add_literal(0);
auto l4 = mm->add_instruction(migraphx::op::gather{}, l0, l1);
auto r1 = mm->add_instruction(migraphx::op::reduce_sum{{0}}, l4);
auto l5 = mm->add_instruction(migraphx::op::gather{}, l0, l1);
auto r2 = mm->add_instruction(migraphx::op::reduce_mean{{0}}, l5);
auto l6 = mm->add_instruction(migraphx::op::gather{}, l0, l1);
auto r3 = mm->add_instruction(migraphx::op::reduce_max{{0}}, l6);
mm->add_return({r1, r2, r3});
auto prog = migraphx::parse_onnx("embedding_bag_test.onnx");
......@@ -799,14 +861,15 @@ TEST_CASE(embedding_bag_offset_test)
TEST_CASE(equal_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
auto input1 = p.add_literal(migraphx::literal(s, data));
auto input2 = p.add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {2, 3}});
auto eq = p.add_instruction(migraphx::op::equal{}, input1, input2);
auto ret = p.add_instruction(migraphx::op::convert{migraphx::shape::bool_type}, eq);
p.add_return({ret});
auto input1 = mm->add_literal(migraphx::literal(s, data));
auto input2 = mm->add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {2, 3}});
auto eq = mm->add_instruction(migraphx::op::equal{}, input1, input2);
auto ret = mm->add_instruction(migraphx::op::convert{migraphx::shape::bool_type}, eq);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("equal_test.onnx");
......@@ -816,14 +879,15 @@ TEST_CASE(equal_test)
TEST_CASE(equal_bool_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sf{migraphx::shape::float_type, {2, 3}};
migraphx::shape sb{migraphx::shape::bool_type, {2, 3}};
auto input1 = p.add_parameter("x1", sf);
auto input2 = p.add_parameter("x2", sb);
auto cin1 = p.add_instruction(migraphx::op::convert{migraphx::shape::bool_type}, input1);
auto ret = p.add_instruction(migraphx::op::equal{}, cin1, input2);
p.add_return({ret});
auto input1 = mm->add_parameter("x1", sf);
auto input2 = mm->add_parameter("x2", sb);
auto cin1 = mm->add_instruction(migraphx::op::convert{migraphx::shape::bool_type}, input1);
auto ret = mm->add_instruction(migraphx::op::equal{}, cin1, input2);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("equal_bool_test.onnx");
......@@ -833,8 +897,9 @@ TEST_CASE(equal_bool_test)
TEST_CASE(erf_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
p.add_instruction(migraphx::op::erf{}, input);
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
mm->add_instruction(migraphx::op::erf{}, input);
auto prog = optimize_onnx("erf_test.onnx");
EXPECT(p == prog);
......@@ -843,8 +908,9 @@ TEST_CASE(erf_test)
TEST_CASE(exp_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::exp{}, input);
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
mm->add_instruction(migraphx::op::exp{}, input);
auto prog = optimize_onnx("exp_test.onnx");
EXPECT(p == prog);
......@@ -853,11 +919,12 @@ TEST_CASE(exp_test)
TEST_CASE(expand_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s(migraphx::shape::float_type, {3, 1, 1});
auto param = p.add_parameter("x", s);
auto param = mm->add_parameter("x", s);
migraphx::shape ss(migraphx::shape::int32_type, {4});
p.add_literal(migraphx::literal(ss, {2, 3, 4, 5}));
p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, param);
mm->add_literal(migraphx::literal(ss, {2, 3, 4, 5}));
mm->add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, param);
auto prog = optimize_onnx("expand_test.onnx");
EXPECT(p == prog);
......@@ -866,9 +933,10 @@ TEST_CASE(expand_test)
TEST_CASE(flatten_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
p.add_instruction(migraphx::op::flatten{2}, l0);
p.add_instruction(migraphx::op::flatten{1}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
mm->add_instruction(migraphx::op::flatten{2}, l0);
mm->add_instruction(migraphx::op::flatten{1}, l0);
auto prog = optimize_onnx("flatten_test.onnx");
EXPECT(p == prog);
......@@ -877,8 +945,9 @@ TEST_CASE(flatten_test)
TEST_CASE(floor_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::floor{}, input);
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
mm->add_instruction(migraphx::op::floor{}, input);
auto prog = optimize_onnx("floor_test.onnx");
......@@ -888,10 +957,11 @@ TEST_CASE(floor_test)
TEST_CASE(gather_test)
{
migraphx::program p;
auto l0 = p.add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}});
int axis = 1;
p.add_instruction(migraphx::op::gather{axis}, l0, l1);
mm->add_instruction(migraphx::op::gather{axis}, l0, l1);
auto prog = optimize_onnx("gather_test.onnx");
EXPECT(p == prog);
......@@ -900,24 +970,25 @@ TEST_CASE(gather_test)
TEST_CASE(gather_elements_axis0_test)
{
migraphx::program p;
auto data = p.add_parameter("data", {migraphx::shape::float_type, {3, 4}});
auto indices = p.add_parameter("indices", {migraphx::shape::int32_type, {2, 3}});
auto* mm = p.get_main_module();
auto data = mm->add_parameter("data", {migraphx::shape::float_type, {3, 4}});
auto indices = mm->add_parameter("indices", {migraphx::shape::int32_type, {2, 3}});
std::vector<int> ind_indices{0, 1, 2, 4, 5, 6};
std::vector<int> ind_axis_indices{0, 0, 0, 1, 1, 1};
migraphx::shape ind_s{migraphx::shape::int32_type, {2, 3}};
auto l_data_indices =
p.add_literal(migraphx::literal{ind_s, ind_indices.begin(), ind_indices.end()});
mm->add_literal(migraphx::literal{ind_s, ind_indices.begin(), ind_indices.end()});
auto l_ind_axis_indices =
p.add_literal(migraphx::literal{ind_s, ind_axis_indices.begin(), ind_axis_indices.end()});
auto l_stride = p.add_literal(migraphx::literal{{migraphx::shape::int32_type, {1}}, {4}});
mm->add_literal(migraphx::literal{ind_s, ind_axis_indices.begin(), ind_axis_indices.end()});
auto l_stride = mm->add_literal(migraphx::literal{{migraphx::shape::int32_type, {1}}, {4}});
auto rsp_data = p.add_instruction(migraphx::op::reshape{{12}}, data);
auto lbst_stride = p.add_instruction(migraphx::op::multibroadcast{ind_s.lens()}, l_stride);
auto axis_delta = p.add_instruction(migraphx::op::sub{}, indices, l_ind_axis_indices);
auto mul_delta = p.add_instruction(migraphx::op::mul{}, axis_delta, lbst_stride);
auto ind = p.add_instruction(migraphx::op::add{}, l_data_indices, mul_delta);
auto ret = p.add_instruction(migraphx::op::gather{0}, rsp_data, ind);
p.add_return({ret});
auto rsp_data = mm->add_instruction(migraphx::op::reshape{{12}}, data);
auto lbst_stride = mm->add_instruction(migraphx::op::multibroadcast{ind_s.lens()}, l_stride);
auto axis_delta = mm->add_instruction(migraphx::op::sub{}, indices, l_ind_axis_indices);
auto mul_delta = mm->add_instruction(migraphx::op::mul{}, axis_delta, lbst_stride);
auto ind = mm->add_instruction(migraphx::op::add{}, l_data_indices, mul_delta);
auto ret = mm->add_instruction(migraphx::op::gather{0}, rsp_data, ind);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("gather_elements_axis0_test.onnx");
......@@ -927,24 +998,25 @@ TEST_CASE(gather_elements_axis0_test)
TEST_CASE(gather_elements_axis1_test)
{
migraphx::program p;
auto data = p.add_parameter("data", {migraphx::shape::float_type, {3, 4}});
auto indices = p.add_parameter("indices", {migraphx::shape::int32_type, {2, 3}});
auto* mm = p.get_main_module();
auto data = mm->add_parameter("data", {migraphx::shape::float_type, {3, 4}});
auto indices = mm->add_parameter("indices", {migraphx::shape::int32_type, {2, 3}});
std::vector<int> ind_indices{0, 1, 2, 4, 5, 6};
std::vector<int> ind_axis_indices{0, 1, 2, 0, 1, 2};
migraphx::shape ind_s{migraphx::shape::int32_type, {2, 3}};
auto l_data_indices =
p.add_literal(migraphx::literal{ind_s, ind_indices.begin(), ind_indices.end()});
mm->add_literal(migraphx::literal{ind_s, ind_indices.begin(), ind_indices.end()});
auto l_ind_axis_indices =
p.add_literal(migraphx::literal{ind_s, ind_axis_indices.begin(), ind_axis_indices.end()});
auto l_stride = p.add_literal(migraphx::literal{{migraphx::shape::int32_type, {1}}, {1}});
mm->add_literal(migraphx::literal{ind_s, ind_axis_indices.begin(), ind_axis_indices.end()});
auto l_stride = mm->add_literal(migraphx::literal{{migraphx::shape::int32_type, {1}}, {1}});
auto rsp_data = p.add_instruction(migraphx::op::reshape{{12}}, data);
auto lbst_stride = p.add_instruction(migraphx::op::multibroadcast{ind_s.lens()}, l_stride);
auto axis_delta = p.add_instruction(migraphx::op::sub{}, indices, l_ind_axis_indices);
auto mul_delta = p.add_instruction(migraphx::op::mul{}, axis_delta, lbst_stride);
auto ind = p.add_instruction(migraphx::op::add{}, l_data_indices, mul_delta);
auto ret = p.add_instruction(migraphx::op::gather{0}, rsp_data, ind);
p.add_return({ret});
auto rsp_data = mm->add_instruction(migraphx::op::reshape{{12}}, data);
auto lbst_stride = mm->add_instruction(migraphx::op::multibroadcast{ind_s.lens()}, l_stride);
auto axis_delta = mm->add_instruction(migraphx::op::sub{}, indices, l_ind_axis_indices);
auto mul_delta = mm->add_instruction(migraphx::op::mul{}, axis_delta, lbst_stride);
auto ind = mm->add_instruction(migraphx::op::add{}, l_data_indices, mul_delta);
auto ret = mm->add_instruction(migraphx::op::gather{0}, rsp_data, ind);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("gather_elements_axis1_test.onnx");
......@@ -954,15 +1026,16 @@ TEST_CASE(gather_elements_axis1_test)
TEST_CASE(gemm_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type});
auto t0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{7, 11}}, l2);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type});
auto t0 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l0);
auto t1 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto bl2 = mm->add_instruction(migraphx::op::multibroadcast{{7, 11}}, l2);
auto alpha = 2.f;
auto beta = 2.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, t0, t1, bl2);
mm->add_instruction(migraphx::op::dot{alpha, beta}, t0, t1, bl2);
auto prog = optimize_onnx("gemm_test.onnx");
EXPECT(p == prog);
......@@ -971,13 +1044,14 @@ TEST_CASE(gemm_test)
TEST_CASE(gemm_ex_test)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 8, 6}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 8, 7}});
auto l2 = p.add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}});
auto t0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 8, 6}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 8, 7}});
auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}});
auto t0 = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0);
auto alpha = 0.5f;
auto beta = 0.8f;
p.add_instruction(migraphx::op::dot{alpha, beta}, t0, l1, l2);
mm->add_instruction(migraphx::op::dot{alpha, beta}, t0, l1, l2);
auto prog = optimize_onnx("gemm_ex_test.onnx");
EXPECT(p == prog);
......@@ -986,15 +1060,16 @@ TEST_CASE(gemm_ex_test)
TEST_CASE(gemm_ex_brcst_test)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}});
auto l2 = p.add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 1}});
auto t0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}});
auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 1}});
auto t0 = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0);
std::vector<std::size_t> out_lens{1, 1, 6, 7};
auto t2 = p.add_instruction(migraphx::op::multibroadcast{out_lens}, l2);
auto t2 = mm->add_instruction(migraphx::op::multibroadcast{out_lens}, l2);
auto alpha = 0.5f;
auto beta = 0.8f;
p.add_instruction(migraphx::op::dot{alpha, beta}, t0, l1, t2);
mm->add_instruction(migraphx::op::dot{alpha, beta}, t0, l1, t2);
auto prog = optimize_onnx("gemm_ex_brcst_test.onnx");
EXPECT(p == prog);
......@@ -1003,11 +1078,13 @@ TEST_CASE(gemm_ex_brcst_test)
TEST_CASE(globalavgpool_test)
{
migraphx::program p;
auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto op = migraphx::op::pooling{"average"};
auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]};
p.add_instruction(op, input);
mm->add_instruction(op, input);
auto prog = optimize_onnx("globalavgpool_test.onnx");
......@@ -1017,11 +1094,13 @@ TEST_CASE(globalavgpool_test)
TEST_CASE(globalmaxpool_test)
{
migraphx::program p;
auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto op = migraphx::op::pooling{"max"};
auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]};
p.add_instruction(op, input);
mm->add_instruction(op, input);
auto prog = optimize_onnx("globalmaxpool_test.onnx");
......@@ -1031,14 +1110,15 @@ TEST_CASE(globalmaxpool_test)
TEST_CASE(greater_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
auto input1 = p.add_literal(migraphx::literal(s, data));
auto input2 = p.add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {2, 3}});
auto gr = p.add_instruction(migraphx::op::greater{}, input1, input2);
auto ret = p.add_instruction(migraphx::op::convert{migraphx::shape::bool_type}, gr);
p.add_return({ret});
auto input1 = mm->add_literal(migraphx::literal(s, data));
auto input2 = mm->add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {2, 3}});
auto gr = mm->add_instruction(migraphx::op::greater{}, input1, input2);
auto ret = mm->add_instruction(migraphx::op::convert{migraphx::shape::bool_type}, gr);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("greater_test.onnx");
EXPECT(p == prog);
......@@ -1047,14 +1127,15 @@ TEST_CASE(greater_test)
TEST_CASE(greater_bool_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sf{migraphx::shape::float_type, {2, 3}};
migraphx::shape sb{migraphx::shape::bool_type, {2, 3}};
auto input1 = p.add_parameter("x1", sf);
auto input2 = p.add_parameter("x2", sb);
auto cin1 = p.add_instruction(migraphx::op::convert{migraphx::shape::bool_type}, input1);
auto ret = p.add_instruction(migraphx::op::greater{}, cin1, input2);
p.add_return({ret});
auto input1 = mm->add_parameter("x1", sf);
auto input2 = mm->add_parameter("x2", sb);
auto cin1 = mm->add_instruction(migraphx::op::convert{migraphx::shape::bool_type}, input1);
auto ret = mm->add_instruction(migraphx::op::greater{}, cin1, input2);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("greater_bool_test.onnx");
EXPECT(p == prog);
......@@ -1063,11 +1144,12 @@ TEST_CASE(greater_bool_test)
TEST_CASE(group_conv_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 4, 16, 16}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 1, 3, 3}});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 4, 16, 16}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 1, 3, 3}});
migraphx::op::convolution op;
op.group = 4;
p.add_instruction(op, l0, l1);
mm->add_instruction(op, l0, l1);
auto prog = optimize_onnx("group_conv_test.onnx");
EXPECT(p == prog);
......@@ -1076,15 +1158,16 @@ TEST_CASE(group_conv_test)
TEST_CASE(imagescaler_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {1, 3, 16, 16}};
auto l0 = p.add_parameter("0", s);
auto scale_val = p.add_literal(0.5f);
auto bias_vals = p.add_literal(
auto l0 = mm->add_parameter("0", s);
auto scale_val = mm->add_literal(0.5f);
auto bias_vals = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s.lens()}, scale_val);
auto img_scaled = p.add_instruction(migraphx::op::mul{}, l0, scaled_tensor);
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals);
p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
auto scaled_tensor = mm->add_instruction(migraphx::op::scalar{s.lens()}, scale_val);
auto img_scaled = mm->add_instruction(migraphx::op::mul{}, l0, scaled_tensor);
auto bias_bcast = mm->add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals);
mm->add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
auto prog = optimize_onnx("imagescaler_test.onnx");
......@@ -1094,16 +1177,17 @@ TEST_CASE(imagescaler_test)
TEST_CASE(imagescaler_half_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::half_type, {1, 3, 16, 16}};
auto l0 = p.add_parameter("0", s);
auto l0 = mm->add_parameter("0", s);
auto scale_val =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5f}});
auto bias_vals = p.add_literal(
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5f}});
auto bias_vals = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::half_type, {3}}, {0.01, 0.02, 0.03}});
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s.lens()}, scale_val);
auto img_scaled = p.add_instruction(migraphx::op::mul{}, l0, scaled_tensor);
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals);
p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
auto scaled_tensor = mm->add_instruction(migraphx::op::scalar{s.lens()}, scale_val);
auto img_scaled = mm->add_instruction(migraphx::op::mul{}, l0, scaled_tensor);
auto bias_bcast = mm->add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals);
mm->add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
auto prog = optimize_onnx("imagescaler_half_test.onnx");
......@@ -1113,10 +1197,11 @@ TEST_CASE(imagescaler_half_test)
TEST_CASE(implicit_add_bcast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}});
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, l0, l3);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}});
auto l3 = mm->add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
mm->add_instruction(migraphx::op::add{}, l0, l3);
auto prog = optimize_onnx("implicit_add_bcast_test.onnx");
......@@ -1126,11 +1211,12 @@ TEST_CASE(implicit_add_bcast_test)
TEST_CASE(implicit_add_bcast_user_input_shape_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5, 1}});
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{3, 4, 5, 6}}, l1);
auto r = p.add_instruction(migraphx::op::add{}, l0, l3);
p.add_return({r});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5, 1}});
auto l3 = mm->add_instruction(migraphx::op::multibroadcast{{3, 4, 5, 6}}, l1);
auto r = mm->add_instruction(migraphx::op::add{}, l0, l3);
mm->add_return({r});
migraphx::onnx_options options;
options.map_input_dims["0"] = {3, 4, 5, 6};
......@@ -1143,10 +1229,11 @@ TEST_CASE(implicit_add_bcast_user_input_shape_test)
TEST_CASE(implicit_pow_bcast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}});
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::pow{}, l0, l3);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}});
auto l3 = mm->add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
mm->add_instruction(migraphx::op::pow{}, l0, l3);
auto prog = optimize_onnx("implicit_pow_bcast_test.onnx");
......@@ -1156,10 +1243,11 @@ TEST_CASE(implicit_pow_bcast_test)
TEST_CASE(implicit_sub_bcast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::uint64_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::uint64_type, {4, 5}});
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::sub{}, l0, l3);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::uint64_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::uint64_type, {4, 5}});
auto l3 = mm->add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
mm->add_instruction(migraphx::op::sub{}, l0, l3);
auto prog = optimize_onnx("implicit_sub_bcast_test.onnx");
......@@ -1169,10 +1257,11 @@ TEST_CASE(implicit_sub_bcast_test)
TEST_CASE(initializer_not_an_input)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> w = {1, 2, 3, 4, 5, 6, 7, 8};
auto l1 = p.add_literal(migraphx::literal({migraphx::shape::float_type, {2, 4}}, w));
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 2}});
p.add_instruction(migraphx::op::dot{}, l0, l1);
auto l1 = mm->add_literal(migraphx::literal({migraphx::shape::float_type, {2, 4}}, w));
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 2}});
mm->add_instruction(migraphx::op::dot{}, l0, l1);
auto prog = optimize_onnx("initializer_not_an_input.onnx");
......@@ -1186,25 +1275,26 @@ TEST_CASE(instance_norm_test)
migraphx::shape s2{migraphx::shape::float_type, {2}};
migraphx::program p;
auto x = p.add_parameter("0", s1);
auto scale = p.add_parameter("1", s2);
auto bias = p.add_parameter("2", s2);
auto mean = p.add_instruction(migraphx::op::reduce_mean{{2, 3}}, x);
auto mean_bcast = p.add_instruction(migraphx::op::multibroadcast{dims}, mean);
auto l0 = p.add_instruction(migraphx::op::sqdiff{}, x, mean_bcast);
auto variance = p.add_instruction(migraphx::op::reduce_mean{{2, 3}}, l0);
auto l1 = p.add_instruction(migraphx::op::sub{}, x, mean_bcast);
auto epsilon_literal = p.add_literal(1e-5f);
auto epsilon_bcast = p.add_instruction(migraphx::op::multibroadcast{dims}, epsilon_literal);
auto variance_bcast = p.add_instruction(migraphx::op::multibroadcast{dims}, variance);
auto l2 = p.add_instruction(migraphx::op::add{}, variance_bcast, epsilon_bcast);
auto l3 = p.add_instruction(migraphx::op::rsqrt{}, l2);
auto l4 = p.add_instruction(migraphx::op::mul{}, l1, l3);
auto scale_bcast = p.add_instruction(migraphx::op::broadcast{1, dims}, scale);
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, dims}, bias);
auto l5 = p.add_instruction(migraphx::op::mul{}, l4, scale_bcast);
p.add_instruction(migraphx::op::add{}, l5, bias_bcast);
auto* mm = p.get_main_module();
auto x = mm->add_parameter("0", s1);
auto scale = mm->add_parameter("1", s2);
auto bias = mm->add_parameter("2", s2);
auto mean = mm->add_instruction(migraphx::op::reduce_mean{{2, 3}}, x);
auto mean_bcast = mm->add_instruction(migraphx::op::multibroadcast{dims}, mean);
auto l0 = mm->add_instruction(migraphx::op::sqdiff{}, x, mean_bcast);
auto variance = mm->add_instruction(migraphx::op::reduce_mean{{2, 3}}, l0);
auto l1 = mm->add_instruction(migraphx::op::sub{}, x, mean_bcast);
auto epsilon_literal = mm->add_literal(1e-5f);
auto epsilon_bcast = mm->add_instruction(migraphx::op::multibroadcast{dims}, epsilon_literal);
auto variance_bcast = mm->add_instruction(migraphx::op::multibroadcast{dims}, variance);
auto l2 = mm->add_instruction(migraphx::op::add{}, variance_bcast, epsilon_bcast);
auto l3 = mm->add_instruction(migraphx::op::rsqrt{}, l2);
auto l4 = mm->add_instruction(migraphx::op::mul{}, l1, l3);
auto scale_bcast = mm->add_instruction(migraphx::op::broadcast{1, dims}, scale);
auto bias_bcast = mm->add_instruction(migraphx::op::broadcast{1, dims}, bias);
auto l5 = mm->add_instruction(migraphx::op::mul{}, l4, scale_bcast);
mm->add_instruction(migraphx::op::add{}, l5, bias_bcast);
auto prog = optimize_onnx("instance_norm_test.onnx");
......@@ -1214,9 +1304,10 @@ TEST_CASE(instance_norm_test)
TEST_CASE(leaky_relu_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
float alpha = 0.01f;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::leaky_relu{alpha}, l0);
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {3}});
mm->add_instruction(migraphx::op::leaky_relu{alpha}, l0);
auto prog = optimize_onnx("leaky_relu_test.onnx");
......@@ -1226,14 +1317,15 @@ TEST_CASE(leaky_relu_test)
TEST_CASE(less_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
auto input1 = p.add_literal(migraphx::literal(s, data));
auto input2 = p.add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {2, 3}});
auto le = p.add_instruction(migraphx::op::less{}, input1, input2);
auto ret = p.add_instruction(migraphx::op::convert{migraphx::shape::bool_type}, le);
p.add_return({ret});
auto input1 = mm->add_literal(migraphx::literal(s, data));
auto input2 = mm->add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {2, 3}});
auto le = mm->add_instruction(migraphx::op::less{}, input1, input2);
auto ret = mm->add_instruction(migraphx::op::convert{migraphx::shape::bool_type}, le);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("less_test.onnx");
EXPECT(p == prog);
......@@ -1242,14 +1334,15 @@ TEST_CASE(less_test)
TEST_CASE(less_bool_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sf{migraphx::shape::float_type, {2, 3}};
migraphx::shape sb{migraphx::shape::bool_type, {2, 3}};
auto input1 = p.add_parameter("x1", sf);
auto input2 = p.add_parameter("x2", sb);
auto cin1 = p.add_instruction(migraphx::op::convert{migraphx::shape::bool_type}, input1);
auto ret = p.add_instruction(migraphx::op::less{}, cin1, input2);
p.add_return({ret});
auto input1 = mm->add_parameter("x1", sf);
auto input2 = mm->add_parameter("x2", sb);
auto cin1 = mm->add_instruction(migraphx::op::convert{migraphx::shape::bool_type}, input1);
auto ret = mm->add_instruction(migraphx::op::less{}, cin1, input2);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("less_bool_test.onnx");
EXPECT(p == prog);
......@@ -1258,8 +1351,9 @@ TEST_CASE(less_bool_test)
TEST_CASE(log_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::log{}, input);
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
mm->add_instruction(migraphx::op::log{}, input);
auto prog = optimize_onnx("log_test.onnx");
EXPECT(p == prog);
......@@ -1268,9 +1362,10 @@ TEST_CASE(log_test)
TEST_CASE(logsoftmax_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
int axis = 1;
p.add_instruction(migraphx::op::logsoftmax{axis}, l0);
mm->add_instruction(migraphx::op::logsoftmax{axis}, l0);
auto prog = optimize_onnx("logsoftmax_test.onnx");
EXPECT(p == prog);
......@@ -1279,13 +1374,14 @@ TEST_CASE(logsoftmax_test)
TEST_CASE(lrn_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 28, 24, 24}});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 28, 24, 24}});
migraphx::op::lrn op;
op.size = 5;
op.alpha = 0.0001;
op.beta = 0.75;
op.bias = 1.0;
p.add_instruction(op, l0);
mm->add_instruction(op, l0);
auto prog = optimize_onnx("lrn_test.onnx");
EXPECT(p == prog);
......@@ -1294,11 +1390,12 @@ TEST_CASE(lrn_test)
TEST_CASE(matmul_bmbm_test)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 6, 7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 2, 1, 7, 8}});
auto bl0 = p.add_instruction(migraphx::op::multibroadcast{{5, 2, 3, 6, 7}}, l0);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{5, 2, 3, 7, 8}}, l1);
p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, bl0, bl1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 6, 7}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 2, 1, 7, 8}});
auto bl0 = mm->add_instruction(migraphx::op::multibroadcast{{5, 2, 3, 6, 7}}, l0);
auto bl1 = mm->add_instruction(migraphx::op::multibroadcast{{5, 2, 3, 7, 8}}, l1);
mm->add_instruction(migraphx::op::dot{1.0f, 0.0f}, bl0, bl1);
auto prog = optimize_onnx("matmul_bmbm_test.onnx");
......@@ -1308,12 +1405,13 @@ TEST_CASE(matmul_bmbm_test)
TEST_CASE(matmul_bmv_test)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 6, 7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl1 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l1);
auto bsl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 7, 1}}, sl1);
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, bsl1);
p.add_instruction(migraphx::op::squeeze{{2}}, res);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 6, 7}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl1 = mm->add_instruction(migraphx::op::unsqueeze{{1}}, l1);
auto bsl1 = mm->add_instruction(migraphx::op::multibroadcast{{3, 7, 1}}, sl1);
auto res = mm->add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, bsl1);
mm->add_instruction(migraphx::op::squeeze{{2}}, res);
auto prog = optimize_onnx("matmul_bmv_test.onnx");
......@@ -1323,11 +1421,12 @@ TEST_CASE(matmul_bmv_test)
TEST_CASE(matmul_mv_test)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {6, 7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl1 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l1);
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, sl1);
p.add_instruction(migraphx::op::squeeze{{1}}, res);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {6, 7}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl1 = mm->add_instruction(migraphx::op::unsqueeze{{1}}, l1);
auto res = mm->add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, sl1);
mm->add_instruction(migraphx::op::squeeze{{1}}, res);
auto prog = optimize_onnx("matmul_mv_test.onnx");
......@@ -1337,12 +1436,13 @@ TEST_CASE(matmul_mv_test)
TEST_CASE(matmul_vbm_test)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 7, 8}});
auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0);
auto bsl0 = p.add_instruction(migraphx::op::multibroadcast{{5, 1, 7}}, sl0);
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, bsl0, l1);
p.add_instruction(migraphx::op::squeeze{{1}}, res);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 7, 8}});
auto sl0 = mm->add_instruction(migraphx::op::unsqueeze{{0}}, l0);
auto bsl0 = mm->add_instruction(migraphx::op::multibroadcast{{5, 1, 7}}, sl0);
auto res = mm->add_instruction(migraphx::op::dot{1.0f, 0.0f}, bsl0, l1);
mm->add_instruction(migraphx::op::squeeze{{1}}, res);
auto prog = optimize_onnx("matmul_vbm_test.onnx");
......@@ -1352,11 +1452,12 @@ TEST_CASE(matmul_vbm_test)
TEST_CASE(matmul_vm_test)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7, 8}});
auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0);
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, sl0, l1);
p.add_instruction(migraphx::op::squeeze{{0}}, res);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7, 8}});
auto sl0 = mm->add_instruction(migraphx::op::unsqueeze{{0}}, l0);
auto res = mm->add_instruction(migraphx::op::dot{1.0f, 0.0f}, sl0, l1);
mm->add_instruction(migraphx::op::squeeze{{0}}, res);
auto prog = optimize_onnx("matmul_vm_test.onnx");
......@@ -1366,13 +1467,14 @@ TEST_CASE(matmul_vm_test)
TEST_CASE(matmul_vv_test)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0);
auto sl1 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l1);
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, sl0, sl1);
auto sr0 = p.add_instruction(migraphx::op::squeeze{{0}}, res);
p.add_instruction(migraphx::op::squeeze{{0}}, sr0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl0 = mm->add_instruction(migraphx::op::unsqueeze{{0}}, l0);
auto sl1 = mm->add_instruction(migraphx::op::unsqueeze{{1}}, l1);
auto res = mm->add_instruction(migraphx::op::dot{1.0f, 0.0f}, sl0, sl1);
auto sr0 = mm->add_instruction(migraphx::op::squeeze{{0}}, res);
mm->add_instruction(migraphx::op::squeeze{{0}}, sr0);
auto prog = optimize_onnx("matmul_vv_test.onnx");
......@@ -1382,9 +1484,10 @@ TEST_CASE(matmul_vv_test)
TEST_CASE(matmulinteger_test)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::int8_type, {3, 6, 16}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::int8_type, {3, 16, 8}});
p.add_instruction(migraphx::op::quant_dot{1, 0}, l0, l1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int8_type, {3, 6, 16}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::int8_type, {3, 16, 8}});
mm->add_instruction(migraphx::op::quant_dot{1, 0}, l0, l1);
auto prog = optimize_onnx("matmulinteger_test.onnx");
......@@ -1394,11 +1497,12 @@ TEST_CASE(matmulinteger_test)
TEST_CASE(max_test)
{
migraphx::program p;
auto input0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
auto input1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3}});
auto input2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {3}});
auto l0 = p.add_instruction(migraphx::op::max{}, input0, input1);
p.add_instruction(migraphx::op::max{}, l0, input2);
auto* mm = p.get_main_module();
auto input0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
auto input1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3}});
auto input2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {3}});
auto l0 = mm->add_instruction(migraphx::op::max{}, input0, input1);
mm->add_instruction(migraphx::op::max{}, l0, input2);
optimize_onnx("max_test.onnx");
}
......@@ -1406,11 +1510,12 @@ TEST_CASE(max_test)
TEST_CASE(maxpool_notset_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
std::vector<int64_t> pads = {0, 0, 0, 0, 0, 0, 1, 1};
float val = std::numeric_limits<float>::lowest();
auto ins_pad = p.add_instruction(migraphx::op::pad{pads, val}, input);
p.add_instruction(migraphx::op::pooling{"max", {0, 0}, {2, 2}, {6, 6}}, ins_pad);
auto ins_pad = mm->add_instruction(migraphx::op::pad{pads, val}, input);
mm->add_instruction(migraphx::op::pooling{"max", {0, 0}, {2, 2}, {6, 6}}, ins_pad);
auto prog = optimize_onnx("maxpool_notset_test.onnx");
......@@ -1420,11 +1525,12 @@ TEST_CASE(maxpool_notset_test)
TEST_CASE(maxpool_same_upper_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
std::vector<int64_t> pads = {0, 0, 0, 0, 0, 0, 1, 1};
float val = std::numeric_limits<float>::lowest();
auto ins_pad = p.add_instruction(migraphx::op::pad{pads, val}, input);
p.add_instruction(migraphx::op::pooling{"max", {0, 0}, {1, 1}, {2, 2}}, ins_pad);
auto ins_pad = mm->add_instruction(migraphx::op::pad{pads, val}, input);
mm->add_instruction(migraphx::op::pooling{"max", {0, 0}, {1, 1}, {2, 2}}, ins_pad);
auto prog = optimize_onnx("maxpool_same_upper_test.onnx");
......@@ -1434,11 +1540,12 @@ TEST_CASE(maxpool_same_upper_test)
TEST_CASE(min_test)
{
migraphx::program p;
auto input0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
auto input1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3}});
auto input2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {3}});
auto l0 = p.add_instruction(migraphx::op::min{}, input0, input1);
p.add_instruction(migraphx::op::min{}, l0, input2);
auto* mm = p.get_main_module();
auto input0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
auto input1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3}});
auto input2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {3}});
auto l0 = mm->add_instruction(migraphx::op::min{}, input0, input1);
mm->add_instruction(migraphx::op::min{}, l0, input2);
optimize_onnx("min_test.onnx");
}
......@@ -1446,8 +1553,9 @@ TEST_CASE(min_test)
TEST_CASE(no_pad_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}});
p.add_instruction(migraphx::op::identity{}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}});
mm->add_instruction(migraphx::op::identity{}, l0);
auto prog = optimize_onnx("no_pad_test.onnx");
EXPECT(p == prog);
......@@ -1456,10 +1564,11 @@ TEST_CASE(no_pad_test)
TEST_CASE(neg_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int64_type, {2, 3}};
auto input = p.add_parameter("0", s);
auto ret = p.add_instruction(migraphx::op::neg{}, input);
p.add_return({ret});
auto input = mm->add_parameter("0", s);
auto ret = mm->add_instruction(migraphx::op::neg{}, input);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("neg_test.onnx");
......@@ -1469,14 +1578,15 @@ TEST_CASE(neg_test)
TEST_CASE(nonzero_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
std::vector<float> data = {1, 0, 1, 1};
p.add_literal(migraphx::literal(s, data));
mm->add_literal(migraphx::literal(s, data));
migraphx::shape si{migraphx::shape::int64_type, {2, 3}};
std::vector<int64_t> indices = {0, 1, 1, 0, 0, 1};
auto r = p.add_literal(migraphx::literal(si, indices));
p.add_return({r});
auto r = mm->add_literal(migraphx::literal(si, indices));
mm->add_return({r});
auto prog = migraphx::parse_onnx("nonzero_test.onnx");
EXPECT(p == prog);
......@@ -1485,14 +1595,15 @@ TEST_CASE(nonzero_test)
TEST_CASE(nonzero_int_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int16_type, {2, 3}};
std::vector<int> data = {1, 1, 0, 1, 0, 1};
p.add_literal(migraphx::literal(s, data.begin(), data.end()));
mm->add_literal(migraphx::literal(s, data.begin(), data.end()));
migraphx::shape si{migraphx::shape::int64_type, {2, 4}};
std::vector<int64_t> indices = {0, 0, 1, 1, 0, 1, 0, 2};
auto r = p.add_literal(migraphx::literal(si, indices));
p.add_return({r});
auto r = mm->add_literal(migraphx::literal(si, indices));
mm->add_return({r});
auto prog = migraphx::parse_onnx("nonzero_int_test.onnx");
EXPECT(p == prog);
......@@ -1501,24 +1612,25 @@ TEST_CASE(nonzero_int_test)
TEST_CASE(onehot_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s_ind{migraphx::shape::int32_type, {5, 2}};
migraphx::shape s_val{migraphx::shape::half_type, {2}};
p.add_literal(3);
auto l_ind = p.add_parameter("indices", s_ind);
auto l_val = p.add_parameter("values", s_val);
mm->add_literal(3);
auto l_ind = mm->add_parameter("indices", s_ind);
auto l_val = mm->add_parameter("values", s_val);
migraphx::shape s_dep{migraphx::shape::half_type, {3, 3}};
std::vector<float> data_dep{1, 0, 0, 0, 1, 0, 0, 0, 1};
auto l_dep = p.add_literal(migraphx::literal(s_dep, data_dep));
auto gather_out = p.add_instruction(migraphx::op::gather{0}, l_dep, l_ind);
auto tr_out = p.add_instruction(migraphx::op::transpose{{2, 0, 1}}, gather_out);
auto off_val = p.add_instruction(migraphx::op::slice{{0}, {0}, {1}}, l_val);
auto on_val = p.add_instruction(migraphx::op::slice{{0}, {1}, {2}}, l_val);
auto diff = p.add_instruction(migraphx::op::sub{}, on_val, off_val);
auto mb_off_val = p.add_instruction(migraphx::op::multibroadcast{{3, 5, 2}}, off_val);
auto mb_diff = p.add_instruction(migraphx::op::multibroadcast{{3, 5, 2}}, diff);
auto mul = p.add_instruction(migraphx::op::mul{}, tr_out, mb_diff);
auto r = p.add_instruction(migraphx::op::add{}, mul, mb_off_val);
p.add_return({r});
auto l_dep = mm->add_literal(migraphx::literal(s_dep, data_dep));
auto gather_out = mm->add_instruction(migraphx::op::gather{0}, l_dep, l_ind);
auto tr_out = mm->add_instruction(migraphx::op::transpose{{2, 0, 1}}, gather_out);
auto off_val = mm->add_instruction(migraphx::op::slice{{0}, {0}, {1}}, l_val);
auto on_val = mm->add_instruction(migraphx::op::slice{{0}, {1}, {2}}, l_val);
auto diff = mm->add_instruction(migraphx::op::sub{}, on_val, off_val);
auto mb_off_val = mm->add_instruction(migraphx::op::multibroadcast{{3, 5, 2}}, off_val);
auto mb_diff = mm->add_instruction(migraphx::op::multibroadcast{{3, 5, 2}}, diff);
auto mul = mm->add_instruction(migraphx::op::mul{}, tr_out, mb_diff);
auto r = mm->add_instruction(migraphx::op::add{}, mul, mb_off_val);
mm->add_return({r});
auto prog = migraphx::parse_onnx("onehot_test.onnx");
......@@ -1528,8 +1640,9 @@ TEST_CASE(onehot_test)
TEST_CASE(pad_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}});
p.add_instruction(migraphx::op::pad{{1, 1, 1, 1}}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}});
mm->add_instruction(migraphx::op::pad{{1, 1, 1, 1}}, l0);
auto prog = optimize_onnx("pad_test.onnx");
EXPECT(p == prog);
......@@ -1538,11 +1651,12 @@ TEST_CASE(pad_test)
TEST_CASE(pad_3arg_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}});
p.add_literal({migraphx::shape{migraphx::shape::float_type}, {1.0f}});
p.add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {1, 1, 2, 2}});
auto r = p.add_instruction(migraphx::op::pad{{1, 1, 2, 2}, 1.0f}, l0);
p.add_return({r});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}});
mm->add_literal({migraphx::shape{migraphx::shape::float_type}, {1.0f}});
mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {1, 1, 2, 2}});
auto r = mm->add_instruction(migraphx::op::pad{{1, 1, 2, 2}, 1.0f}, l0);
mm->add_return({r});
auto prog = migraphx::parse_onnx("pad_3arg_test.onnx");
......@@ -1552,13 +1666,14 @@ TEST_CASE(pad_3arg_test)
TEST_CASE(pad_reflect_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}});
p.add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {0, 2, 0, 1}});
auto l1 = p.add_instruction(migraphx::op::slice{{0, 1}, {0, 1}, {2, 2}}, l0);
auto l2 = p.add_instruction(migraphx::op::slice{{0, 1}, {0, 0}, {2, 1}}, l0);
auto l3 = p.add_instruction(migraphx::op::slice{{0, 1}, {0, 0}, {2, 1}}, l0);
auto r = p.add_instruction(migraphx::op::concat{1}, l2, l1, l0, l3);
p.add_return({r});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}});
mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {0, 2, 0, 1}});
auto l1 = mm->add_instruction(migraphx::op::slice{{0, 1}, {0, 1}, {2, 2}}, l0);
auto l2 = mm->add_instruction(migraphx::op::slice{{0, 1}, {0, 0}, {2, 1}}, l0);
auto l3 = mm->add_instruction(migraphx::op::slice{{0, 1}, {0, 0}, {2, 1}}, l0);
auto r = mm->add_instruction(migraphx::op::concat{1}, l2, l1, l0, l3);
mm->add_return({r});
auto prog = migraphx::parse_onnx("pad_reflect_test.onnx");
......@@ -1568,15 +1683,16 @@ TEST_CASE(pad_reflect_test)
TEST_CASE(pad_reflect_multiaxis_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3}});
p.add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {0, 2, 2, 0}});
auto l1 = p.add_instruction(migraphx::op::slice{{0, 1}, {0, 1}, {2, 2}}, l0);
auto l2 = p.add_instruction(migraphx::op::slice{{0, 1}, {0, 2}, {2, 3}}, l0);
auto l3 = p.add_instruction(migraphx::op::concat{1}, l2, l1, l0);
auto l4 = p.add_instruction(migraphx::op::slice{{0, 1}, {0, 0}, {1, 5}}, l3);
auto l5 = p.add_instruction(migraphx::op::slice{{0, 1}, {1, 0}, {2, 5}}, l3);
auto r = p.add_instruction(migraphx::op::concat{0}, l3, l4, l5);
p.add_return({r});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3}});
mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {0, 2, 2, 0}});
auto l1 = mm->add_instruction(migraphx::op::slice{{0, 1}, {0, 1}, {2, 2}}, l0);
auto l2 = mm->add_instruction(migraphx::op::slice{{0, 1}, {0, 2}, {2, 3}}, l0);
auto l3 = mm->add_instruction(migraphx::op::concat{1}, l2, l1, l0);
auto l4 = mm->add_instruction(migraphx::op::slice{{0, 1}, {0, 0}, {1, 5}}, l3);
auto l5 = mm->add_instruction(migraphx::op::slice{{0, 1}, {1, 0}, {2, 5}}, l3);
auto r = mm->add_instruction(migraphx::op::concat{0}, l3, l4, l5);
mm->add_return({r});
auto prog = migraphx::parse_onnx("pad_reflect_multiaxis_test.onnx");
......@@ -1586,9 +1702,10 @@ TEST_CASE(pad_reflect_multiaxis_test)
TEST_CASE(pow_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
p.add_instruction(migraphx::op::pow{}, l0, l1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
mm->add_instruction(migraphx::op::pow{}, l0, l1);
auto prog = optimize_onnx("pow_test.onnx");
......@@ -1598,11 +1715,12 @@ TEST_CASE(pow_test)
TEST_CASE(prelu_brcst_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5}});
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{l0->get_shape().lens()}, l1);
auto ret = p.add_instruction(migraphx::op::prelu{}, l0, bl1);
p.add_return({ret});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5}});
auto bl1 = mm->add_instruction(migraphx::op::multibroadcast{l0->get_shape().lens()}, l1);
auto ret = mm->add_instruction(migraphx::op::prelu{}, l0, bl1);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("prelu_brcst_test.onnx");
......@@ -1612,10 +1730,11 @@ TEST_CASE(prelu_brcst_test)
TEST_CASE(range_test)
{
migraphx::program p;
p.add_literal(int64_t{10});
p.add_literal(int64_t{6});
p.add_literal(int64_t{-3});
p.add_literal(migraphx::literal{{migraphx::shape::int64_type, {2}}, {10, 7}});
auto* mm = p.get_main_module();
mm->add_literal(int64_t{10});
mm->add_literal(int64_t{6});
mm->add_literal(int64_t{-3});
mm->add_literal(migraphx::literal{{migraphx::shape::int64_type, {2}}, {10, 7}});
auto prog = optimize_onnx("range_test.onnx");
......@@ -1625,10 +1744,11 @@ TEST_CASE(range_test)
TEST_CASE(range_float_test)
{
migraphx::program p;
p.add_literal(float{2});
p.add_literal(float{11});
p.add_literal(float{2});
p.add_literal(migraphx::literal{{migraphx::shape::float_type, {5}}, {2, 4, 6, 8, 10}});
auto* mm = p.get_main_module();
mm->add_literal(float{2});
mm->add_literal(float{11});
mm->add_literal(float{2});
mm->add_literal(migraphx::literal{{migraphx::shape::float_type, {5}}, {2, 4, 6, 8, 10}});
auto prog = optimize_onnx("range_float_test.onnx");
......@@ -1638,8 +1758,9 @@ TEST_CASE(range_float_test)
TEST_CASE(recip_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::recip{}, input);
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3}});
mm->add_instruction(migraphx::op::recip{}, input);
auto prog = optimize_onnx("recip_test.onnx");
......@@ -1649,10 +1770,11 @@ TEST_CASE(recip_test)
TEST_CASE(reducel1_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto abs_l0 = p.add_instruction(migraphx::op::abs{}, l0);
auto sum_l0 = p.add_instruction(migraphx::op::reduce_sum{{-2}}, abs_l0);
p.add_instruction(migraphx::op::squeeze{{-2}}, sum_l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto abs_l0 = mm->add_instruction(migraphx::op::abs{}, l0);
auto sum_l0 = mm->add_instruction(migraphx::op::reduce_sum{{-2}}, abs_l0);
mm->add_instruction(migraphx::op::squeeze{{-2}}, sum_l0);
auto prog = optimize_onnx("reducel1_test.onnx");
EXPECT(p == prog);
......@@ -1661,11 +1783,12 @@ TEST_CASE(reducel1_test)
TEST_CASE(reducel2_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto square_l0 = p.add_instruction(migraphx::op::mul{}, l0, l0);
auto sum_l0 = p.add_instruction(migraphx::op::reduce_sum{{-1}}, square_l0);
auto squ_l0 = p.add_instruction(migraphx::op::squeeze{{-1}}, sum_l0);
p.add_instruction(migraphx::op::sqrt{}, squ_l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto square_l0 = mm->add_instruction(migraphx::op::mul{}, l0, l0);
auto sum_l0 = mm->add_instruction(migraphx::op::reduce_sum{{-1}}, square_l0);
auto squ_l0 = mm->add_instruction(migraphx::op::squeeze{{-1}}, sum_l0);
mm->add_instruction(migraphx::op::sqrt{}, squ_l0);
auto prog = optimize_onnx("reducel2_test.onnx");
EXPECT(p == prog);
......@@ -1674,9 +1797,10 @@ TEST_CASE(reducel2_test)
TEST_CASE(reduce_log_sum_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto sum_l0 = p.add_instruction(migraphx::op::reduce_sum{{-3}}, l0);
p.add_instruction(migraphx::op::log{}, sum_l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto sum_l0 = mm->add_instruction(migraphx::op::reduce_sum{{-3}}, l0);
mm->add_instruction(migraphx::op::log{}, sum_l0);
auto prog = optimize_onnx("reduce_log_sum_test.onnx");
EXPECT(p == prog);
......@@ -1685,10 +1809,11 @@ TEST_CASE(reduce_log_sum_test)
TEST_CASE(reduce_log_sum_exp_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto exp_l0 = p.add_instruction(migraphx::op::exp{}, l0);
auto sum_l0 = p.add_instruction(migraphx::op::reduce_sum{{-4}}, exp_l0);
p.add_instruction(migraphx::op::log{}, sum_l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto exp_l0 = mm->add_instruction(migraphx::op::exp{}, l0);
auto sum_l0 = mm->add_instruction(migraphx::op::reduce_sum{{-4}}, exp_l0);
mm->add_instruction(migraphx::op::log{}, sum_l0);
auto prog = optimize_onnx("reduce_log_sum_exp_test.onnx");
EXPECT(p == prog);
......@@ -1697,8 +1822,9 @@ TEST_CASE(reduce_log_sum_exp_test)
TEST_CASE(reducemax_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_instruction(migraphx::op::reduce_max{{2}}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
mm->add_instruction(migraphx::op::reduce_max{{2}}, l0);
auto prog = optimize_onnx("reducemax_test.onnx");
EXPECT(p == prog);
......@@ -1707,9 +1833,10 @@ TEST_CASE(reducemax_test)
TEST_CASE(reducemean_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_instruction(migraphx::op::reduce_mean{{2, 3}}, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = mm->add_instruction(migraphx::op::reduce_mean{{2, 3}}, l0);
mm->add_instruction(migraphx::op::squeeze{{2, 3}}, l1);
auto prog = optimize_onnx("reducemean_test.onnx");
EXPECT(p == prog);
......@@ -1718,8 +1845,9 @@ TEST_CASE(reducemean_test)
TEST_CASE(reducemean_keepdims_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_instruction(migraphx::op::reduce_mean{{2}}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
mm->add_instruction(migraphx::op::reduce_mean{{2}}, l0);
auto prog = optimize_onnx("reducemean_keepdims_test.onnx");
EXPECT(p == prog);
......@@ -1728,9 +1856,10 @@ TEST_CASE(reducemean_keepdims_test)
TEST_CASE(reducemin_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_instruction(migraphx::op::reduce_min{{2, 3}}, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = mm->add_instruction(migraphx::op::reduce_min{{2, 3}}, l0);
mm->add_instruction(migraphx::op::squeeze{{2, 3}}, l1);
auto prog = optimize_onnx("reducemin_test.onnx");
EXPECT(p == prog);
......@@ -1739,8 +1868,9 @@ TEST_CASE(reducemin_test)
TEST_CASE(reduceprod_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_instruction(migraphx::op::reduce_prod{{2}}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
mm->add_instruction(migraphx::op::reduce_prod{{2}}, l0);
auto prog = optimize_onnx("reduceprod_test.onnx");
EXPECT(p == prog);
......@@ -1749,9 +1879,10 @@ TEST_CASE(reduceprod_test)
TEST_CASE(reducesum_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_instruction(migraphx::op::reduce_sum{{2}}, l0);
p.add_instruction(migraphx::op::squeeze{{2}}, l1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = mm->add_instruction(migraphx::op::reduce_sum{{2}}, l0);
mm->add_instruction(migraphx::op::squeeze{{2}}, l1);
auto prog = optimize_onnx("reducesum_test.onnx");
EXPECT(p == prog);
......@@ -1760,9 +1891,10 @@ TEST_CASE(reducesum_test)
TEST_CASE(reducesum_multiaxis_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_instruction(migraphx::op::reduce_sum{{2, 3}}, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = mm->add_instruction(migraphx::op::reduce_sum{{2, 3}}, l0);
mm->add_instruction(migraphx::op::squeeze{{2, 3}}, l1);
auto prog = optimize_onnx("reducesum_multiaxis_test.onnx");
EXPECT(p == prog);
......@@ -1771,8 +1903,9 @@ TEST_CASE(reducesum_multiaxis_test)
TEST_CASE(reducesum_keepdims_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_instruction(migraphx::op::reduce_sum{{2, 3}}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
mm->add_instruction(migraphx::op::reduce_sum{{2, 3}}, l0);
auto prog = optimize_onnx("reducesum_keepdims_test.onnx");
EXPECT(p == prog);
......@@ -1781,10 +1914,11 @@ TEST_CASE(reducesum_keepdims_test)
TEST_CASE(reducesum_square_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto squ_l0 = p.add_instruction(migraphx::op::mul{}, l0, l0);
auto sum_l0 = p.add_instruction(migraphx::op::reduce_sum{{-2}}, squ_l0);
p.add_instruction(migraphx::op::squeeze{{-2}}, sum_l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto squ_l0 = mm->add_instruction(migraphx::op::mul{}, l0, l0);
auto sum_l0 = mm->add_instruction(migraphx::op::reduce_sum{{-2}}, squ_l0);
mm->add_instruction(migraphx::op::squeeze{{-2}}, sum_l0);
auto prog = optimize_onnx("reducesum_square_test.onnx");
EXPECT(p == prog);
......@@ -1793,14 +1927,15 @@ TEST_CASE(reducesum_square_test)
TEST_CASE(reshape_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::op::reshape op;
std::vector<int64_t> reshape_dims{3, 8};
p.add_literal(
mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, reshape_dims});
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
op.dims = reshape_dims;
p.add_instruction(op, l0);
p.add_instruction(op, l0);
mm->add_instruction(op, l0);
mm->add_instruction(op, l0);
auto prog = optimize_onnx("reshape_test.onnx");
EXPECT(p == prog);
......@@ -1809,13 +1944,14 @@ TEST_CASE(reshape_test)
TEST_CASE(reshape_non_standard_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::op::reshape op;
std::vector<int64_t> reshape_dims{4, 3, 2};
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4}};
auto x = p.add_parameter("x", s);
auto tran_x = p.add_instruction(migraphx::op::transpose{{0, 2, 1}}, x);
auto cont_x = p.add_instruction(migraphx::op::contiguous{}, tran_x);
p.add_instruction(migraphx::op::reshape{{4, 3, 2}}, cont_x);
auto x = mm->add_parameter("x", s);
auto tran_x = mm->add_instruction(migraphx::op::transpose{{0, 2, 1}}, x);
auto cont_x = mm->add_instruction(migraphx::op::contiguous{}, tran_x);
mm->add_instruction(migraphx::op::reshape{{4, 3, 2}}, cont_x);
auto prog = optimize_onnx("reshape_non_standard_test.onnx");
EXPECT(p == prog);
......@@ -1954,8 +2090,9 @@ TEST_CASE(resize_upsample_pf_test)
TEST_CASE(round_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}});
p.add_instruction(migraphx::op::round{}, input);
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}});
mm->add_instruction(migraphx::op::round{}, input);
auto prog = optimize_onnx("round_test.onnx");
EXPECT(p == prog);
......@@ -1964,29 +2101,30 @@ TEST_CASE(round_test)
TEST_CASE(selu_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<std::size_t> lens = {2, 3};
migraphx::shape s{migraphx::shape::double_type, lens};
auto x = p.add_parameter("x", s);
auto x = mm->add_parameter("x", s);
migraphx::shape ls{migraphx::shape::double_type, {1}};
auto la = p.add_literal({ls, {0.3}});
auto lg = p.add_literal({ls, {0.25}});
auto mbla = p.add_instruction(migraphx::op::multibroadcast{lens}, la);
auto mblg = p.add_instruction(migraphx::op::multibroadcast{lens}, lg);
auto la = mm->add_literal({ls, {0.3}});
auto lg = mm->add_literal({ls, {0.25}});
auto mbla = mm->add_instruction(migraphx::op::multibroadcast{lens}, la);
auto mblg = mm->add_instruction(migraphx::op::multibroadcast{lens}, lg);
auto sign_x = p.add_instruction(migraphx::op::sign{}, x);
auto exp_x = p.add_instruction(migraphx::op::exp{}, x);
auto sign_x = mm->add_instruction(migraphx::op::sign{}, x);
auto exp_x = mm->add_instruction(migraphx::op::exp{}, x);
auto mlax = p.add_instruction(migraphx::op::mul{}, mbla, exp_x);
auto smlax = p.add_instruction(migraphx::op::sub{}, mlax, mbla);
auto mlax = mm->add_instruction(migraphx::op::mul{}, mbla, exp_x);
auto smlax = mm->add_instruction(migraphx::op::sub{}, mlax, mbla);
auto item1 = p.add_instruction(migraphx::op::add{}, smlax, x);
auto item2 = p.add_instruction(migraphx::op::sub{}, smlax, x);
auto item1 = mm->add_instruction(migraphx::op::add{}, smlax, x);
auto item2 = mm->add_instruction(migraphx::op::sub{}, smlax, x);
auto sitem2 = p.add_instruction(migraphx::op::mul{}, sign_x, item2);
auto item12 = p.add_instruction(migraphx::op::sub{}, item1, sitem2);
auto r = p.add_instruction(migraphx::op::mul{}, item12, mblg);
p.add_return({r});
auto sitem2 = mm->add_instruction(migraphx::op::mul{}, sign_x, item2);
auto item12 = mm->add_instruction(migraphx::op::sub{}, item1, sitem2);
auto r = mm->add_instruction(migraphx::op::mul{}, item12, mblg);
mm->add_return({r});
auto prog = migraphx::parse_onnx("selu_test.onnx");
......@@ -1996,10 +2134,11 @@ TEST_CASE(selu_test)
TEST_CASE(shape_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}};
auto l0 = p.add_parameter("x", s);
auto l0 = mm->add_parameter("x", s);
migraphx::shape s_shape{migraphx::shape::int64_type, {4}};
p.add_literal(s_shape, l0->get_shape().lens());
mm->add_literal(s_shape, l0->get_shape().lens());
auto prog = optimize_onnx("shape_test.onnx");
EXPECT(p == prog);
......@@ -2008,13 +2147,14 @@ TEST_CASE(shape_test)
TEST_CASE(shape_gather_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {7, 3, 10}});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {7, 3, 10}});
migraphx::shape const_shape{migraphx::shape::int32_type, {1}};
auto l2 = p.add_literal(migraphx::literal{const_shape, {1}});
auto l2 = mm->add_literal(migraphx::literal{const_shape, {1}});
auto l1 =
p.add_literal(migraphx::shape{migraphx::shape::int64_type, {3}}, l0->get_shape().lens());
mm->add_literal(migraphx::shape{migraphx::shape::int64_type, {3}}, l0->get_shape().lens());
int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, l1, l2);
mm->add_instruction(migraphx::op::gather{axis}, l1, l2);
auto prog = optimize_onnx("shape_gather_test.onnx");
EXPECT(p == prog);
......@@ -2023,8 +2163,9 @@ TEST_CASE(shape_gather_test)
TEST_CASE(sign_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}});
p.add_instruction(migraphx::op::sign{}, input);
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}});
mm->add_instruction(migraphx::op::sign{}, input);
auto prog = optimize_onnx("sign_test.onnx");
EXPECT(p == prog);
......@@ -2033,8 +2174,9 @@ TEST_CASE(sign_test)
TEST_CASE(sin_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::sin{}, input);
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
mm->add_instruction(migraphx::op::sin{}, input);
auto prog = optimize_onnx("sin_test.onnx");
EXPECT(p == prog);
......@@ -2043,8 +2185,9 @@ TEST_CASE(sin_test)
TEST_CASE(sinh_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::sinh{}, input);
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
mm->add_instruction(migraphx::op::sinh{}, input);
auto prog = optimize_onnx("sinh_test.onnx");
......@@ -2054,8 +2197,9 @@ TEST_CASE(sinh_test)
TEST_CASE(slice_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 2}});
p.add_instruction(migraphx::op::slice{{0, 1}, {1, 0}, {2, 2}}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 2}});
mm->add_instruction(migraphx::op::slice{{0, 1}, {1, 0}, {2, 2}}, l0);
auto prog = optimize_onnx("slice_test.onnx");
EXPECT(p == prog);
......@@ -2064,11 +2208,12 @@ TEST_CASE(slice_test)
TEST_CASE(slice_3arg_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 5}});
p.add_literal({{migraphx::shape::int32_type, {2}}, {0, 0}});
p.add_literal({{migraphx::shape::int32_type, {2}}, {2, 5}});
auto ret = p.add_instruction(migraphx::op::slice{{0, 1}, {0, 0}, {2, 5}}, l0);
p.add_return({ret});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 5}});
mm->add_literal({{migraphx::shape::int32_type, {2}}, {0, 0}});
mm->add_literal({{migraphx::shape::int32_type, {2}}, {2, 5}});
auto ret = mm->add_instruction(migraphx::op::slice{{0, 1}, {0, 0}, {2, 5}}, l0);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("slice_3arg_test.onnx");
......@@ -2078,13 +2223,14 @@ TEST_CASE(slice_3arg_test)
TEST_CASE(slice_5arg_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 5}});
p.add_literal({{migraphx::shape::int32_type, {2}}, {1, 1}});
p.add_literal({{migraphx::shape::int32_type, {2}}, {-1, -2}});
p.add_literal({{migraphx::shape::int32_type, {2}}, {-1, -1}});
p.add_literal({{migraphx::shape::int32_type, {2}}, {-5, -3}});
auto ret = p.add_instruction(migraphx::op::slice{{-1, -2}, {-5, -3}, {-1, -1}}, l0);
p.add_return({ret});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 5}});
mm->add_literal({{migraphx::shape::int32_type, {2}}, {1, 1}});
mm->add_literal({{migraphx::shape::int32_type, {2}}, {-1, -2}});
mm->add_literal({{migraphx::shape::int32_type, {2}}, {-1, -1}});
mm->add_literal({{migraphx::shape::int32_type, {2}}, {-5, -3}});
auto ret = mm->add_instruction(migraphx::op::slice{{-1, -2}, {-5, -3}, {-1, -1}}, l0);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("slice_5arg_test.onnx");
......@@ -2094,8 +2240,9 @@ TEST_CASE(slice_5arg_test)
TEST_CASE(slice_max_end_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {10, 20}});
p.add_instruction(migraphx::op::slice{{0, 1}, {1, 2}, {3000000000, -1}}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {10, 20}});
mm->add_instruction(migraphx::op::slice{{0, 1}, {1, 2}, {3000000000, -1}}, l0);
auto prog = optimize_onnx("slice_max_end_test.onnx");
EXPECT(p == prog);
......@@ -2104,8 +2251,9 @@ TEST_CASE(slice_max_end_test)
TEST_CASE(softmax_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
p.add_instruction(migraphx::op::softmax{1}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
mm->add_instruction(migraphx::op::softmax{1}, l0);
auto prog = optimize_onnx("softmax_test.onnx");
EXPECT(p == prog);
......@@ -2114,11 +2262,12 @@ TEST_CASE(softmax_test)
TEST_CASE(split_minus_axis_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
auto r1 = p.add_instruction(migraphx::op::slice{{-1}, {0}, {5}}, input);
auto r2 = p.add_instruction(migraphx::op::slice{{-1}, {5}, {10}}, input);
auto r3 = p.add_instruction(migraphx::op::slice{{-1}, {10}, {15}}, input);
p.add_return({r1, r2, r3});
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
auto r1 = mm->add_instruction(migraphx::op::slice{{-1}, {0}, {5}}, input);
auto r2 = mm->add_instruction(migraphx::op::slice{{-1}, {5}, {10}}, input);
auto r3 = mm->add_instruction(migraphx::op::slice{{-1}, {10}, {15}}, input);
mm->add_return({r1, r2, r3});
auto prog = migraphx::parse_onnx("split_minus_axis_test.onnx");
......@@ -2128,11 +2277,12 @@ TEST_CASE(split_minus_axis_test)
TEST_CASE(split_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
auto r1 = p.add_instruction(migraphx::op::slice{{1}, {0}, {7}}, input);
auto r2 = p.add_instruction(migraphx::op::slice{{1}, {7}, {11}}, input);
auto r3 = p.add_instruction(migraphx::op::slice{{1}, {11}, {15}}, input);
p.add_return({r1, r2, r3});
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
auto r1 = mm->add_instruction(migraphx::op::slice{{1}, {0}, {7}}, input);
auto r2 = mm->add_instruction(migraphx::op::slice{{1}, {7}, {11}}, input);
auto r3 = mm->add_instruction(migraphx::op::slice{{1}, {11}, {15}}, input);
mm->add_return({r1, r2, r3});
auto prog = migraphx::parse_onnx("split_test.onnx");
EXPECT(p == prog);
......@@ -2141,10 +2291,11 @@ TEST_CASE(split_test)
TEST_CASE(split_test_default)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
auto r1 = p.add_instruction(migraphx::op::slice{{0}, {0}, {5}}, input);
auto r2 = p.add_instruction(migraphx::op::slice{{0}, {5}, {10}}, input);
p.add_return({r1, r2});
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
auto r1 = mm->add_instruction(migraphx::op::slice{{0}, {0}, {5}}, input);
auto r2 = mm->add_instruction(migraphx::op::slice{{0}, {5}, {10}}, input);
mm->add_return({r1, r2});
auto prog = migraphx::parse_onnx("split_test_default.onnx");
EXPECT(p == prog);
......@@ -2153,8 +2304,9 @@ TEST_CASE(split_test_default)
TEST_CASE(sqrt_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
p.add_instruction(migraphx::op::sqrt{}, input);
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
mm->add_instruction(migraphx::op::sqrt{}, input);
auto prog = optimize_onnx("sqrt_test.onnx");
EXPECT(p == prog);
......@@ -2163,12 +2315,13 @@ TEST_CASE(sqrt_test)
TEST_CASE(squeeze_unsqueeze_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int64_t> squeeze_axes{0, 2, 3, 5};
std::vector<int64_t> unsqueeze_axes{0, 1, 3, 5};
auto l0 =
p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 1, 1, 2, 1}});
auto l1 = p.add_instruction(migraphx::op::squeeze{squeeze_axes}, l0);
p.add_instruction(migraphx::op::unsqueeze{unsqueeze_axes}, l1);
mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 1, 1, 2, 1}});
auto l1 = mm->add_instruction(migraphx::op::squeeze{squeeze_axes}, l0);
mm->add_instruction(migraphx::op::unsqueeze{unsqueeze_axes}, l1);
auto prog = optimize_onnx("squeeze_unsqueeze_test.onnx");
EXPECT(p == prog);
......@@ -2177,10 +2330,11 @@ TEST_CASE(squeeze_unsqueeze_test)
TEST_CASE(sub_bcast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::sub{}, l0, l2);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = mm->add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1);
mm->add_instruction(migraphx::op::sub{}, l0, l2);
auto prog = optimize_onnx("sub_bcast_test.onnx");
......@@ -2190,10 +2344,11 @@ TEST_CASE(sub_bcast_test)
TEST_CASE(sub_scalar_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1}});
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::sub{}, l0, m1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1}});
auto m1 = mm->add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
mm->add_instruction(migraphx::op::sub{}, l0, m1);
auto prog = optimize_onnx("sub_scalar_test.onnx");
EXPECT(p == prog);
......@@ -2202,13 +2357,14 @@ TEST_CASE(sub_scalar_test)
TEST_CASE(sum_int_test)
{
migraphx::program p;
auto input0 = p.add_parameter("0", migraphx::shape{migraphx::shape::int16_type, {3}});
auto input1 = p.add_parameter("1", migraphx::shape{migraphx::shape::uint16_type, {3}});
auto input2 = p.add_parameter("2", migraphx::shape{migraphx::shape::uint32_type, {3}});
auto cin0 = p.add_instruction(migraphx::op::convert{migraphx::shape::uint32_type}, input0);
auto cin1 = p.add_instruction(migraphx::op::convert{migraphx::shape::uint32_type}, input1);
auto l0 = p.add_instruction(migraphx::op::add{}, cin0, cin1);
p.add_instruction(migraphx::op::add{}, l0, input2);
auto* mm = p.get_main_module();
auto input0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::int16_type, {3}});
auto input1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::uint16_type, {3}});
auto input2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::uint32_type, {3}});
auto cin0 = mm->add_instruction(migraphx::op::convert{migraphx::shape::uint32_type}, input0);
auto cin1 = mm->add_instruction(migraphx::op::convert{migraphx::shape::uint32_type}, input1);
auto l0 = mm->add_instruction(migraphx::op::add{}, cin0, cin1);
mm->add_instruction(migraphx::op::add{}, l0, input2);
auto prog = optimize_onnx("sum_int_test.onnx");
EXPECT(p == prog);
......@@ -2217,11 +2373,12 @@ TEST_CASE(sum_int_test)
TEST_CASE(sum_test)
{
migraphx::program p;
auto input0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
auto input1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3}});
auto input2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {3}});
auto l0 = p.add_instruction(migraphx::op::add{}, input0, input1);
p.add_instruction(migraphx::op::add{}, l0, input2);
auto* mm = p.get_main_module();
auto input0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
auto input1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3}});
auto input2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {3}});
auto l0 = mm->add_instruction(migraphx::op::add{}, input0, input1);
mm->add_instruction(migraphx::op::add{}, l0, input2);
auto prog = optimize_onnx("sum_test.onnx");
EXPECT(p == prog);
......@@ -2230,31 +2387,33 @@ TEST_CASE(sum_test)
TEST_CASE(sum_type_test)
{
migraphx::program p;
auto l_bool = p.add_literal({migraphx::shape{migraphx::shape::bool_type, {2}}, {1, 0}});
auto l_int8 = p.add_literal({migraphx::shape{migraphx::shape::int8_type, {2}}, {1, 1}});
auto l_uint8 = p.add_literal({migraphx::shape{migraphx::shape::uint8_type, {2}}, {1, 1}});
auto l_uint16 = p.add_literal({migraphx::shape{migraphx::shape::uint16_type, {2}}, {1, 1}});
auto l_uint32 = p.add_literal({migraphx::shape{migraphx::shape::uint32_type, {2}}, {1, 1}});
auto l_uint64 = p.add_literal({migraphx::shape{migraphx::shape::uint64_type, {2}}, {1, 1}});
auto l_double = p.add_literal({migraphx::shape{migraphx::shape::double_type, {2}}, {1, 1}});
auto l_raw = p.add_literal({migraphx::shape{migraphx::shape::double_type, {2}}, {1.5, 2.0}});
auto o_bool = p.add_instruction(migraphx::op::convert{migraphx::shape::double_type}, l_bool);
auto o_int8 = p.add_instruction(migraphx::op::convert{migraphx::shape::double_type}, l_int8);
auto o_uint8 = p.add_instruction(migraphx::op::convert{migraphx::shape::double_type}, l_uint8);
auto* mm = p.get_main_module();
auto l_bool = mm->add_literal({migraphx::shape{migraphx::shape::bool_type, {2}}, {1, 0}});
auto l_int8 = mm->add_literal({migraphx::shape{migraphx::shape::int8_type, {2}}, {1, 1}});
auto l_uint8 = mm->add_literal({migraphx::shape{migraphx::shape::uint8_type, {2}}, {1, 1}});
auto l_uint16 = mm->add_literal({migraphx::shape{migraphx::shape::uint16_type, {2}}, {1, 1}});
auto l_uint32 = mm->add_literal({migraphx::shape{migraphx::shape::uint32_type, {2}}, {1, 1}});
auto l_uint64 = mm->add_literal({migraphx::shape{migraphx::shape::uint64_type, {2}}, {1, 1}});
auto l_double = mm->add_literal({migraphx::shape{migraphx::shape::double_type, {2}}, {1, 1}});
auto l_raw = mm->add_literal({migraphx::shape{migraphx::shape::double_type, {2}}, {1.5, 2.0}});
auto o_bool = mm->add_instruction(migraphx::op::convert{migraphx::shape::double_type}, l_bool);
auto o_int8 = mm->add_instruction(migraphx::op::convert{migraphx::shape::double_type}, l_int8);
auto o_uint8 =
mm->add_instruction(migraphx::op::convert{migraphx::shape::double_type}, l_uint8);
auto o_uint16 =
p.add_instruction(migraphx::op::convert{migraphx::shape::double_type}, l_uint16);
mm->add_instruction(migraphx::op::convert{migraphx::shape::double_type}, l_uint16);
auto o_uint32 =
p.add_instruction(migraphx::op::convert{migraphx::shape::double_type}, l_uint32);
mm->add_instruction(migraphx::op::convert{migraphx::shape::double_type}, l_uint32);
auto o_uint64 =
p.add_instruction(migraphx::op::convert{migraphx::shape::double_type}, l_uint64);
auto s0 = p.add_instruction(migraphx::op::add{}, o_bool, o_int8);
auto s1 = p.add_instruction(migraphx::op::add{}, s0, o_uint8);
auto s2 = p.add_instruction(migraphx::op::add{}, s1, o_uint16);
auto s3 = p.add_instruction(migraphx::op::add{}, s2, o_uint32);
auto s4 = p.add_instruction(migraphx::op::add{}, s3, o_uint64);
auto s5 = p.add_instruction(migraphx::op::add{}, s4, l_double);
auto s6 = p.add_instruction(migraphx::op::add{}, s5, l_raw);
p.add_return({s6});
mm->add_instruction(migraphx::op::convert{migraphx::shape::double_type}, l_uint64);
auto s0 = mm->add_instruction(migraphx::op::add{}, o_bool, o_int8);
auto s1 = mm->add_instruction(migraphx::op::add{}, s0, o_uint8);
auto s2 = mm->add_instruction(migraphx::op::add{}, s1, o_uint16);
auto s3 = mm->add_instruction(migraphx::op::add{}, s2, o_uint32);
auto s4 = mm->add_instruction(migraphx::op::add{}, s3, o_uint64);
auto s5 = mm->add_instruction(migraphx::op::add{}, s4, l_double);
auto s6 = mm->add_instruction(migraphx::op::add{}, s5, l_raw);
mm->add_return({s6});
auto prog = migraphx::parse_onnx("sum_type_test.onnx");
......@@ -2264,8 +2423,9 @@ TEST_CASE(sum_type_test)
TEST_CASE(tan_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::tan{}, input);
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
mm->add_instruction(migraphx::op::tan{}, input);
auto prog = optimize_onnx("tan_test.onnx");
EXPECT(p == prog);
......@@ -2274,8 +2434,9 @@ TEST_CASE(tan_test)
TEST_CASE(tanh_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}});
p.add_instruction(migraphx::op::tanh{}, input);
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}});
mm->add_instruction(migraphx::op::tanh{}, input);
auto prog = optimize_onnx("tanh_test.onnx");
......@@ -2285,9 +2446,10 @@ TEST_CASE(tanh_test)
TEST_CASE(tile_test)
{
migraphx::program p;
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, {1, 2}});
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
p.add_instruction(migraphx::op::concat{1}, input, input);
auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, {1, 2}});
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
mm->add_instruction(migraphx::op::concat{1}, input, input);
auto prog = optimize_onnx("tile_test.onnx");
......@@ -2297,11 +2459,12 @@ TEST_CASE(tile_test)
TEST_CASE(tile_test_3x2)
{
migraphx::program p;
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, {3, 2}});
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto l0 = p.add_instruction(migraphx::op::concat{0}, input, input);
auto l1 = p.add_instruction(migraphx::op::concat{0}, l0, input);
p.add_instruction(migraphx::op::concat{1}, l1, l1);
auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, {3, 2}});
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto l0 = mm->add_instruction(migraphx::op::concat{0}, input, input);
auto l1 = mm->add_instruction(migraphx::op::concat{0}, l0, input);
mm->add_instruction(migraphx::op::concat{1}, l1, l1);
auto prog = optimize_onnx("tile_test_3x2.onnx");
......@@ -2311,9 +2474,10 @@ TEST_CASE(tile_test_3x2)
TEST_CASE(transpose_test)
{
migraphx::program p;
auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto* mm = p.get_main_module();
auto input = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
std::vector<int64_t> perm{0, 3, 1, 2};
p.add_instruction(migraphx::op::transpose{perm}, input);
mm->add_instruction(migraphx::op::transpose{perm}, input);
auto prog = optimize_onnx("transpose_test.onnx");
......@@ -2323,36 +2487,39 @@ TEST_CASE(transpose_test)
TEST_CASE(transpose_gather_test)
{
migraphx::program p;
auto make_contiguous = [&p](migraphx::instruction_ref ins) {
auto* mm = p.get_main_module();
auto make_contiguous = [&mm](migraphx::instruction_ref ins) {
if(ins->get_shape().standard())
{
return ins;
}
return p.add_instruction(migraphx::op::contiguous{}, ins);
return mm->add_instruction(migraphx::op::contiguous{}, ins);
};
auto data = p.add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 5, 4, 6}});
auto data =
mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 5, 4, 6}});
auto ind =
p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 4, 3, 5}});
auto tr_data = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3}}, data);
auto tr_ind = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3}}, ind);
mm->add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 4, 3, 5}});
auto tr_data = mm->add_instruction(migraphx::op::transpose{{0, 2, 1, 3}}, data);
auto tr_ind = mm->add_instruction(migraphx::op::transpose{{0, 2, 1, 3}}, ind);
int axis = 1;
p.add_instruction(
mm->add_instruction(
migraphx::op::gather{axis}, make_contiguous(tr_data), make_contiguous(tr_ind));
auto prog = optimize_onnx("transpose_gather_test.onnx");
EXPECT(p.sort() == prog.sort());
EXPECT(mm->sort() == prog.sort());
}
TEST_CASE(undefined_test)
{
migraphx::program p;
p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_instruction(migraphx::op::undefined{});
auto l2 = p.add_instruction(migraphx::op::identity{}, l1);
p.add_return({l2});
auto* mm = p.get_main_module();
mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_instruction(migraphx::op::undefined{});
auto l2 = mm->add_instruction(migraphx::op::identity{}, l1);
mm->add_return({l2});
auto prog = migraphx::parse_onnx("undefined_test.onnx");
......@@ -2362,10 +2529,11 @@ TEST_CASE(undefined_test)
TEST_CASE(unknown_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::unknown{"Unknown"}, l0, l1);
p.add_instruction(migraphx::op::unknown{"Unknown"}, l2);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = mm->add_instruction(migraphx::op::unknown{"Unknown"}, l0, l1);
mm->add_instruction(migraphx::op::unknown{"Unknown"}, l2);
auto prog = optimize_onnx("unknown_test.onnx");
EXPECT(p == prog);
......@@ -2384,19 +2552,20 @@ TEST_CASE(unknown_test_throw)
TEST_CASE(upsample_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ss{migraphx::shape::float_type, {4}};
p.add_literal(migraphx::literal(ss, {1.0f, 1.0f, 2.0f, 3.0f}));
mm->add_literal(migraphx::literal(ss, {1.0f, 1.0f, 2.0f, 3.0f}));
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
auto ix = p.add_parameter("X", sx);
auto ix = mm->add_parameter("X", sx);
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}};
std::vector<int> ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3};
auto li = p.add_literal(migraphx::literal(si, ind));
auto rsp = p.add_instruction(migraphx::op::reshape{{4}}, ix);
auto r = p.add_instruction(migraphx::op::gather{0}, rsp, li);
p.add_return({r});
auto li = mm->add_literal(migraphx::literal(si, ind));
auto rsp = mm->add_instruction(migraphx::op::reshape{{4}}, ix);
auto r = mm->add_instruction(migraphx::op::gather{0}, rsp, li);
mm->add_return({r});
auto prog = migraphx::parse_onnx("upsample_test.onnx");
......@@ -2413,8 +2582,9 @@ TEST_CASE(unknown_test_throw_print_error)
TEST_CASE(variable_batch_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::identity{}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
mm->add_instruction(migraphx::op::identity{}, l0);
auto prog = optimize_onnx("variable_batch_test.onnx");
EXPECT(p == prog);
......@@ -2423,9 +2593,10 @@ TEST_CASE(variable_batch_test)
TEST_CASE(variable_batch_user_input_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 16, 16}});
auto r = p.add_instruction(migraphx::op::identity{}, l0);
p.add_return({r});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 16, 16}});
auto r = mm->add_instruction(migraphx::op::identity{}, l0);
mm->add_return({r});
migraphx::onnx_options options;
options.default_dim_value = 2;
......@@ -2438,9 +2609,10 @@ TEST_CASE(variable_batch_user_input_test)
TEST_CASE(variable_batch_leq_zero_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::add{}, l0, l1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
mm->add_instruction(migraphx::op::add{}, l0, l1);
auto prog = optimize_onnx("variable_batch_leq_zero_test.onnx");
EXPECT(p == prog);
......@@ -2449,18 +2621,19 @@ TEST_CASE(variable_batch_leq_zero_test)
TEST_CASE(where_test)
{
migraphx::program p;
auto lc = p.add_parameter("c", migraphx::shape{migraphx::shape::bool_type, {2}});
auto lx = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto ly = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 1, 2, 2}});
auto lcc = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, lc);
auto lxm = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 2, 2}}, lx);
auto lym = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 2, 2}}, ly);
auto lxy = p.add_instruction(migraphx::op::sub{}, lxm, lym);
auto lccm = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 2, 2}}, lcc);
auto lm = p.add_instruction(migraphx::op::mul{}, lxy, lccm);
auto lym1 = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 2, 2}}, ly);
auto r = p.add_instruction(migraphx::op::add{}, lm, lym1);
p.add_return({r});
auto* mm = p.get_main_module();
auto lc = mm->add_parameter("c", migraphx::shape{migraphx::shape::bool_type, {2}});
auto lx = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto ly = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 1, 2, 2}});
auto lcc = mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, lc);
auto lxm = mm->add_instruction(migraphx::op::multibroadcast{{2, 2, 2, 2}}, lx);
auto lym = mm->add_instruction(migraphx::op::multibroadcast{{2, 2, 2, 2}}, ly);
auto lxy = mm->add_instruction(migraphx::op::sub{}, lxm, lym);
auto lccm = mm->add_instruction(migraphx::op::multibroadcast{{2, 2, 2, 2}}, lcc);
auto lm = mm->add_instruction(migraphx::op::mul{}, lxy, lccm);
auto lym1 = mm->add_instruction(migraphx::op::multibroadcast{{2, 2, 2, 2}}, ly);
auto r = mm->add_instruction(migraphx::op::add{}, lm, lym1);
mm->add_return({r});
auto prog = migraphx::parse_onnx("where_test.onnx");
......
......@@ -53,6 +53,14 @@ def test_add_scalar():
print(r)
def test_module():
p = migraphx.parse_onnx("add_scalar_test.onnx")
mm = p.get_main_module()
p.print()
mm.print()
test_conv_relu()
test_module()
if sys.version_info >= (3, 0):
test_add_scalar()
......@@ -167,7 +167,9 @@ std::vector<argument> run(program& p, const program::parameter_map& params)
std::vector<shape> get_output_shapes(program& p) { return p.get_output_shapes(); }
void print(const program& p) { std::cout << p << std::endl; }
void print_program(const program& p) { std::cout << p << std::endl; }
void print_module(const module& m) { std::cout << m << std::endl; }
} // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment