Unverified Commit 32b08891 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Dynamic ref transpose (#1438)

Extends ref transpose operator for dynamic shapes
Make dynamic tests more consistent naming
parent eb094e57
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -45,17 +46,15 @@ struct transpose ...@@ -45,17 +46,15 @@ struct transpose
} }
std::string name() const { return "transpose"; } std::string name() const { return "transpose"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto input = inputs.at(0); auto input = inputs.at(0);
auto input_lens = input.lens();
auto input_strides = input.strides();
auto t = input.type();
if(dims.size() != input_lens.size()) if(dims.size() != input.ndim())
{ {
MIGRAPHX_THROW("Permutation has wrong number of axes"); MIGRAPHX_THROW("TRANSPOSE: Permutation has wrong number of axes");
} }
std::vector<int64_t> axes(dims.size()); std::vector<int64_t> axes(dims.size());
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
...@@ -63,19 +62,36 @@ struct transpose ...@@ -63,19 +62,36 @@ struct transpose
{ {
MIGRAPHX_THROW("TRANSPOSE: Invalid permutation"); MIGRAPHX_THROW("TRANSPOSE: Invalid permutation");
} }
std::vector<size_t> output_lens(input_lens.size());
std::vector<size_t> output_strides(input_lens.size()); if(input.dynamic())
for(std::size_t i = 0; i < output_lens.size(); i++)
{ {
output_lens[i] = input_lens[dims[i]]; std::vector<shape::dynamic_dimension> output_dyn_dims(input.ndim());
output_strides[i] = input_strides[dims[i]]; std::transform(dims.cbegin(), dims.cend(), output_dyn_dims.begin(), [&](auto dim) {
return input.dyn_dims()[dim];
});
return {input.type(), output_dyn_dims};
}
else
{
auto input_lens = input.lens();
auto input_strides = input.strides();
std::vector<size_t> output_lens(input.ndim());
std::vector<size_t> output_strides(input.ndim());
for(std::size_t i = 0; i < input.ndim(); i++)
{
output_lens[i] = input_lens[dims[i]];
output_strides[i] = input_strides[dims[i]];
}
return {input.type(), output_lens, output_strides};
} }
return {t, output_lens, output_strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -47,7 +47,7 @@ struct parse_transpose : op_parser<parse_transpose> ...@@ -47,7 +47,7 @@ struct parse_transpose : op_parser<parse_transpose>
} }
// if perm is empty, use the default value // if perm is empty, use the default value
auto n_dim = args.front()->get_shape().lens().size(); auto n_dim = args.front()->get_shape().ndim();
if(perm.empty()) if(perm.empty())
{ {
perm.resize(n_dim); perm.resize(n_dim);
......
...@@ -6277,6 +6277,21 @@ def transpose_test(): ...@@ -6277,6 +6277,21 @@ def transpose_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def transpose_dyn_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [None, 2, 2, 3])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [None, 3, 2, 2])
node = onnx.helper.make_node(
'Transpose',
perm=[0, 3, 1, 2],
inputs=['0'],
outputs=['1'],
)
return ([node], [x], [y])
@onnx_test @onnx_test
def transpose_gather_test(): def transpose_gather_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 5, 4, 6]) x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 5, 4, 6])
......
...@@ -5973,6 +5973,24 @@ TEST_CASE(transpose_test) ...@@ -5973,6 +5973,24 @@ TEST_CASE(transpose_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(transpose_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter(
"0",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 2, 0}, {2, 2, 0}, {3, 3, 0}}});
std::vector<int64_t> perm{0, 3, 1, 2};
auto t0 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), input);
mm->add_return({t0});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = migraphx::parse_onnx("transpose_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(topk_attrk_test) TEST_CASE(topk_attrk_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -2226,6 +2226,28 @@ TEST_CASE(transpose_shape) ...@@ -2226,6 +2226,28 @@ TEST_CASE(transpose_shape)
throws_shape(migraphx::make_op("transpose", {{"permutation", {1, 2}}}), input); throws_shape(migraphx::make_op("transpose", {{"permutation", {1, 2}}}), input);
} }
TEST_CASE(transpose_dyn_shape0)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4, 0}, {2, 2, 0}}};
migraphx::shape output{migraphx::shape::float_type, {{2, 2, 0}, {1, 4, 0}}};
expect_shape(input, migraphx::make_op("transpose", {{"permutation", {0, 1}}}), input);
expect_shape(output, migraphx::make_op("transpose", {{"permutation", {1, 0}}}), input);
}
TEST_CASE(transpose_dyn_shape1)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}};
migraphx::shape output{migraphx::shape::float_type, {{4, 4, 0}, {4, 4, 0}, {1, 4, 0}}};
expect_shape(input, migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), input);
expect_shape(output, migraphx::make_op("transpose", {{"permutation", {2, 1, 0}}}), input);
}
TEST_CASE(transpose_axes_error)
{
migraphx::shape input{migraphx::shape::float_type, {2, 2}};
throws_shape(migraphx::make_op("transpose", {{"permutation", {1}}}), input);
}
TEST_CASE(step_test) TEST_CASE(step_test)
{ {
migraphx::shape s1{migraphx::shape::float_type, {1, 2, 4}}; migraphx::shape s1{migraphx::shape::float_type, {1, 2, 4}};
......
...@@ -60,15 +60,16 @@ TEST_CASE(abs_test) ...@@ -60,15 +60,16 @@ TEST_CASE(abs_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(abs_dynamic_test) TEST_CASE(abs_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<float> a = {-1, 2, -3, 4};
migraphx::shape s{migraphx::shape::float_type, {{2, 8, 0}, {2, 2, 0}}}; migraphx::shape s{migraphx::shape::float_type, {{2, 8, 0}, {2, 2, 0}}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("abs"), input); mm->add_instruction(migraphx::make_op("abs"), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> a = {-1, 2, -3, 4};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2, 2}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2, 2}};
params0["X"] = migraphx::argument(input_fixed_shape0, a.data()); params0["X"] = migraphx::argument(input_fixed_shape0, a.data());
...@@ -97,17 +98,17 @@ TEST_CASE(acos_test) ...@@ -97,17 +98,17 @@ TEST_CASE(acos_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(acos_dynamic_test) TEST_CASE(acos_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data{-0.8f, 0.0f, 1.0f};
mm->add_instruction(migraphx::make_op("acos"), input); mm->add_instruction(migraphx::make_op("acos"), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> input_data{-0.8f, 0.0f, 1.0f};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
...@@ -138,7 +139,7 @@ TEST_CASE(acosh_test) ...@@ -138,7 +139,7 @@ TEST_CASE(acosh_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(acosh_dynamic_test) TEST_CASE(acosh_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -419,17 +420,17 @@ TEST_CASE(asin_test) ...@@ -419,17 +420,17 @@ TEST_CASE(asin_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(asin_dynamic_test) TEST_CASE(asin_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data{-0.5f, 0.0f, 0.9f};
mm->add_instruction(migraphx::make_op("asin"), input); mm->add_instruction(migraphx::make_op("asin"), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> input_data{-0.5f, 0.0f, 0.9f};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
...@@ -460,17 +461,17 @@ TEST_CASE(asinh_test) ...@@ -460,17 +461,17 @@ TEST_CASE(asinh_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(asinh_dynamic_test) TEST_CASE(asinh_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data{-0.5f, 0.0f, 0.9f};
mm->add_instruction(migraphx::make_op("asinh"), input); mm->add_instruction(migraphx::make_op("asinh"), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> input_data{-0.5f, 0.0f, 0.9f};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
...@@ -501,17 +502,17 @@ TEST_CASE(atan_test) ...@@ -501,17 +502,17 @@ TEST_CASE(atan_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(atan_dynamic_test) TEST_CASE(atan_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data{-1.0f, 0.0f, 1.0f};
mm->add_instruction(migraphx::make_op("atan"), input); mm->add_instruction(migraphx::make_op("atan"), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> input_data{-1.0f, 0.0f, 1.0f};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
...@@ -542,17 +543,17 @@ TEST_CASE(atanh_test) ...@@ -542,17 +543,17 @@ TEST_CASE(atanh_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(atanh_dynamic_test) TEST_CASE(atanh_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data{0.4435683f, 0.6223626f, 0.316958f};
mm->add_instruction(migraphx::make_op("atanh"), input); mm->add_instruction(migraphx::make_op("atanh"), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> input_data{0.4435683f, 0.6223626f, 0.316958f};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
...@@ -758,17 +759,17 @@ TEST_CASE(ceil_test) ...@@ -758,17 +759,17 @@ TEST_CASE(ceil_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(ceil_dynamic_test) TEST_CASE(ceil_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{4, 12, 0}; migraphx::shape::dynamic_dimension dd{4, 12, 0};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data = {1.1, 1.5, 1.6, -1.1, -1.5, -1.6, 0.0, 2.0, -2.0};
mm->add_instruction(migraphx::make_op("ceil"), input); mm->add_instruction(migraphx::make_op("ceil"), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> input_data = {1.1, 1.5, 1.6, -1.1, -1.5, -1.6, 0.0, 2.0, -2.0};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {9}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {9}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
...@@ -1069,7 +1070,7 @@ TEST_CASE(conv_dynamic_batch_test) ...@@ -1069,7 +1070,7 @@ TEST_CASE(conv_dynamic_batch_test)
EXPECT(migraphx::verify_range(results_vector, sol)); EXPECT(migraphx::verify_range(results_vector, sol));
} }
TEST_CASE(conv_dynamic_img_shape_test) TEST_CASE(conv_dyn_img_shape_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -1158,7 +1159,7 @@ TEST_CASE(conv_dynamic_img_shape_test) ...@@ -1158,7 +1159,7 @@ TEST_CASE(conv_dynamic_img_shape_test)
EXPECT(migraphx::verify_range(results_vector, sol)); EXPECT(migraphx::verify_range(results_vector, sol));
} }
TEST_CASE(conv_dynamic_weights_shape_test) TEST_CASE(conv_dyn_weights_shape_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -1235,7 +1236,7 @@ TEST_CASE(conv_dynamic_weights_shape_test) ...@@ -1235,7 +1236,7 @@ TEST_CASE(conv_dynamic_weights_shape_test)
EXPECT(migraphx::verify_range(results_vector, sol)); EXPECT(migraphx::verify_range(results_vector, sol));
} }
TEST_CASE(conv_dynamic_img_same_upper_test) TEST_CASE(conv_dyn_img_same_upper_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -1306,7 +1307,7 @@ TEST_CASE(conv_dynamic_img_same_upper_test) ...@@ -1306,7 +1307,7 @@ TEST_CASE(conv_dynamic_img_same_upper_test)
EXPECT(migraphx::verify_range(results_vector, sol)); EXPECT(migraphx::verify_range(results_vector, sol));
} }
TEST_CASE(conv_dynamic_kernel_same_upper_test) TEST_CASE(conv_dyn_kernel_same_upper_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -1380,7 +1381,7 @@ TEST_CASE(conv_dynamic_kernel_same_upper_test) ...@@ -1380,7 +1381,7 @@ TEST_CASE(conv_dynamic_kernel_same_upper_test)
EXPECT(migraphx::verify_range(results_vector, sol)); EXPECT(migraphx::verify_range(results_vector, sol));
} }
TEST_CASE(conv_dynamic_kernel_same_lower_test) TEST_CASE(conv_dyn_kernel_same_lower_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -1725,17 +1726,17 @@ TEST_CASE(cos_test) ...@@ -1725,17 +1726,17 @@ TEST_CASE(cos_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(cos_dynamic_test) TEST_CASE(cos_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data{-1, 0, 1};
mm->add_instruction(migraphx::make_op("cos"), input); mm->add_instruction(migraphx::make_op("cos"), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> input_data{-1, 0, 1};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
...@@ -1766,17 +1767,17 @@ TEST_CASE(cosh_test) ...@@ -1766,17 +1767,17 @@ TEST_CASE(cosh_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(cosh_dynamic_test) TEST_CASE(cosh_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data = {-1.0, 2.0, -3.0, 4.0};
mm->add_instruction(migraphx::make_op("cosh"), input); mm->add_instruction(migraphx::make_op("cosh"), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> input_data = {-1.0, 2.0, -3.0, 4.0};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {4}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {4}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
...@@ -1999,18 +2000,18 @@ TEST_CASE(elu_test) ...@@ -1999,18 +2000,18 @@ TEST_CASE(elu_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(elu_dynamic_test) TEST_CASE(elu_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data{-1.0, 2.0, -3.0, 4.0};
float alpha = 0.5; float alpha = 0.5;
mm->add_instruction(migraphx::make_op("elu", {{"alpha", alpha}}), input); mm->add_instruction(migraphx::make_op("elu", {{"alpha", alpha}}), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> input_data{-1.0, 2.0, -3.0, 4.0};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {4}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {4}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
...@@ -2117,17 +2118,17 @@ TEST_CASE(erf_test) ...@@ -2117,17 +2118,17 @@ TEST_CASE(erf_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(erf_dynamic_test) TEST_CASE(erf_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data = {0.73785057, 1.58165966, -0.43597795, -0.01677432};
mm->add_instruction(migraphx::make_op("erf"), input); mm->add_instruction(migraphx::make_op("erf"), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> input_data = {0.73785057, 1.58165966, -0.43597795, -0.01677432};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {4}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {4}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
...@@ -2158,17 +2159,17 @@ TEST_CASE(exp_test) ...@@ -2158,17 +2159,17 @@ TEST_CASE(exp_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(exp_dynamic_test) TEST_CASE(exp_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data{-1, 0, 1};
mm->add_instruction(migraphx::make_op("exp"), input); mm->add_instruction(migraphx::make_op("exp"), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> input_data{-1, 0, 1};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
...@@ -2199,17 +2200,17 @@ TEST_CASE(floor_test) ...@@ -2199,17 +2200,17 @@ TEST_CASE(floor_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(floor_dynamic_test) TEST_CASE(floor_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{5, 12, 0}; migraphx::shape::dynamic_dimension dd{5, 12, 0};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data = {1.1, 1.5, 0.6, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0};
mm->add_instruction(migraphx::make_op("floor"), input); mm->add_instruction(migraphx::make_op("floor"), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> input_data = {1.1, 1.5, 0.6, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {9}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {9}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
...@@ -2803,16 +2804,16 @@ TEST_CASE(identity_test) ...@@ -2803,16 +2804,16 @@ TEST_CASE(identity_test)
EXPECT(std::equal(data.begin(), data.end(), results_vector.begin())); EXPECT(std::equal(data.begin(), data.end(), results_vector.begin()));
} }
TEST_CASE(identity_dynamic_test) TEST_CASE(identity_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 4, 0}, {2, 4, 0}}}; migraphx::shape s{migraphx::shape::float_type, {{2, 4, 0}, {2, 4, 0}}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<int> input_data{1, 2, 3, 4};
mm->add_instruction(migraphx::make_op("identity"), input); mm->add_instruction(migraphx::make_op("identity"), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<int> input_data{1, 2, 3, 4};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::int32_type, {2, 2}}; migraphx::shape input_fixed_shape0{migraphx::shape::int32_type, {2, 2}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
...@@ -3049,17 +3050,17 @@ TEST_CASE(isnan_test) ...@@ -3049,17 +3050,17 @@ TEST_CASE(isnan_test)
} }
} }
TEST_CASE(isnan_dynamic_test) TEST_CASE(isnan_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 2, 0}, {3, 8, 0}}}; migraphx::shape s{migraphx::shape::float_type, {{2, 2, 0}, {3, 8, 0}}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
auto nan_val = std::numeric_limits<float>::quiet_NaN(); auto nan_val = std::numeric_limits<float>::quiet_NaN();
std::vector<float> input_data = {1.2, 5.2, nan_val, nan_val, 0., 100.};
mm->add_instruction(migraphx::make_op("isnan"), input); mm->add_instruction(migraphx::make_op("isnan"), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> input_data = {1.2, 5.2, nan_val, nan_val, 0., 100.};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2, 3}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2, 3}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
...@@ -3420,17 +3421,17 @@ TEST_CASE(log_test) ...@@ -3420,17 +3421,17 @@ TEST_CASE(log_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(log_dynamic_test) TEST_CASE(log_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data = {1, 2, 3};
mm->add_instruction(migraphx::make_op("log"), input); mm->add_instruction(migraphx::make_op("log"), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> input_data = {1, 2, 3};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
...@@ -4090,7 +4091,7 @@ TEST_CASE(fmod_test) ...@@ -4090,7 +4091,7 @@ TEST_CASE(fmod_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(fmod_dynamic_test) TEST_CASE(fmod_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -4386,16 +4387,17 @@ TEST_CASE(neg_test) ...@@ -4386,16 +4387,17 @@ TEST_CASE(neg_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(neg_dynamic_test) TEST_CASE(neg_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 4, 0}, {3, 3, 0}}}; migraphx::shape s{migraphx::shape::float_type, {{2, 4, 0}, {3, 3, 0}}};
std::vector<float> a = {1.0f, 1.3f, -1.2f, 0.0f, -100.f, 200.f}; auto input = mm->add_parameter("X", s);
auto input = mm->add_parameter("X", s); auto ret = mm->add_instruction(migraphx::make_op("neg"), input);
auto ret = mm->add_instruction(migraphx::make_op("neg"), input);
mm->add_return({ret}); mm->add_return({ret});
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> a = {1.0f, 1.3f, -1.2f, 0.0f, -100.f, 200.f};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2, 3}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2, 3}};
params0["X"] = migraphx::argument(input_fixed_shape0, a.data()); params0["X"] = migraphx::argument(input_fixed_shape0, a.data());
...@@ -4407,7 +4409,7 @@ TEST_CASE(neg_dynamic_test) ...@@ -4407,7 +4409,7 @@ TEST_CASE(neg_dynamic_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(nms_dynamic_out_test) TEST_CASE(nms_dyn_out_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -4442,7 +4444,7 @@ TEST_CASE(nms_dynamic_out_test) ...@@ -4442,7 +4444,7 @@ TEST_CASE(nms_dynamic_out_test)
EXPECT(migraphx::verify_range(result, gold)); EXPECT(migraphx::verify_range(result, gold));
} }
TEST_CASE(nms_dynamic_batch_test) TEST_CASE(nms_dyn_batch_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -4488,7 +4490,7 @@ TEST_CASE(nms_dynamic_batch_test) ...@@ -4488,7 +4490,7 @@ TEST_CASE(nms_dynamic_batch_test)
EXPECT(migraphx::verify_range(result, gold)); EXPECT(migraphx::verify_range(result, gold));
} }
TEST_CASE(nms_dynamic_boxes_test) TEST_CASE(nms_dyn_boxes_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -4531,7 +4533,7 @@ TEST_CASE(nms_dynamic_boxes_test) ...@@ -4531,7 +4533,7 @@ TEST_CASE(nms_dynamic_boxes_test)
EXPECT(migraphx::verify_range(result, gold)); EXPECT(migraphx::verify_range(result, gold));
} }
TEST_CASE(nms_dynamic_classes_test) TEST_CASE(nms_dyn_classes_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -4777,17 +4779,17 @@ TEST_CASE(not_test) ...@@ -4777,17 +4779,17 @@ TEST_CASE(not_test)
} }
} }
TEST_CASE(not_dynamic_test) TEST_CASE(not_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data{0, 8, 1, -32};
mm->add_instruction(migraphx::make_op("not"), input); mm->add_instruction(migraphx::make_op("not"), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> input_data{0, 8, 1, -32};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {4}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {4}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
...@@ -5511,17 +5513,17 @@ TEST_CASE(recip_test) ...@@ -5511,17 +5513,17 @@ TEST_CASE(recip_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(recip_dynamic_test) TEST_CASE(recip_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data{-0.5f, 0.1f, 0.5f};
mm->add_instruction(migraphx::make_op("recip"), input); mm->add_instruction(migraphx::make_op("recip"), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> input_data{-0.5f, 0.1f, 0.5f};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
...@@ -5819,17 +5821,17 @@ TEST_CASE(relu_test) ...@@ -5819,17 +5821,17 @@ TEST_CASE(relu_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(relu_dynamic_test) TEST_CASE(relu_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data{-1.f, 0.f, 1.f};
mm->add_instruction(migraphx::make_op("relu"), input); mm->add_instruction(migraphx::make_op("relu"), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> input_data{-1.f, 0.f, 1.f};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
...@@ -6124,17 +6126,17 @@ TEST_CASE(round_test) ...@@ -6124,17 +6126,17 @@ TEST_CASE(round_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(round_dynamic_test) TEST_CASE(round_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{4, 10, 0}; migraphx::shape::dynamic_dimension dd{4, 10, 0};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data{1.1, 1.5, 1.6, -1.1, -1.5, -1.6, 0.0, 2.0, -2.0};
mm->add_instruction(migraphx::make_op("round"), input); mm->add_instruction(migraphx::make_op("round"), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> input_data{1.1, 1.5, 1.6, -1.1, -1.5, -1.6, 0.0, 2.0, -2.0};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {9}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {9}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
...@@ -6160,17 +6162,17 @@ TEST_CASE(rsqrt_test) ...@@ -6160,17 +6162,17 @@ TEST_CASE(rsqrt_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(rsqrt_dynamic_test) TEST_CASE(rsqrt_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data{4.0, 16.0, 64.0};
mm->add_instruction(migraphx::make_op("rsqrt"), input); mm->add_instruction(migraphx::make_op("rsqrt"), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> input_data{4.0, 16.0, 64.0};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
...@@ -6729,16 +6731,16 @@ TEST_CASE(sigmoid_test) ...@@ -6729,16 +6731,16 @@ TEST_CASE(sigmoid_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(sigmoid_dynamic_test) TEST_CASE(sigmoid_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 4, 0}, {2, 2, 0}}}; migraphx::shape s{migraphx::shape::float_type, {{2, 4, 0}, {2, 2, 0}}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data{-1, 2, -3, 4};
mm->add_instruction(migraphx::make_op("sigmoid"), input); mm->add_instruction(migraphx::make_op("sigmoid"), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> input_data{-1, 2, -3, 4};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2, 2}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2, 2}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
...@@ -6765,17 +6767,17 @@ TEST_CASE(sign_test) ...@@ -6765,17 +6767,17 @@ TEST_CASE(sign_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(sign_dynamic_test) TEST_CASE(sign_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data{1.02481645, 0.85643062, -0.03404123, -0.92791926, 0.0};
mm->add_instruction(migraphx::make_op("sign"), input); mm->add_instruction(migraphx::make_op("sign"), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> input_data{1.02481645, 0.85643062, -0.03404123, -0.92791926, 0.0};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {5}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {5}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
...@@ -6804,17 +6806,17 @@ TEST_CASE(sin_test) ...@@ -6804,17 +6806,17 @@ TEST_CASE(sin_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(sin_dynamic_test) TEST_CASE(sin_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0}; migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
std::vector<float> input_data = {-1, 0, 1};
mm->add_instruction(migraphx::make_op("sin"), input); mm->add_instruction(migraphx::make_op("sin"), input);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> input_data = {-1, 0, 1};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
...@@ -7336,11 +7338,6 @@ TEST_CASE(transpose_test) ...@@ -7336,11 +7338,6 @@ TEST_CASE(transpose_test)
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), l); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), l);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
result.visit([&](auto output) {
std::vector<size_t> new_lens = {1, 3, 2, 2};
EXPECT(bool{output.get_shape().lens() == new_lens});
});
} }
{ {
migraphx::program p; migraphx::program p;
...@@ -7360,6 +7357,32 @@ TEST_CASE(transpose_test) ...@@ -7360,6 +7357,32 @@ TEST_CASE(transpose_test)
} }
} }
TEST_CASE(transpose_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4, 0}, {2, 2, 0}, {2, 2, 0}, {3, 3, 0}}};
auto l = mm->add_parameter("X", s);
std::vector<int64_t> perm = {0, 3, 1, 2};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), l);
p.compile(migraphx::ref::target{});
std::vector<float> data(12);
std::iota(data.begin(), data.end(), 0);
migraphx::parameter_map params;
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {1, 2, 2, 3}};
params["X"] = migraphx::argument(input_fixed_shape, data.data());
auto result = p.eval(params).back();
std::vector<size_t> new_lens = {1, 3, 2, 2};
EXPECT(result.get_shape().lens() == new_lens);
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(unsqueeze_test) TEST_CASE(unsqueeze_test)
{ {
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment