Commit bc5d7f75 authored by Paul's avatar Paul
Browse files

Merge from develop

parents 47c0854d a5b0afa0
 sinh-example:;
xy"Sinh test_sinhZ
x


b
y


B
\ No newline at end of file
#include <iostream>
#include <vector>
#include <migraph/literal.hpp>
#include <migraph/operators.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/onnx.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/onnx.hpp>
#include "test.hpp"
void pytorch_conv_bias_test()
TEST_CASE(pytorch_conv_bias_test)
{
migraph::program p;
auto l0 = p.add_parameter("0", {migraph::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraph::shape::float_type, {1, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {1}});
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {1, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2);
p.add_instruction(migraph::op::add{}, l3, l4);
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2);
p.add_instruction(migraphx::op::add{}, l3, l4);
auto prog = migraph::parse_onnx("conv.onnx");
auto prog = migraphx::parse_onnx("conv.onnx");
EXPECT(p == prog);
}
void pytorch_conv_relu_maxpool()
TEST_CASE(pytorch_conv_relu_maxpool)
{
migraph::program p;
auto l0 = p.add_parameter("0", {migraph::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraph::shape::float_type, {1, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {1}});
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {1, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2);
auto l5 = p.add_instruction(migraph::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::op::relu{}, l5);
p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2);
auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraphx::op::relu{}, l5);
p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
auto prog = migraph::parse_onnx("conv_relu_maxpool.onnx");
auto prog = migraphx::parse_onnx("conv_relu_maxpool.onnx");
EXPECT(p == prog);
}
void pytorch_conv_bn_relu_maxpool()
TEST_CASE(pytorch_conv_bn_relu_maxpool)
{
migraph::program p;
auto l0 = p.add_parameter("0", {migraph::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraph::shape::float_type, {1, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {1}});
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {1, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}});
auto p3 = p.add_parameter("3", {migraph::shape::float_type, {1}});
auto p4 = p.add_parameter("4", {migraph::shape::float_type, {1}});
auto p5 = p.add_parameter("5", {migraph::shape::float_type, {1}});
auto p6 = p.add_parameter("6", {migraph::shape::float_type, {1}});
auto p3 = p.add_parameter("3", {migraphx::shape::float_type, {1}});
auto p4 = p.add_parameter("4", {migraphx::shape::float_type, {1}});
auto p5 = p.add_parameter("5", {migraphx::shape::float_type, {1}});
auto p6 = p.add_parameter("6", {migraphx::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2);
auto l5 = p.add_instruction(migraph::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::op::batch_norm_inference{1.0e-5f}, l5, p3, p4, p5, p6);
auto l7 = p.add_instruction(migraph::op::relu{}, l6);
p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l7);
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2);
auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraphx::op::batch_norm_inference{1.0e-5f}, l5, p3, p4, p5, p6);
auto l7 = p.add_instruction(migraphx::op::relu{}, l6);
p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l7);
auto prog = migraph::parse_onnx("conv_bn_relu_maxpool.onnx");
auto prog = migraphx::parse_onnx("conv_bn_relu_maxpool.onnx");
EXPECT(p == prog);
}
void pytorch_conv_relu_maxpool_x2()
TEST_CASE(pytorch_conv_relu_maxpool_x2)
{
migraph::program p;
auto l0 = p.add_parameter("0", {migraph::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraph::shape::float_type, {5, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {5}});
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {5, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {5}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2);
auto l5 = p.add_instruction(migraph::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::op::relu{}, l5);
auto l7 = p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2);
auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraphx::op::relu{}, l5);
auto l7 = p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
auto l8 = p.add_parameter("3", {migraph::shape::float_type, {1, 5, 5, 5}});
auto l9 = p.add_parameter("4", {migraph::shape::float_type, {1}});
auto l10 = p.add_instruction(migraph::op::convolution{}, l7, l8);
auto l11 = p.add_instruction(migraph::op::broadcast{axis, l10->get_shape()}, l9);
auto l12 = p.add_instruction(migraph::op::add{}, l10, l11);
auto l13 = p.add_instruction(migraph::op::relu{}, l12);
p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13);
auto l8 = p.add_parameter("3", {migraphx::shape::float_type, {1, 5, 5, 5}});
auto l9 = p.add_parameter("4", {migraphx::shape::float_type, {1}});
auto l10 = p.add_instruction(migraphx::op::convolution{}, l7, l8);
auto l11 = p.add_instruction(migraphx::op::broadcast{axis, l10->get_shape()}, l9);
auto l12 = p.add_instruction(migraphx::op::add{}, l10, l11);
auto l13 = p.add_instruction(migraphx::op::relu{}, l12);
p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13);
auto prog = migraph::parse_onnx("conv_relu_maxpoolX2.onnx");
auto prog = migraphx::parse_onnx("conv_relu_maxpoolX2.onnx");
EXPECT(p == prog);
}
void leaky_relu_test()
TEST_CASE(leaky_relu_test)
{
migraph::program p;
migraphx::program p;
float alpha = 0.01f;
auto l0 = p.add_parameter("0", {migraph::shape::float_type, {3}});
p.add_instruction(migraph::op::leaky_relu{alpha}, l0);
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::leaky_relu{alpha}, l0);
auto prog = migraph::parse_onnx("leaky_relu.onnx");
auto prog = migraphx::parse_onnx("leaky_relu.onnx");
EXPECT(p == prog);
}
void imagescaler_test()
TEST_CASE(imagescaler_test)
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {1, 3, 16, 16}};
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1, 3, 16, 16}};
auto l0 = p.add_parameter("0", s);
auto scale_val = p.add_literal(0.5f);
auto bias_vals = p.add_literal(
migraph::literal{migraph::shape{migraph::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
auto scaled_tensor = p.add_instruction(migraph::op::scalar{s}, scale_val);
auto img_scaled = p.add_instruction(migraph::op::mul{}, l0, scaled_tensor);
auto bias_bcast = p.add_instruction(migraph::op::broadcast{1, s}, bias_vals);
p.add_instruction(migraph::op::add{}, img_scaled, bias_bcast);
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s}, scale_val);
auto img_scaled = p.add_instruction(migraphx::op::mul{}, l0, scaled_tensor);
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s}, bias_vals);
p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
auto prog = migraph::parse_onnx("imagescaler_test.onnx");
auto prog = migraphx::parse_onnx("imagescaler_test.onnx");
EXPECT(p == prog);
}
void globalavgpool_test()
TEST_CASE(globalavgpool_test)
{
migraph::program p;
auto input = p.add_parameter("0", migraph::shape{migraph::shape::float_type, {1, 3, 16, 16}});
auto op = migraph::op::pooling{"average"};
migraphx::program p;
auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto op = migraphx::op::pooling{"average"};
auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]};
p.add_instruction(op, input);
auto prog = migraph::parse_onnx("globalavgpool_test.onnx");
auto prog = migraphx::parse_onnx("globalavgpool_test.onnx");
EXPECT(p == prog);
}
void globalmaxpool_test()
TEST_CASE(globalmaxpool_test)
{
migraph::program p;
auto input = p.add_parameter("0", migraph::shape{migraph::shape::float_type, {1, 3, 16, 16}});
auto op = migraph::op::pooling{"max"};
migraphx::program p;
auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto op = migraphx::op::pooling{"max"};
auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]};
p.add_instruction(op, input);
auto prog = migraph::parse_onnx("globalmaxpool_test.onnx");
auto prog = migraphx::parse_onnx("globalmaxpool_test.onnx");
EXPECT(p == prog);
}
void transpose_test()
TEST_CASE(transpose_test)
{
migraph::program p;
auto input = p.add_parameter("0", migraph::shape{migraph::shape::float_type, {1, 2, 2, 3}});
migraphx::program p;
auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
std::vector<int64_t> perm{0, 3, 1, 2};
p.add_instruction(migraph::op::transpose{perm}, input);
p.add_instruction(migraphx::op::transpose{perm}, input);
auto prog = migraph::parse_onnx("transpose_test.onnx");
auto prog = migraphx::parse_onnx("transpose_test.onnx");
EXPECT(p == prog);
}
void dropout_test()
TEST_CASE(dropout_test)
{
migraph::program p;
auto input = p.add_parameter("0", migraph::shape{migraph::shape::float_type, {1, 3, 2, 2}});
p.add_instruction(migraph::op::identity{}, input);
migraphx::program p;
auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}});
p.add_instruction(migraphx::op::identity{}, input);
auto prog = migraph::parse_onnx("dropout_test.onnx");
auto prog = migraphx::parse_onnx("dropout_test.onnx");
EXPECT(p == prog);
}
int main()
TEST_CASE(sum_test)
{
pytorch_conv_bias_test();
pytorch_conv_relu_maxpool();
pytorch_conv_bn_relu_maxpool();
pytorch_conv_relu_maxpool_x2();
leaky_relu_test();
imagescaler_test();
globalavgpool_test();
globalmaxpool_test();
transpose_test();
dropout_test();
migraphx::program p;
auto input0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
auto input1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3}});
auto input2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {3}});
auto l0 = p.add_instruction(migraphx::op::add{}, input0, input1);
p.add_instruction(migraphx::op::add{}, l0, input2);
auto prog = migraphx::parse_onnx("sum_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(exp_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::exp{}, input);
auto prog = migraphx::parse_onnx("exp_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(log_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::log{}, input);
auto prog = migraphx::parse_onnx("log_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(sin_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::sin{}, input);
auto prog = migraphx::parse_onnx("sin_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(cos_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::cos{}, input);
auto prog = migraphx::parse_onnx("cos_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(tan_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::tan{}, input);
auto prog = migraphx::parse_onnx("tan_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(sinh_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::sinh{}, input);
auto prog = migraphx::parse_onnx("sinh_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(cosh_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}});
p.add_instruction(migraphx::op::cosh{}, input);
auto prog = migraphx::parse_onnx("cosh_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(tanh_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}});
p.add_instruction(migraphx::op::tanh{}, input);
auto prog = migraphx::parse_onnx("tanh_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(elu_test)
{
migraphx::program p;
auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::elu{0.01}, input);
auto prog = migraphx::parse_onnx("elu_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(asin_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::asin{}, input);
auto prog = migraphx::parse_onnx("asin_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(max_test)
{
migraphx::program p;
auto input0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
auto input1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3}});
auto input2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {3}});
auto l0 = p.add_instruction(migraphx::op::max{}, input0, input1);
p.add_instruction(migraphx::op::max{}, l0, input2);
migraphx::parse_onnx("max_test.onnx");
}
TEST_CASE(acos_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::acos{}, input);
auto prog = migraphx::parse_onnx("acos_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(min_test)
{
migraphx::program p;
auto input0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
auto input1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3}});
auto input2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {3}});
auto l0 = p.add_instruction(migraphx::op::min{}, input0, input1);
p.add_instruction(migraphx::op::min{}, l0, input2);
migraphx::parse_onnx("min_test.onnx");
}
TEST_CASE(atan_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
p.add_instruction(migraphx::op::atan{}, input);
auto prog = migraphx::parse_onnx("atan_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(add_bcast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape()}, l1);
p.add_instruction(migraphx::op::add{}, l0, l2);
auto prog = migraphx::parse_onnx("add_bcast_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(implicit_add_bcast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, l2, l3);
auto prog = migraphx::parse_onnx("implicit_bcast_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(sub_bcast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape()}, l1);
p.add_instruction(migraphx::op::sub{}, l0, l2);
auto prog = migraphx::parse_onnx("sub_bcast_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(implicit_sub_bcast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::sub{}, l2, l3);
auto prog = migraphx::parse_onnx("implicit_sub_bcast_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(unknown_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::unknown{"Unknown"}, l0, l1);
p.add_instruction(migraphx::unknown{"Unknown"}, l2);
auto prog = migraphx::parse_onnx("unknown_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(softmax_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
auto r = p.add_instruction(migraphx::op::reshape{{1, 3, 1, 1}}, l0);
auto s = p.add_instruction(migraphx::op::softmax{}, r);
p.add_instruction(migraphx::op::reshape{{1, 3}}, s);
auto prog = migraphx::parse_onnx("softmax_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(reshape_test)
{
migraphx::program p;
migraphx::op::reshape op;
std::vector<int64_t> reshape_dims{3, 8};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, reshape_dims});
op.dims = reshape_dims;
p.add_instruction(op, l0);
p.add_instruction(op, l0);
auto prog = migraphx::parse_onnx("reshape_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(shape_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}};
auto l0 = p.add_parameter("x", s);
migraphx::shape s_shape{migraphx::shape::int64_type, {4}};
p.add_literal(s_shape, l0->get_shape().lens());
auto prog = migraphx::parse_onnx("shape_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(gather_test)
{
migraphx::program p;
auto l0 = p.add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}});
int axis = 1;
p.add_instruction(migraphx::op::gather{axis}, l0, l1);
auto prog = migraphx::parse_onnx("gather_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(shape_gather_test)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {7, 3, 10}});
auto l1 =
p.add_literal(migraphx::shape{migraphx::shape::int64_type, {3}}, l0->get_shape().lens());
migraphx::shape const_shape{migraphx::shape::int32_type, {1}};
auto l2 = p.add_literal(migraphx::literal{const_shape, {1}});
int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, l1, l2);
auto prog = migraphx::parse_onnx("shape_gather.onnx");
EXPECT(p == prog);
}
TEST_CASE(rnn_test)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// bidirectional
{
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_bi.onnx");
EXPECT(p == prog);
}
// forward
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_forward.onnx");
EXPECT(p == prog);
}
// reverse
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_reverse.onnx");
EXPECT(p == prog);
}
// 3 argumments
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_3args.onnx");
EXPECT(p == prog);
}
// 5 argumments
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
seq_len,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_5args.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(gru_test)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// forward
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip,
1},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward.onnx");
EXPECT(p == prog);
}
// reverse
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse.onnx");
EXPECT(p == prog);
}
// bidirectional
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::relu{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(gru_test_args)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// 3 arguments
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_3arg.onnx");
EXPECT(p == prog);
}
// 4 arguments
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_4arg.onnx");
EXPECT(p == prog);
}
// 5 arguments
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_5arg.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(gru_test_actv_funcs)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// bidirection, 0 actv function
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx");
EXPECT(p == prog);
}
// bidirection, 1 actv function
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{
hs, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx");
EXPECT(p == prog);
}
// bidirection, 2 actv functions
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_2.onnx");
EXPECT(p == prog);
}
// bidirection, 3 actv functions
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx");
EXPECT(p == prog);
}
// forward, 0 actv function
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward_0.onnx");
EXPECT(p == prog);
}
// reverse, 1 actv function
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{
hs, {migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(flatten_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
p.add_instruction(migraphx::op::flatten{1}, l0);
p.add_instruction(migraphx::op::flatten{2}, l0);
auto prog = migraphx::parse_onnx("flatten_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(squeeze_unsqueeze_test)
{
migraphx::program p;
std::vector<int64_t> squeeze_axes{0, 2, 3, 5};
std::vector<int64_t> unsqueeze_axes{0, 1, 3, 5};
auto l0 =
p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 1, 1, 2, 1}});
auto l1 = p.add_instruction(migraphx::op::squeeze{squeeze_axes}, l0);
p.add_instruction(migraphx::op::unsqueeze{unsqueeze_axes}, l1);
auto prog = migraphx::parse_onnx("squeeze_unsqueeze_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(concat_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7, 4, 3}});
p.add_instruction(migraphx::op::concat{0}, l0, l1);
auto prog = migraphx::parse_onnx("concat_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(slice_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 2}});
p.add_instruction(migraphx::op::slice{{0, 1}, {1, 0}, {2, 2}}, l0);
auto prog = migraphx::parse_onnx("slice_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(constant_test)
{
migraphx::program p;
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0, 1, 2}});
auto prog = migraphx::parse_onnx("constant_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(constant_fill_test)
{
{
migraphx::program p;
auto l0 = p.add_literal(migraphx::literal{{migraphx::shape::int32_type, {2}}, {2, 3}});
std::vector<std::size_t> dims(l0->get_shape().elements());
migraphx::literal ls = l0->get_literal();
ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
migraphx::shape s{migraphx::shape::float_type, dims};
std::vector<float> value(s.elements(), 1.0);
p.add_literal(migraphx::literal{s, value});
auto prog = migraphx::parse_onnx("const_fill1.onnx");
EXPECT(p == prog);
}
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> value(s.elements(), 1.0);
p.add_literal(migraphx::literal{s, value});
auto prog = migraphx::parse_onnx("const_fill2.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(gemm_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}});
p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {}});
auto t0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto alpha = 2.f;
p.add_instruction(migraphx::op::dot{alpha}, t0, t1);
auto prog = migraphx::parse_onnx("gemm_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(add_scalar_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}});
auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, m0, m1);
auto prog = migraphx::parse_onnx("add_scalar_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(sub_scalar_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}});
auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::sub{}, m0, m1);
auto prog = migraphx::parse_onnx("sub_scalar_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(group_conv_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 4, 16, 16}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 1, 3, 3}});
migraphx::op::convolution op;
op.group = 4;
p.add_instruction(op, l0, l1);
migraphx::parse_onnx("group_conv_test.onnx");
}
TEST_CASE(pad_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}});
p.add_instruction(migraphx::op::pad{{1, 1, 1, 1}}, l0);
migraphx::parse_onnx("pad_test.onnx");
}
TEST_CASE(lrn_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 28, 24, 24}});
migraphx::op::lrn op;
op.size = 5;
op.alpha = 0.0001;
op.beta = 0.75;
op.bias = 1.0;
p.add_instruction(op, l0);
migraphx::parse_onnx("lrn_test.onnx");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
 pad-example:T

01"Pad*
pads@@@@test-padZ
0


b
1


B
\ No newline at end of file
reshape-example:

0
12"Reshape

03"Reshape*
shape@@ test-reshape* :B1Z
0



Z
1

b
2


b
3


B
\ No newline at end of file
 shape-example:I
xy"Shape
test_shapeZ
x




b
y

B
\ No newline at end of file
 sin-example:9
xy"Sintest_sinZ
x


b
y


B
\ No newline at end of file
 sinh-example:;
xy"Sinh test_sinhZ
x


b
y


B
\ No newline at end of file
softmax-example:I

01"Softmax test-softmaxZ
0


b
1


B
\ No newline at end of file
 subtraction2:
/
0
1out"Sub*
axis*
broadcast subtraction2Z
0




Z
1


b
out




B
\ No newline at end of file
 sum-example:e

0
1
23"Sum test-dropoutZ
0

Z
1

Z
2

b
2

B
\ No newline at end of file
 tan-example:9
xy"Tantest_tanZ
x


b
y


B
\ No newline at end of file
 tanh-example:;
xy"Tanh test_tanhZ
x

b
y

B
\ No newline at end of file
unknown-example:

0
12"Unknown
2"Unknown test-unknownZ
0




Z
1


b
2




B
\ No newline at end of file
#include <migraph/program.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <sstream>
#include "test.hpp"
template <class... Ts>
void expect_shape(const migraph::shape& expected, const migraph::operation& op, Ts... xs)
void expect_shape(const migraphx::shape& expected, const migraphx::operation& op, Ts... xs)
{
migraph::program p;
std::vector<migraph::shape> shapes{xs...};
std::vector<migraph::instruction_ref> args(shapes.size());
migraphx::program p;
std::vector<migraphx::shape> shapes{xs...};
std::vector<migraphx::instruction_ref> args(shapes.size());
std::transform(
shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return p.add_outline(s); });
p.add_instruction(op, args);
......@@ -24,11 +24,11 @@ void expect_shape(const migraph::shape& expected, const migraph::operation& op,
}
template <class... Ts>
void throws_shape(const migraph::operation& op, Ts... xs)
void throws_shape(const migraphx::operation& op, Ts... xs)
{
migraph::program p;
std::vector<migraph::shape> shapes{xs...};
std::vector<migraph::instruction_ref> args(shapes.size());
migraphx::program p;
std::vector<migraphx::shape> shapes{xs...};
std::vector<migraphx::instruction_ref> args(shapes.size());
std::transform(
shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return p.add_outline(s); });
bool thrown = test::throws([&] { p.add_instruction(op, args); });
......@@ -46,7 +46,7 @@ struct always_false : std::false_type
};
template <class... Ts>
void throws_shape(const migraph::shape&, Ts...)
void throws_shape(const migraphx::shape&, Ts...)
{
static_assert(always_false<Ts...>{},
"An expected shape should not be passed to throws_shape function");
......@@ -55,94 +55,114 @@ void throws_shape(const migraph::shape&, Ts...)
TEST_CASE(batch_norm_inference_shape)
{
const size_t channels = 3;
migraph::shape s{migraph::shape::float_type, {4, channels, 3, 3}};
migraph::shape vars{migraph::shape::float_type, {channels}};
expect_shape(s, migraph::op::batch_norm_inference{}, s, vars, vars, vars, vars);
throws_shape(migraph::op::batch_norm_inference{}, s);
throws_shape(migraph::op::batch_norm_inference{}, s, vars, vars, vars, vars, vars);
migraphx::shape s{migraphx::shape::float_type, {4, channels, 3, 3}};
migraphx::shape vars{migraphx::shape::float_type, {channels}};
expect_shape(s, migraphx::op::batch_norm_inference{}, s, vars, vars, vars, vars);
throws_shape(migraphx::op::batch_norm_inference{}, s);
throws_shape(migraphx::op::batch_norm_inference{}, s, vars, vars, vars, vars, vars);
}
TEST_CASE(convolution_shape)
{
migraph::shape output{migraph::shape::float_type, {4, 4, 1, 1}};
migraph::shape input{migraph::shape::float_type, {4, 3, 3, 3}};
migraph::shape weights{migraph::shape::float_type, {4, 3, 3, 3}};
expect_shape(output, migraph::op::convolution{}, input, weights);
throws_shape(migraph::op::convolution{}, input);
migraph::shape input2{migraph::shape::float_type, {3, 3}};
migraph::shape weights2{migraph::shape::float_type, {3, 3}};
throws_shape(migraph::op::convolution{}, input2, weights2);
throws_shape(migraph::op::convolution{}, input2, weights);
migraphx::shape output{migraphx::shape::float_type, {4, 4, 1, 1}};
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
expect_shape(output, migraphx::op::convolution{}, input, weights);
throws_shape(migraphx::op::convolution{}, input);
migraphx::shape input2{migraphx::shape::float_type, {3, 3}};
migraphx::shape weights2{migraphx::shape::float_type, {3, 3}};
throws_shape(migraphx::op::convolution{}, input2, weights2);
throws_shape(migraphx::op::convolution{}, input2, weights);
}
TEST_CASE(transpose_shape)
{
migraph::shape input{migraph::shape::float_type, {2, 2}};
migraph::shape output{migraph::shape::float_type, {2, 2}, {1, 2}};
expect_shape(input, migraph::op::transpose{{0, 1}}, input);
expect_shape(output, migraph::op::transpose{{1, 0}}, input);
throws_shape(migraph::op::transpose{{1, 2}}, input);
migraphx::shape input{migraphx::shape::float_type, {2, 2}};
migraphx::shape output{migraphx::shape::float_type, {2, 2}, {1, 2}};
expect_shape(input, migraphx::op::transpose{{0, 1}}, input);
expect_shape(output, migraphx::op::transpose{{1, 0}}, input);
throws_shape(migraphx::op::transpose{{1, 2}}, input);
}
TEST_CASE(contiguous_shape)
{
migraph::shape output{migraph::shape::float_type, {2, 2}};
migraph::shape input{migraph::shape::float_type, {2, 2}, {1, 2}};
expect_shape(output, migraph::op::contiguous{}, input);
throws_shape(migraph::op::contiguous{}, input, input);
migraphx::shape output{migraphx::shape::float_type, {2, 2}};
migraphx::shape input{migraphx::shape::float_type, {2, 2}, {1, 2}};
expect_shape(output, migraphx::op::contiguous{}, input);
throws_shape(migraphx::op::contiguous{}, input, input);
migraph::shape single{migraph::shape::float_type, {2}};
expect_shape(single, migraph::op::contiguous{}, single);
migraphx::shape single{migraphx::shape::float_type, {2}};
expect_shape(single, migraphx::op::contiguous{}, single);
}
TEST_CASE(reshape_shape)
{
migraph::shape input{migraph::shape::float_type, {24, 1, 1, 1}};
migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}};
for(auto&& new_shape :
std::vector<std::vector<int64_t>>{{8, 3, 1, 1}, {1, 3, 4, 2}, {1, 3, 4, 2}})
{
std::vector<std::size_t> lens(new_shape.size());
std::copy(new_shape.begin(), new_shape.end(), lens.begin());
migraph::shape output{migraph::shape::float_type, lens};
expect_shape(output, migraph::op::reshape{new_shape}, input);
migraphx::shape output{migraphx::shape::float_type, lens};
expect_shape(output, migraphx::op::reshape{new_shape}, input);
}
for(auto&& new_shape :
std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}, {3, 0, 0}, {3, 2, 0}})
{
throws_shape(migraphx::op::reshape{new_shape}, input);
}
for(auto&& new_shape : std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}})
std::vector<std::pair<std::vector<int64_t>, migraphx::shape>> minus1_tests{
{{2, -1, 3}, {migraphx::shape::float_type, {2, 4, 3}}},
{{0, -1, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
{{2, -1, 0}, {migraphx::shape::float_type, {2, 12, 1}}},
{{0, 0, -1}, {migraphx::shape::float_type, {24, 1, 1}}},
{{2, 0, -1}, {migraphx::shape::float_type, {2, 1, 12}}},
{{-1, 2, 3}, {migraphx::shape::float_type, {4, 2, 3}}},
{{-1, 0, 3}, {migraphx::shape::float_type, {8, 1, 3}}},
{{-1, 0, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
{{-1, 3, 0}, {migraphx::shape::float_type, {8, 3, 1}}}};
for(auto& it : minus1_tests)
{
throws_shape(migraph::op::reshape{new_shape}, input);
expect_shape(it.second, migraphx::op::reshape{it.first}, input);
}
}
TEST_CASE(flatten_shape)
{
migraph::shape input{migraph::shape::float_type, {2, 4, 6, 8}};
expect_shape(migraph::shape{migraph::shape::float_type, {1, 2 * 4 * 6 * 8}},
migraph::op::flatten{0},
migraphx::shape input{migraphx::shape::float_type, {2, 4, 6, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 2 * 4 * 6 * 8}},
migraphx::op::flatten{0},
input);
expect_shape(
migraph::shape{migraph::shape::float_type, {2, 4 * 6 * 8}}, migraph::op::flatten{1}, input);
expect_shape(
migraph::shape{migraph::shape::float_type, {2 * 4, 6 * 8}}, migraph::op::flatten{2}, input);
expect_shape(
migraph::shape{migraph::shape::float_type, {2 * 4 * 6, 8}}, migraph::op::flatten{3}, input);
expect_shape(migraph::shape{migraph::shape::float_type, {2 * 4 * 6 * 8, 1}},
migraph::op::flatten{4},
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 4 * 6 * 8}},
migraphx::op::flatten{1},
input);
throws_shape(migraph::op::flatten{5}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2 * 4, 6 * 8}},
migraphx::op::flatten{2},
input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2 * 4 * 6, 8}},
migraphx::op::flatten{3},
input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2 * 4 * 6 * 8, 1}},
migraphx::op::flatten{4},
input);
throws_shape(migraphx::op::flatten{5}, input);
}
TEST_CASE(slice_shape)
{
migraph::shape input{migraph::shape::int32_type, {2, 2, 3}};
expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraph::op::slice{{2}, {1}, {3}},
migraphx::shape input{migraphx::shape::int32_type, {2, 2, 3}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraphx::op::slice{{2}, {1}, {3}},
input);
expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraph::op::slice{{0, 1, 2}, {0, 0, 1}, {2, 2, 3}},
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraphx::op::slice{{0, 1, 2}, {0, 0, 1}, {2, 2, 3}},
input);
expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 1}, {6, 3, 1}},
migraph::op::slice{{2}, {2}, {10}},
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 1}, {6, 3, 1}},
migraphx::op::slice{{2}, {2}, {10}},
input);
}
......@@ -150,62 +170,423 @@ TEST_CASE(multibroadcast)
{
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {2, 1, 3}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 3, 0, 1}},
migraph::op::multibroadcast{lens},
migraphx::shape input{migraphx::shape::float_type, {2, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 0, 1}},
migraphx::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {2, 1, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 1, 0, 0}},
migraph::op::multibroadcast{lens},
migraphx::shape input{migraphx::shape::float_type, {2, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 1, 0, 0}},
migraphx::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {5, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 0, 1, 0}},
migraph::op::multibroadcast{lens},
migraphx::shape input{migraphx::shape::float_type, {5, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 1, 0}},
migraphx::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {1, 0, 0, 0}},
migraph::op::multibroadcast{lens},
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 0, 0, 0}},
migraphx::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {3}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 0, 0, 1}},
migraph::op::multibroadcast{lens},
migraphx::shape input{migraphx::shape::float_type, {3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 0, 1}},
migraphx::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 4, 1, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 3}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 3, 3, 1}},
migraph::op::multibroadcast{lens},
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 3, 1}},
migraphx::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 1, 1, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {1, 1, 1, 0}},
migraph::op::multibroadcast{lens},
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 1, 1, 0}},
migraphx::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 1, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 1, 1}};
throws_shape(migraph::op::multibroadcast{lens}, input);
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
throws_shape(migraphx::op::multibroadcast{lens}, input);
}
{
std::vector<std::size_t> lens{4, 1, 3};
migraph::shape input{migraph::shape::float_type, {}};
throws_shape(migraph::op::multibroadcast{lens}, input);
migraphx::shape input{migraphx::shape::float_type, {}};
throws_shape(migraphx::op::multibroadcast{lens}, input);
}
}
TEST_CASE(gather)
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 6, 4, 5}},
migraphx::op::gather{axis},
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {6, 3, 4, 5}},
migraphx::op::gather{axis},
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = 4;
throws_shape(migraphx::op::gather{axis}, input, indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = -5;
throws_shape(migraphx::op::gather{axis}, input, indices);
}
}
TEST_CASE(rnn)
{
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape(migraphx::op::rnn{hidden_size + 1,
{migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape(
migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
}
TEST_CASE(gru)
{
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::gru{hidden_size,
{migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(migraphx::op::gru{hidden_size + 1,
{migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(migraphx::op::gru{hidden_size,
{migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
}
......
#include <migraph/operation.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/context.hpp>
#include <sstream>
#include <string>
#include "test.hpp"
......@@ -9,18 +10,19 @@ struct simple_operation
template <class T, class F>
static auto reflect(T& x, F f)
{
return migraph::pack(f(x.data, "data"));
return migraphx::pack(f(x.data, "data"));
}
int data = 1;
std::string name() const { return "simple"; }
migraph::shape compute_shape(const std::vector<migraph::shape>&) const
migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const
{
MIGRAPH_THROW("not computable");
MIGRAPHX_THROW("not computable");
}
migraph::argument
compute(migraph::context&, const migraph::shape&, const std::vector<migraph::argument>&) const
migraphx::argument compute(migraphx::context&,
const migraphx::shape&,
const std::vector<migraphx::argument>&) const
{
MIGRAPH_THROW("not computable");
MIGRAPHX_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const simple_operation& op)
{
......@@ -32,22 +34,23 @@ struct simple_operation
struct simple_operation_no_print
{
std::string name() const { return "simple"; }
migraph::shape compute_shape(const std::vector<migraph::shape>&) const
migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const
{
MIGRAPH_THROW("not computable");
MIGRAPHX_THROW("not computable");
}
migraph::argument
compute(migraph::context&, const migraph::shape&, const std::vector<migraph::argument>&) const
migraphx::argument compute(migraphx::context&,
const migraphx::shape&,
const std::vector<migraphx::argument>&) const
{
MIGRAPH_THROW("not computable");
MIGRAPHX_THROW("not computable");
}
};
TEST_CASE(operation_copy_test)
{
simple_operation s{};
migraph::operation op1 = s; // NOLINT
migraph::operation op2 = op1; // NOLINT
migraphx::operation op1 = s; // NOLINT
migraphx::operation op2 = op1; // NOLINT
// cppcheck-suppress duplicateExpression
EXPECT(s == op1);
// cppcheck-suppress duplicateExpression
......@@ -57,10 +60,10 @@ TEST_CASE(operation_copy_test)
TEST_CASE(operation_equal_test)
{
simple_operation s{};
migraph::operation op1 = s;
s.data = 2;
migraph::operation op2 = op1; // NOLINT
migraph::operation op3 = s; // NOLINT
migraphx::operation op1 = s;
s.data = 2;
migraphx::operation op2 = op1; // NOLINT
migraphx::operation op3 = s; // NOLINT
EXPECT(s != op1);
EXPECT(op2 == op1);
......@@ -74,18 +77,18 @@ struct not_operation
TEST_CASE(operation_any_cast)
{
migraph::operation op1 = simple_operation{};
EXPECT(migraph::any_cast<simple_operation>(op1).data == 1);
EXPECT(migraph::any_cast<not_operation*>(&op1) == nullptr);
EXPECT(test::throws([&] { migraph::any_cast<not_operation&>(op1); }));
migraph::operation op2 = simple_operation{2};
EXPECT(migraph::any_cast<simple_operation>(op2).data == 2);
EXPECT(migraph::any_cast<not_operation*>(&op2) == nullptr);
migraphx::operation op1 = simple_operation{};
EXPECT(migraphx::any_cast<simple_operation>(op1).data == 1);
EXPECT(migraphx::any_cast<not_operation*>(&op1) == nullptr);
EXPECT(test::throws([&] { migraphx::any_cast<not_operation&>(op1); }));
migraphx::operation op2 = simple_operation{2};
EXPECT(migraphx::any_cast<simple_operation>(op2).data == 2);
EXPECT(migraphx::any_cast<not_operation*>(&op2) == nullptr);
}
TEST_CASE(operation_print)
{
migraph::operation op = simple_operation{};
migraphx::operation op = simple_operation{};
std::stringstream ss;
ss << op;
std::string s = ss.str();
......@@ -94,11 +97,71 @@ TEST_CASE(operation_print)
TEST_CASE(operation_default_print)
{
migraph::operation op = simple_operation_no_print{};
migraphx::operation op = simple_operation_no_print{};
std::stringstream ss;
ss << op;
std::string s = ss.str();
EXPECT(s == "simple");
}
struct final_operation
{
std::string name() const { return "final"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const
{
MIGRAPHX_THROW("not computable");
}
void
finalize(migraphx::context&, const migraphx::shape&, const std::vector<migraphx::shape>&) const
{
}
};
struct final_operation_throw
{
std::string name() const { return "final"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const
{
MIGRAPHX_THROW("not computable");
}
[[gnu::noreturn]] void
finalize(migraphx::context&, const migraphx::shape&, const std::vector<migraphx::shape>&) const
{
MIGRAPHX_THROW("finalize");
}
};
TEST_CASE(check_has_finalize_simple)
{
migraphx::operation op = simple_operation{};
EXPECT(not migraphx::has_finalize(op));
}
TEST_CASE(check_has_finalize)
{
migraphx::operation op = final_operation{};
EXPECT(migraphx::has_finalize(op));
}
TEST_CASE(check_run_finalize)
{
migraphx::operation op = final_operation{};
migraphx::context ctx{};
op.finalize(ctx, {}, {});
}
TEST_CASE(check_run_finalize_simple)
{
migraphx::operation op = simple_operation{};
migraphx::context ctx{};
op.finalize(ctx, {}, {});
}
TEST_CASE(check_run_finalize_throw)
{
migraphx::operation op = final_operation_throw{};
migraphx::context ctx{};
EXPECT(test::throws([&] { op.finalize(ctx, {}, {}); }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment