Commit 00d5d880 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents 00d90ca8 f60c3815
slice_5arg_reverse_test:
9arg_step"Constant*#
value** Bstep
Barg_axis"Constant*,
value* *Baxis
@arg_end"Constant*+
value**Bend
D arg_start"Constant*-
value*!*Bstart
5
0
arg_start
arg_end
arg_axis
arg_step1"Sliceslice_5arg_reverse_testZ
0


b
1


B
\ No newline at end of file
slice_5arg_step_test:
9arg_step"Constant*#
value** Bstep
Barg_axis"Constant*,
value* *Baxis
@arg_end"Constant*+
value**Bend
D arg_start"Constant*-
value*!*Bstart
5
0
arg_start
arg_end
arg_axis
arg_step1"Sliceslice_5arg_step_testZ
0


b
1


B
\ No newline at end of file
slice_5arg_test: slice_5arg_test:
0arg_step"Constant* 0arg_step"Constant*
value**Bstep value**Bstep
Barg_axis"Constant*, Barg_axis"Constant*,
...@@ -20,4 +20,4 @@ D arg_start"Constant*- ...@@ -20,4 +20,4 @@ D arg_start"Constant*-
1 1
 
 
B B
\ No newline at end of file \ No newline at end of file
...@@ -76,6 +76,7 @@ TEST_CASE(if_else_test) ...@@ -76,6 +76,7 @@ TEST_CASE(if_else_test)
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625}; std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::parameter_map pp; migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_data, data.data());
pp["y"] = migraphx::argument(s_data, data.data()); pp["y"] = migraphx::argument(s_data, data.data());
auto result = p.eval(pp).back(); auto result = p.eval(pp).back();
...@@ -160,6 +161,55 @@ TEST_CASE(if_pl_test) ...@@ -160,6 +161,55 @@ TEST_CASE(if_pl_test)
} }
} }
TEST_CASE(if_tuple_test)
{
auto run_prog = [](bool cond) {
migraphx::program p = migraphx::parse_onnx("if_tuple_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape xs{migraphx::shape::float_type, {1, 4}};
migraphx::shape ys{migraphx::shape::float_type, {3, 4}};
migraphx::shape cond_s{migraphx::shape::bool_type};
std::vector<float> x_data(xs.elements(), 1.0f);
std::vector<float> y_data(ys.elements(), 2.0f);
std::vector<char> cond_data{static_cast<char>(cond)};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(xs, x_data.data());
pp["y"] = migraphx::argument(ys, y_data.data());
pp["cond"] = migraphx::argument(cond_s, cond_data.data());
auto results = p.eval(pp);
std::vector<std::vector<float>> rets;
for(const auto& arg : results)
{
std::vector<float> vec;
arg.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
rets.push_back(vec);
}
return rets;
};
// then branch
{
auto results = run_prog(true);
std::vector<float> gold0(4, 2.0f);
std::vector<float> gold1(12, 4.0f);
EXPECT(migraphx::verify_range(results.at(0), gold0));
EXPECT(migraphx::verify_range(results.at(1), gold1));
}
// else branch
{
auto results = run_prog(false);
std::vector<float> gold0(4, 3.0f);
std::vector<float> gold1(12, 5.0f);
EXPECT(migraphx::verify_range(results.at(0), gold0));
EXPECT(migraphx::verify_range(results.at(1), gold1));
}
}
TEST_CASE(instance_norm_test) TEST_CASE(instance_norm_test)
{ {
migraphx::program p = migraphx::parse_onnx("instance_norm_val_test.onnx"); migraphx::program p = migraphx::parse_onnx("instance_norm_val_test.onnx");
...@@ -240,7 +290,84 @@ TEST_CASE(lessorequal_test) ...@@ -240,7 +290,84 @@ TEST_CASE(lessorequal_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(resize_test) TEST_CASE(resize_downsample_f_test)
{
migraphx::program p = migraphx::parse_onnx("resize_downsample_f_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 4}};
std::vector<float> dx(sx.elements());
std::iota(dx.begin(), dx.end(), 0.0f);
migraphx::parameter_map pp;
pp["X"] = migraphx::argument(sx, dx.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.0f, 3.0f};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(resize_upsample_linear_ac_test)
{
migraphx::program p = migraphx::parse_onnx("resize_upsample_linear_ac_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
std::vector<float> dx = {1.0f, 2.0f, 3.0f, 4.0f};
migraphx::parameter_map pp;
pp["X"] = migraphx::argument(sx, dx.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1,
4.0f / 3,
5.0f / 3,
2,
5.0f / 3,
2,
7.0f / 3,
8.0f / 3,
7.0f / 3,
8.0f / 3,
3,
10.0f / 3,
3,
10.0f / 3,
11.0f / 3,
4};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(resize_upsample_linear_test)
{
migraphx::program p = migraphx::parse_onnx("resize_upsample_linear_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
std::vector<float> dx = {1.0f, 2.0f, 3.0f, 4.0f};
migraphx::parameter_map pp;
pp["X"] = migraphx::argument(sx, dx.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
1, 1.25, 1.75, 2, 1.5, 1.75, 2.25, 2.5, 2.5, 2.75, 3.25, 3.5, 3, 3.25, 3.75, 4};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(resize_upsample_pf_test)
{ {
migraphx::program p = migraphx::parse_onnx("resize_upsample_pf_test.onnx"); migraphx::program p = migraphx::parse_onnx("resize_upsample_pf_test.onnx");
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
...@@ -281,6 +408,85 @@ TEST_CASE(selu_test) ...@@ -281,6 +408,85 @@ TEST_CASE(selu_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(slice_test)
{
migraphx::program p = migraphx::parse_onnx("slice_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape sh_data{migraphx::shape::float_type, {3, 2}};
std::vector<float> data = {0, 1, 2, 3, 4, 5};
migraphx::parameter_map pp;
pp["0"] = migraphx::argument(sh_data, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {2, 3};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(slice_5arg_test)
{
migraphx::program p = migraphx::parse_onnx("slice_5arg_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape sh_data{migraphx::shape::float_type, {5, 5}}; // start
std::vector<float> data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24};
migraphx::parameter_map pp;
pp["0"] = migraphx::argument(sh_data, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {10, 11, 12, 13, 15, 16, 17, 18};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(slice_reverse_test)
{
migraphx::program p = migraphx::parse_onnx("slice_5arg_reverse_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape sh_data{migraphx::shape::float_type, {5, 5}}; // start
std::vector<float> data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24};
migraphx::parameter_map pp;
pp["0"] = migraphx::argument(sh_data, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {14, 13, 12, 11, 19, 18, 17, 16};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(slice_step_test)
{
migraphx::program p = migraphx::parse_onnx("slice_5arg_step_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape sh_data{migraphx::shape::float_type, {5, 5}}; // start
std::vector<float> data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24};
migraphx::parameter_map pp;
pp["0"] = migraphx::argument(sh_data, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {14, 12};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(upsample_test) TEST_CASE(upsample_test)
{ {
migraphx::program p = migraphx::parse_onnx("upsample_test.onnx"); migraphx::program p = migraphx::parse_onnx("upsample_test.onnx");
......
...@@ -68,6 +68,43 @@ TEST_CASE(batch_norm_inference_shape) ...@@ -68,6 +68,43 @@ TEST_CASE(batch_norm_inference_shape)
throws_shape(migraphx::make_op("batch_norm_inference"), s, vars, vars, vars, vars, vars); throws_shape(migraphx::make_op("batch_norm_inference"), s, vars, vars, vars, vars, vars);
} }
TEST_CASE(broadcast)
{
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {1}, {0}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", lens}}),
input);
}
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {2}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
}
{
std::vector<std::size_t> lens{2, 2};
migraphx::shape input{migraphx::shape::float_type, {1, 2}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 2, 4, 3}, {0, 0, 3, 1}},
migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}),
input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}), input);
}
}
TEST_CASE(convolution_shape) TEST_CASE(convolution_shape)
{ {
migraphx::shape output{migraphx::shape::float_type, {4, 4, 1, 1}}; migraphx::shape output{migraphx::shape::float_type, {4, 4, 1, 1}};
...@@ -106,6 +143,24 @@ TEST_CASE(convolution_shape) ...@@ -106,6 +143,24 @@ TEST_CASE(convolution_shape)
throws_shape(migraphx::make_op("convolution"), input_3d, weights_3d); throws_shape(migraphx::make_op("convolution"), input_3d, weights_3d);
} }
TEST_CASE(contiguous_shape)
{
migraphx::shape output{migraphx::shape::float_type, {2, 2}};
migraphx::shape input{migraphx::shape::float_type, {2, 2}, {1, 2}};
expect_shape(output, migraphx::make_op("contiguous"), input);
throws_shape(migraphx::make_op("contiguous"), input, input);
migraphx::shape single{migraphx::shape::float_type, {2}};
expect_shape(single, migraphx::make_op("contiguous"), single);
}
TEST_CASE(contiguous_shape_scalar)
{
migraphx::shape output{migraphx::shape::float_type};
migraphx::shape input{migraphx::shape::float_type};
expect_shape(output, migraphx::make_op("contiguous"), input);
}
TEST_CASE(deconvolution_shape) TEST_CASE(deconvolution_shape)
{ {
migraphx::shape input{migraphx::shape::float_type, {4, 4, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {4, 4, 1, 1}};
...@@ -137,141 +192,6 @@ TEST_CASE(deconvolution_shape) ...@@ -137,141 +192,6 @@ TEST_CASE(deconvolution_shape)
weights_3d); weights_3d);
} }
TEST_CASE(quant_convolution_shape)
{
migraphx::shape output{migraphx::shape::int32_type, {4, 4, 1, 1}};
migraphx::shape input{migraphx::shape::int8_type, {4, 3, 3, 3}};
migraphx::shape weights{migraphx::shape::int8_type, {4, 3, 3, 3}};
expect_shape(output, migraphx::make_op("quant_convolution"), input, weights);
throws_shape(migraphx::make_op("quant_convolution"), input);
throws_shape(migraphx::make_op("quant_convolution",
{{"padding", {0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
input,
weights);
throws_shape(migraphx::make_op("quant_convolution",
{{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
input,
weights);
migraphx::shape input2{migraphx::shape::int32_type, {3, 3}};
migraphx::shape weights2{migraphx::shape::float_type, {3, 3}};
throws_shape(migraphx::make_op("quant_convolution"), input2, weights2);
throws_shape(migraphx::make_op("quant_convolution"), input2, weights);
migraphx::shape input3{migraphx::shape::int32_type, {4, 3, 3, 3}};
migraphx::shape weight3{migraphx::shape::float_type, {4, 3, 3, 3}};
throws_shape(migraphx::make_op("quant_convolution"), input3, weights);
throws_shape(migraphx::make_op("quant_convolution"), input, weight3);
throws_shape(migraphx::make_op("quant_convolution"), input3, weight3);
}
TEST_CASE(pooling_shape)
{
migraphx::shape output{migraphx::shape::float_type, {4, 3, 1, 1}};
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
throws_shape(
migraphx::make_op("pooling",
{{"mode", "max"}, {"padding", {1}}, {"stride", {0}}, {"lengths", {1}}}),
input);
expect_shape(
output,
migraphx::make_op(
"pooling",
{{"mode", "max"}, {"padding", {0, 0}}, {"stride", {3, 3}}, {"lengths", {1, 1}}}),
input);
migraphx::shape output1{migraphx::shape::float_type, {4, 3, 2, 2}};
expect_shape(output1,
migraphx::make_op("pooling",
{{"mode", "max"},
{"padding", {0, 0}},
{"stride", {3, 3}},
{"lengths", {1, 1}},
{"ceil_mode", true}}),
input);
}
TEST_CASE(inconsistent_attr_shape)
{
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
throws_shape(migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}),
input,
weights);
throws_shape(migraphx::make_op("deconvolution",
{{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}),
input,
weights);
throws_shape(
migraphx::make_op(
"pooling", {{"mode", "max"}, {"padding", {1}}, {"stride", {0}}, {"lengths", {1, 1}}}),
input);
}
TEST_CASE(transpose_shape)
{
migraphx::shape input{migraphx::shape::float_type, {2, 2}};
migraphx::shape output{migraphx::shape::float_type, {2, 2}, {1, 2}};
expect_shape(input, migraphx::make_op("transpose", {{"dims", {0, 1}}}), input);
expect_shape(output, migraphx::make_op("transpose", {{"dims", {1, 0}}}), input);
expect_shape(output, migraphx::make_op("transpose"), input);
throws_shape(migraphx::make_op("transpose", {{"dims", {1, 2}}}), input);
}
TEST_CASE(contiguous_shape)
{
migraphx::shape output{migraphx::shape::float_type, {2, 2}};
migraphx::shape input{migraphx::shape::float_type, {2, 2}, {1, 2}};
expect_shape(output, migraphx::make_op("contiguous"), input);
throws_shape(migraphx::make_op("contiguous"), input, input);
migraphx::shape single{migraphx::shape::float_type, {2}};
expect_shape(single, migraphx::make_op("contiguous"), single);
}
TEST_CASE(contiguous_shape_scalar)
{
migraphx::shape output{migraphx::shape::float_type};
migraphx::shape input{migraphx::shape::float_type};
expect_shape(output, migraphx::make_op("contiguous"), input);
}
TEST_CASE(reshape_shape)
{
migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}};
for(auto&& new_shape :
std::vector<std::vector<int64_t>>{{8, 3, 1, 1}, {1, 3, 4, 2}, {1, 3, 4, 2}})
{
std::vector<std::size_t> lens(new_shape.size());
std::copy(new_shape.begin(), new_shape.end(), lens.begin());
migraphx::shape output{migraphx::shape::float_type, lens};
expect_shape(output, migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
for(auto&& new_shape :
std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}, {3, 0, 0}, {3, 2, 0}})
{
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
std::vector<std::pair<std::vector<int64_t>, migraphx::shape>> minus1_tests{
{{2, -1, 3}, {migraphx::shape::float_type, {2, 4, 3}}},
{{0, -1, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
{{2, -1, 0}, {migraphx::shape::float_type, {2, 12, 1}}},
{{0, 0, -1}, {migraphx::shape::float_type, {24, 1, 1}}},
{{2, 0, -1}, {migraphx::shape::float_type, {2, 1, 12}}},
{{-1, 2, 3}, {migraphx::shape::float_type, {4, 2, 3}}},
{{-1, 0, 3}, {migraphx::shape::float_type, {8, 1, 3}}},
{{-1, 0, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
{{-1, 3, 0}, {migraphx::shape::float_type, {8, 3, 1}}}};
for(auto& it : minus1_tests)
{
expect_shape(it.second, migraphx::make_op("reshape", {{"dims", it.first}}), input);
}
}
TEST_CASE(flatten_shape) TEST_CASE(flatten_shape)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 4, 6, 8}}; migraphx::shape input{migraphx::shape::float_type, {2, 4, 6, 8}};
...@@ -300,173 +220,48 @@ TEST_CASE(flatten_shape) ...@@ -300,173 +220,48 @@ TEST_CASE(flatten_shape)
throws_shape(migraphx::make_op("flatten", {{"axis", -5}}), input); throws_shape(migraphx::make_op("flatten", {{"axis", -5}}), input);
} }
TEST_CASE(slice_shape) TEST_CASE(gather)
{ {
migraphx::shape input{migraphx::shape::int32_type, {2, 2, 3}}; {
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}}, migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1}}, {"ends", {3}}}), migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
input); int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 4, 5}},
migraphx::make_op( migraphx::make_op("gather", {{"axis", axis}}),
"slice", {{"axes", {0, 1, 2}}, {"starts", {0, 0, 1}}, {"ends", {2, 2, 3}}}), input,
input); indices);
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 1}, {6, 3, 1}}, }
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {10}}}),
input);
}
TEST_CASE(multibroadcast)
{
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape input{migraphx::shape::float_type, {2, 1, 3}}; migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 0, 1}}, int axis = -4;
migraphx::make_op("multibroadcast", {{"output_lens", lens}}), expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 3, 4, 5}},
input); migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
} }
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape input{migraphx::shape::float_type, {2, 1, 1}}; migraphx::shape indices{migraphx::shape::int32_type, {1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 1, 0, 0}}, int axis = -4;
migraphx::make_op("multibroadcast", {{"output_lens", lens}}), expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}},
input); migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
} }
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape input{migraphx::shape::float_type, {5, 1}}; migraphx::shape indices{migraphx::shape::int32_type};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 1, 0}}, int axis = -4;
migraphx::make_op("multibroadcast", {{"output_lens", lens}}), expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}},
input); migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
} }
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 0, 0, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 0, 1}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 3, 1}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 1, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 1, 1, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
}
{
std::vector<std::size_t> lens{4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
}
{
std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {3, 4}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
}
{
std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
}
}
TEST_CASE(broadcast)
{
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {1}, {0}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", lens}}),
input);
}
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {2}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
}
{
std::vector<std::size_t> lens{2, 2};
migraphx::shape input{migraphx::shape::float_type, {1, 2}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 2, 4, 3}, {0, 0, 3, 1}},
migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}),
input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}), input);
}
}
TEST_CASE(gather)
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 4, 5}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 3, 4, 5}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {1}};
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type};
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
}
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type}; migraphx::shape indices{migraphx::shape::int32_type};
...@@ -512,219 +307,546 @@ TEST_CASE(gather) ...@@ -512,219 +307,546 @@ TEST_CASE(gather)
} }
} }
template <class T> // 3 input arguments
void test_softmax_variations() TEST_CASE(gemm)
{ {
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{0}, input); migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
} migraphx::shape s_m3{migraphx::shape::float_type, {1}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{1}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{2}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{3}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 4;
throws_shape(T{axis}, input);
} }
}
TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); }
TEST_CASE(test_argmax)
{
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}}, migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::make_op("argmax", {{"axis", 0}}), migraphx::shape s_m3{migraphx::shape::float_type, {1, 1}};
input); throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
} }
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}}, migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::make_op("argmax", {{"axis", 1}}), migraphx::shape s_m3{migraphx::shape::float_type, {8}};
input); throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
} }
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}}, migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::make_op("argmax", {{"axis", 2}}), migraphx::shape s_m3{migraphx::shape::float_type, {4, 1}};
input); throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
} }
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}}, migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::make_op("argmax", {{"axis", 3}}), migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
input); throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
throws_shape(migraphx::make_op("argmax", {{"axis", 4}}), input); migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
} }
}
TEST_CASE(test_argmin)
{
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}}, migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::make_op("argmin", {{"axis", 0}}), migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
input); expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
migraphx::make_op("dot"),
s_m1,
s_m2,
s_m3);
} }
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}}, migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::make_op("argmin", {{"axis", 1}}), migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
input); expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 8}},
migraphx::make_op("dot"),
s_m1,
s_m2,
s_m3);
} }
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 6}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}}, migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::make_op("argmin", {{"axis", 2}}), migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
input); throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
} }
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}}, migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::make_op("argmin", {{"axis", 3}}), migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
input); throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
throws_shape(migraphx::make_op("argmin", {{"axis", 4}}), input); migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
} }
} }
TEST_CASE(test_scalar) TEST_CASE(get_tuple_elem_test)
{ {
migraphx::shape s1{migraphx::shape::float_type, {1}, {1}}; migraphx::shape s0{migraphx::shape::bool_type, {1, 1}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 4, 5}, {0, 0, 0, 0}}; migraphx::shape s1{migraphx::shape::float_type, {2, 3}};
expect_shape(s2, migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 3, 4, 5}}}), s1); migraphx::shape s2{migraphx::shape::int32_type, {5, 6}};
migraphx::shape s_tuple({s0, s1, s2});
expect_shape(s0, migraphx::make_op("get_tuple_elem", {{"index", 0}}), s_tuple);
expect_shape(s1, migraphx::make_op("get_tuple_elem", {{"index", 1}}), s_tuple);
expect_shape(s2, migraphx::make_op("get_tuple_elem", {{"index", 2}}), s_tuple);
throws_shape(migraphx::make_op("get_tuple_elem", {{"index", 3}}), s_tuple);
throws_shape(migraphx::make_op("get_tuple_elem", {{"index", 0}}), s0);
throws_shape(migraphx::make_op("get_tuple_elem", {{"index", 1}}), s1);
throws_shape(migraphx::make_op("get_tuple_elem", {{"index", 0}}), s2);
} }
TEST_CASE(test_scalar_nelemnts) TEST_CASE(gru)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; {
throws_shape(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 3, 4, 5}}}), input); std::size_t batch_size = 2;
} std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
TEST_CASE(test_squeeze) migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
{ migraphx::shape w_shape{migraphx::shape::float_type,
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}}; {num_dirct, 3 * hidden_size, input_size}};
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}}; migraphx::shape r_shape{migraphx::shape::float_type,
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {3}}}), s1); {num_dirct, 3 * hidden_size, hidden_size}};
} migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
TEST_CASE(test_squeeze_negative_axis) expect_shape(
{ migraphx::shape{migraphx::shape::float_type,
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}}; {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}}; migraphx::make_op(
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {-2}}}), s1); "gru",
} {{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
TEST_CASE(test_squeeze_wrong_axis) {
{ std::size_t batch_size = 2;
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}}; std::size_t seq_len = 2;
throws_shape(migraphx::make_op("squeeze", {{"axes", {0}}}), s1); std::size_t hidden_size = 4;
} std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
TEST_CASE(test_squeeze_all) migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
{ migraphx::shape w_shape{migraphx::shape::float_type,
migraphx::shape s1{migraphx::shape::float_type, {1}}; {num_dirct, 3 * hidden_size, input_size}};
migraphx::shape s2{migraphx::shape::float_type}; migraphx::shape r_shape{migraphx::shape::float_type,
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {0}}}), s1); {num_dirct, 3 * hidden_size, hidden_size}};
} migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
TEST_CASE(test_unsqueeze_scalar) expect_shape(
{ migraphx::shape{migraphx::shape::float_type,
migraphx::shape s1{migraphx::shape::float_type, {1}, {0}}; {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::shape s2{migraphx::shape::float_type, {1}, {1}}; migraphx::make_op(
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0}}}), s1); "gru",
} {{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
TEST_CASE(test_unsqueeze_scalar_tensor1) {
{ std::size_t batch_size = 2;
migraphx::shape s{migraphx::shape::float_type, {4, 3, 3}, {0, 0, 0}}; std::size_t seq_len = 2;
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s); std::size_t hidden_size = 4;
} std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
TEST_CASE(test_unsqueeze_scalar_tensor2) migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
{ migraphx::shape w_shape{migraphx::shape::float_type,
migraphx::shape s{migraphx::shape::float_type, {1, 1, 1}, {0, 0, 0}}; {num_dirct, 3 * hidden_size, input_size}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s); migraphx::shape r_shape{migraphx::shape::float_type,
} {num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
TEST_CASE(test_unsqueeze) expect_shape(
{ migraphx::shape{migraphx::shape::float_type,
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}}; {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}}; migraphx::make_op(
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1); "gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size + 1},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
} }
TEST_CASE(test_unsqueeze_negative_axis) TEST_CASE(inconsistent_attr_shape)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}}; migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}}; migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1); throws_shape(migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}),
input,
weights);
throws_shape(migraphx::make_op("deconvolution",
{{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}),
input,
weights);
throws_shape(
migraphx::make_op(
"pooling", {{"mode", "max"}, {"padding", {1}}, {"stride", {0}}, {"lengths", {1, 1}}}),
input);
} }
template <class T> template <class T>
void test_reduce_ops() void test_softmax_variations()
{ {
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{0}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{1}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{2}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{3}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 4;
throws_shape(T{axis}, input);
}
}
TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); }
TEST_CASE(lstm)
{
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape( expect_shape(
migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{{0, 1, 2, 3}}, input); migraphx::shape{migraphx::shape::float_type,
} {seq_len, num_dirct, batch_size, hidden_size}},
{ migraphx::make_op(
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; "lstm",
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, T{{2, 3}}, input); {{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; std::size_t batch_size = 2;
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, T{{0}}, input); std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size + 1},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; std::size_t batch_size = 2;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}}, T{{-1}}, input); std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; std::size_t batch_size = 2;
throws_shape(T{{4}}, input); std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
} }
TEST_CASE(reduce_sum) { test_reduce_ops<migraphx::op::reduce_sum>(); }
TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); }
// 2 inputs arguments // 2 inputs arguments
TEST_CASE(matmul) TEST_CASE(matmul)
{ {
...@@ -825,523 +947,269 @@ TEST_CASE(matmul) ...@@ -825,523 +947,269 @@ TEST_CASE(matmul)
} }
} }
// 3 input arguments TEST_CASE(multibroadcast)
TEST_CASE(gemm)
{ {
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}}; migraphx::shape input{migraphx::shape::float_type, {2, 1, 3}};
migraphx::shape s_m3{migraphx::shape::float_type, {1}}; expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 0, 1}},
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3); migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
} input);
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 1}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 1}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
migraphx::make_op("dot"),
s_m1,
s_m2,
s_m3);
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}}; migraphx::shape input{migraphx::shape::float_type, {2, 1, 1}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}}; expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 1, 0, 0}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 8}}, migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
migraphx::make_op("dot"), input);
s_m1,
s_m2,
s_m3);
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 6}}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}}; migraphx::shape input{migraphx::shape::float_type, {5, 1}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}}; expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 1, 0}},
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3); migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}}; expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 0, 0, 0}},
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3); migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}}; migraphx::shape input{migraphx::shape::float_type, {3}};
migraphx::shape s_m3{migraphx::shape::float_type}; expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 0, 1}},
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3); migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
} }
}
// quant_dot
TEST_CASE(quant_dot_2args)
{
{ {
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}}; std::vector<std::size_t> lens{4, 4, 1, 3};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 3, 1}},
migraphx::make_op("quant_dot"), migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
s_m1, input);
s_m2);
} }
{ {
migraphx::shape s_m1{migraphx::shape::int8_type, {3, 8}}; std::vector<std::size_t> lens{4, 1, 1, 3};
migraphx::shape s_m2{migraphx::shape::int8_type, {8, 7}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {3, 7}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 1, 1, 0}},
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
s_m1, input);
s_m2);
} }
{ {
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}}; std::vector<std::size_t> lens{4, 1, 3};
migraphx::shape s_m2{migraphx::shape::int8_type, {8, 8}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
throws_shape(migraphx::make_op("quant_dot"), s_m1, s_m2); throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
} }
}
TEST_CASE(quant_dot_3args)
{
{ {
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}}; std::vector<std::size_t> lens{4, 1, 3};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}}; migraphx::shape input{migraphx::shape::float_type, {}};
migraphx::shape s_m3{migraphx::shape::int32_type, {2, 8}}; throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
migraphx::make_op("quant_dot"),
s_m1,
s_m2,
s_m3);
} }
{ {
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}}; std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}}; migraphx::shape input{migraphx::shape::float_type, {3, 4}};
migraphx::shape s_m3{migraphx::shape::int8_type, {2, 8}}; throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
throws_shape(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 2}}), s_m1, s_m2, s_m3); }
{
std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
} }
} }
TEST_CASE(rnn) TEST_CASE(pooling_shape)
{ {
{ migraphx::shape output{migraphx::shape::float_type, {4, 3, 1, 1}};
std::size_t batch_size = 2; migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
std::size_t seq_len = 2; throws_shape(
std::size_t hidden_size = 4; migraphx::make_op("pooling",
std::size_t input_size = 3; {{"mode", "max"}, {"padding", {1}}, {"stride", {0}}, {"lengths", {1}}}),
std::size_t num_dirct = 1; input);
float clip = 0.0f; expect_shape(
output,
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::make_op(
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; "pooling",
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; {{"mode", "max"}, {"padding", {0, 0}}, {"stride", {3, 3}}, {"lengths", {1, 1}}}),
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; input);
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape( migraphx::shape output1{migraphx::shape::float_type, {4, 3, 2, 2}};
migraphx::shape{migraphx::shape::float_type, expect_shape(output1,
{seq_len, num_dirct, batch_size, hidden_size}}, migraphx::make_op("pooling",
migraphx::make_op( {{"mode", "max"},
"rnn", {"padding", {0, 0}},
{{"hidden_size", hidden_size}, {"stride", {3, 3}},
{"actv_func", {"lengths", {1, 1}},
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})}, {"ceil_mode", true}}),
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, input);
{"clip", clip}}), }
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
TEST_CASE(prefix_scan_sum)
{
{ {
std::size_t batch_size = 2; migraphx::shape s{migraphx::shape::float_type, {1, 2, 3}};
std::size_t seq_len = 2; throws_shape(
std::size_t hidden_size = 4; migraphx::make_op("prefix_scan_sum", {{"axis", 3}, {"exclusive", 0}, {"reverse", 0}}),
std::size_t input_size = 3; s);
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
{ {
std::size_t batch_size = 2; migraphx::shape s{migraphx::shape::float_type, {1, 2}};
std::size_t seq_len = 2; throws_shape(
std::size_t hidden_size = 4; migraphx::make_op("prefix_scan_sum", {{"axis", -3}, {"exclusive", 0}, {"reverse", 0}}),
std::size_t input_size = 3; s);
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
}
{ TEST_CASE(quant_convolution_shape)
std::size_t batch_size = 2; {
std::size_t seq_len = 2; migraphx::shape output{migraphx::shape::int32_type, {4, 4, 1, 1}};
std::size_t hidden_size = 4; migraphx::shape input{migraphx::shape::int8_type, {4, 3, 3, 3}};
std::size_t input_size = 3; migraphx::shape weights{migraphx::shape::int8_type, {4, 3, 3, 3}};
std::size_t num_dirct = 1; expect_shape(output, migraphx::make_op("quant_convolution"), input, weights);
float clip = 0.0f; throws_shape(migraphx::make_op("quant_convolution"), input);
throws_shape(migraphx::make_op("quant_convolution",
{{"padding", {0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
input,
weights);
throws_shape(migraphx::make_op("quant_convolution",
{{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
input,
weights);
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape input2{migraphx::shape::int32_type, {3, 3}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape weights2{migraphx::shape::float_type, {3, 3}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; throws_shape(migraphx::make_op("quant_convolution"), input2, weights2);
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; throws_shape(migraphx::make_op("quant_convolution"), input2, weights);
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape( migraphx::shape input3{migraphx::shape::int32_type, {4, 3, 3, 3}};
migraphx::make_op( migraphx::shape weight3{migraphx::shape::float_type, {4, 3, 3, 3}};
"rnn", throws_shape(migraphx::make_op("quant_convolution"), input3, weights);
{{"hidden_size", hidden_size + 1}, throws_shape(migraphx::make_op("quant_convolution"), input, weight3);
{"actv_func", throws_shape(migraphx::make_op("quant_convolution"), input3, weight3);
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})}, }
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
// quant_dot
TEST_CASE(quant_dot_2args)
{
{ {
std::size_t batch_size = 2; migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
std::size_t seq_len = 2; migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
std::size_t hidden_size = 4; expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
std::size_t input_size = 3; migraphx::make_op("quant_dot"),
std::size_t num_dirct = 1; s_m1,
float clip = 0.0f; s_m2);
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
{ {
std::size_t batch_size = 2; migraphx::shape s_m1{migraphx::shape::int8_type, {3, 8}};
std::size_t seq_len = 2; migraphx::shape s_m2{migraphx::shape::int8_type, {8, 7}};
std::size_t hidden_size = 4; expect_shape(migraphx::shape{migraphx::shape::int32_type, {3, 7}},
std::size_t input_size = 3; migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}),
std::size_t num_dirct = 2; s_m1,
float clip = 0.0f; s_m2);
}
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape( {
migraphx::make_op( migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
"rnn", migraphx::shape s_m2{migraphx::shape::int8_type, {8, 8}};
{{"hidden_size", hidden_size}, throws_shape(migraphx::make_op("quant_dot"), s_m1, s_m2);
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
} }
TEST_CASE(gru) TEST_CASE(quant_dot_3args)
{ {
{ {
std::size_t batch_size = 2; migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
std::size_t seq_len = 2; migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
std::size_t hidden_size = 4; migraphx::shape s_m3{migraphx::shape::int32_type, {2, 8}};
std::size_t input_size = 3; expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
std::size_t num_dirct = 1; migraphx::make_op("quant_dot"),
float clip = 0.0f; s_m1,
s_m2,
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; s_m3);
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
{ {
std::size_t batch_size = 2; migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
std::size_t seq_len = 2; migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
std::size_t hidden_size = 4; migraphx::shape s_m3{migraphx::shape::int8_type, {2, 8}};
std::size_t input_size = 3; throws_shape(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 2}}), s_m1, s_m2, s_m3);
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
}
template <class T>
void test_reduce_ops()
{
{ {
std::size_t batch_size = 2; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
std::size_t seq_len = 2; expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{}, input);
std::size_t hidden_size = 4; }
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape( expect_shape(
migraphx::shape{migraphx::shape::float_type, migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{{0, 1, 2, 3}}, input);
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
{ {
std::size_t batch_size = 2; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
std::size_t seq_len = 2; expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, T{{2, 3}}, input);
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size + 1},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
{ {
std::size_t batch_size = 2; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
std::size_t seq_len = 2; expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, T{{0}}, input);
std::size_t hidden_size = 4; }
std::size_t input_size = 3; {
std::size_t num_dirct = 1; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
float clip = 0.0f; expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}}, T{{-1}}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(T{{4}}, input);
}
}
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); }
migraphx::shape w_shape{migraphx::shape::float_type, TEST_CASE(reduce_sum) { test_reduce_ops<migraphx::op::reduce_sum>(); }
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape( TEST_CASE(reshape_shape)
migraphx::make_op( {
"gru", migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}};
{{"hidden_size", hidden_size}, for(auto&& new_shape :
{"actv_func", std::vector<std::vector<int64_t>>{{8, 3, 1, 1}, {1, 3, 4, 2}, {1, 3, 4, 2}})
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})}, {
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, std::vector<std::size_t> lens(new_shape.size());
{"clip", clip}}), std::copy(new_shape.begin(), new_shape.end(), lens.begin());
in_shape, migraphx::shape output{migraphx::shape::float_type, lens};
w_shape, expect_shape(output, migraphx::make_op("reshape", {{"dims", new_shape}}), input);
r_shape,
b_shape,
ih_shape);
} }
for(auto&& new_shape :
std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}, {3, 0, 0}, {3, 2, 0}})
{ {
std::size_t batch_size = 2; throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
std::size_t seq_len = 2; }
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; std::vector<std::pair<std::vector<int64_t>, migraphx::shape>> minus1_tests{
migraphx::shape w_shape{migraphx::shape::float_type, {{2, -1, 3}, {migraphx::shape::float_type, {2, 4, 3}}},
{num_dirct, 3 * hidden_size, input_size}}; {{0, -1, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
migraphx::shape r_shape{migraphx::shape::float_type, {{2, -1, 0}, {migraphx::shape::float_type, {2, 12, 1}}},
{num_dirct, 3 * hidden_size, hidden_size}}; {{0, 0, -1}, {migraphx::shape::float_type, {24, 1, 1}}},
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; {{2, 0, -1}, {migraphx::shape::float_type, {2, 1, 12}}},
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; {{-1, 2, 3}, {migraphx::shape::float_type, {4, 2, 3}}},
{{-1, 0, 3}, {migraphx::shape::float_type, {8, 1, 3}}},
{{-1, 0, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
{{-1, 3, 0}, {migraphx::shape::float_type, {8, 3, 1}}}};
throws_shape( for(auto& it : minus1_tests)
migraphx::make_op( {
"gru", expect_shape(it.second, migraphx::make_op("reshape", {{"dims", it.first}}), input);
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
} }
} }
TEST_CASE(lstm) TEST_CASE(rnn)
{ {
{ {
std::size_t batch_size = 2; std::size_t batch_size = 2;
...@@ -1352,16 +1220,16 @@ TEST_CASE(lstm) ...@@ -1352,16 +1220,16 @@ TEST_CASE(lstm)
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
{num_dirct, 3 * hidden_size, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
{num_dirct, 3 * hidden_size, hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape( expect_shape(
migraphx::shape{migraphx::shape::float_type, migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op( migraphx::make_op(
"lstm", "rnn",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
...@@ -1369,7 +1237,9 @@ TEST_CASE(lstm) ...@@ -1369,7 +1237,9 @@ TEST_CASE(lstm)
{"clip", clip}}), {"clip", clip}}),
in_shape, in_shape,
w_shape, w_shape,
r_shape); r_shape,
b_shape,
ih_shape);
} }
{ {
...@@ -1381,18 +1251,16 @@ TEST_CASE(lstm) ...@@ -1381,18 +1251,16 @@ TEST_CASE(lstm)
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape( expect_shape(
migraphx::shape{migraphx::shape::float_type, migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op( migraphx::make_op(
"lstm", "rnn",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
...@@ -1414,18 +1282,16 @@ TEST_CASE(lstm) ...@@ -1414,18 +1282,16 @@ TEST_CASE(lstm)
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape( expect_shape(
migraphx::shape{migraphx::shape::float_type, migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op( migraphx::make_op(
"lstm", "rnn",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
...@@ -1447,16 +1313,14 @@ TEST_CASE(lstm) ...@@ -1447,16 +1313,14 @@ TEST_CASE(lstm)
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape( throws_shape(
migraphx::make_op( migraphx::make_op(
"lstm", "rnn",
{{"hidden_size", hidden_size + 1}, {{"hidden_size", hidden_size + 1},
{"actv_func", {"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
...@@ -1478,16 +1342,14 @@ TEST_CASE(lstm) ...@@ -1478,16 +1342,14 @@ TEST_CASE(lstm)
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape( throws_shape(
migraphx::make_op( migraphx::make_op(
"lstm", "rnn",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
...@@ -1509,16 +1371,14 @@ TEST_CASE(lstm) ...@@ -1509,16 +1371,14 @@ TEST_CASE(lstm)
float clip = 0.0f; float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape( throws_shape(
migraphx::make_op( migraphx::make_op(
"lstm", "rnn",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})}, migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
...@@ -1532,21 +1392,176 @@ TEST_CASE(lstm) ...@@ -1532,21 +1392,176 @@ TEST_CASE(lstm)
} }
} }
TEST_CASE(prefix_scan_sum) TEST_CASE(slice_shape)
{
migraphx::shape input{migraphx::shape::int32_type, {2, 2, 3}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1}}, {"ends", {3}}}),
input);
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraphx::make_op(
"slice", {{"axes", {0, 1, 2}}, {"starts", {0, 0, 1}}, {"ends", {2, 2, 3}}}),
input);
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 1}, {6, 3, 1}},
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {10}}}),
input);
}
TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
TEST_CASE(test_argmax)
{ {
{ {
migraphx::shape s{migraphx::shape::float_type, {1, 2, 3}}; migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
throws_shape( expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
migraphx::make_op("prefix_scan_sum", {{"axis", 3}, {"exclusive", 0}, {"reverse", 0}}), migraphx::make_op("argmax", {{"axis", 0}}),
s); input);
} }
{ {
migraphx::shape s{migraphx::shape::float_type, {1, 2}}; migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
throws_shape( expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
migraphx::make_op("prefix_scan_sum", {{"axis", -3}, {"exclusive", 0}, {"reverse", 0}}), migraphx::make_op("argmax", {{"axis", 1}}),
s); input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
migraphx::make_op("argmax", {{"axis", 2}}),
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
migraphx::make_op("argmax", {{"axis", 3}}),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::make_op("argmax", {{"axis", 4}}), input);
}
}
TEST_CASE(test_argmin)
{
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
migraphx::make_op("argmin", {{"axis", 0}}),
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
migraphx::make_op("argmin", {{"axis", 1}}),
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
migraphx::make_op("argmin", {{"axis", 2}}),
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
migraphx::make_op("argmin", {{"axis", 3}}),
input);
} }
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::make_op("argmin", {{"axis", 4}}), input);
}
}
TEST_CASE(test_scalar)
{
migraphx::shape s1{migraphx::shape::float_type, {1}, {1}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 4, 5}, {0, 0, 0, 0}};
expect_shape(s2, migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 3, 4, 5}}}), s1);
}
TEST_CASE(test_scalar_nelemnts)
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 3, 4, 5}}}), input);
}
TEST_CASE(test_squeeze)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {3}}}), s1);
}
TEST_CASE(test_squeeze_all)
{
migraphx::shape s1{migraphx::shape::float_type, {1}};
migraphx::shape s2{migraphx::shape::float_type};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
}
TEST_CASE(test_squeeze_negative_axis)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {-2}}}), s1);
}
TEST_CASE(test_squeeze_wrong_axis)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
throws_shape(migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
}
TEST_CASE(test_unsqueeze)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_negative_axis)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1);
}
TEST_CASE(test_unsqueeze_scalar)
{
migraphx::shape s1{migraphx::shape::float_type, {1}, {0}};
migraphx::shape s2{migraphx::shape::float_type, {1}, {1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0}}}), s1);
}
TEST_CASE(test_unsqueeze_scalar_tensor1)
{
migraphx::shape s{migraphx::shape::float_type, {4, 3, 3}, {0, 0, 0}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s);
}
TEST_CASE(test_unsqueeze_scalar_tensor2)
{
migraphx::shape s{migraphx::shape::float_type, {1, 1, 1}, {0, 0, 0}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s);
}
TEST_CASE(transpose_shape)
{
migraphx::shape input{migraphx::shape::float_type, {2, 2}};
migraphx::shape output{migraphx::shape::float_type, {2, 2}, {1, 2}};
expect_shape(input, migraphx::make_op("transpose", {{"dims", {0, 1}}}), input);
expect_shape(output, migraphx::make_op("transpose", {{"dims", {1, 0}}}), input);
expect_shape(output, migraphx::make_op("transpose"), input);
throws_shape(migraphx::make_op("transpose", {{"dims", {1, 2}}}), input);
} }
TEST_CASE(step_test) TEST_CASE(step_test)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp> #include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <migraphx/module.hpp>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -115,4 +116,36 @@ TEST_CASE(ops) ...@@ -115,4 +116,36 @@ TEST_CASE(ops)
EXPECT(names.size() > 1); EXPECT(names.size() > 1);
} }
TEST_CASE(rnn)
{
migraphx::shape s{migraphx::shape::float_type, {2, 1}};
std::vector<float> data1(2, 2.0f);
std::vector<float> data2(2, 3.0f);
migraphx::argument a1(s, data1.data());
migraphx::argument a2(s, data2.data());
auto op = migraphx::make_op("rnn");
EXPECT(test::throws([&] { op.compute(s, {a1, a2}); }));
}
TEST_CASE(if_op)
{
migraphx::shape s{migraphx::shape::bool_type, {1}};
std::vector<char> data = {1};
migraphx::argument cond(s, data.data());
migraphx::shape sd{migraphx::shape::float_type, {2, 1}};
std::vector<float> data1(2, 2.0f);
std::vector<float> data2(2, 3.0f);
migraphx::argument a1(sd, data1.data());
migraphx::argument a2(sd, data2.data());
migraphx::module m("name");
auto l = m.add_literal(migraphx::literal(sd, data1));
m.add_return({l});
auto op = migraphx::make_op("add");
EXPECT(test::throws([&] { op.compute(s, {cond, a1, a2}, {&m, &m}, {}); }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <cmath>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/op/pooling.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/op/batch_norm_inference.hpp> #include <migraphx/op/batch_norm_inference.hpp>
...@@ -48,7 +49,9 @@ TEST_CASE(acos_test) ...@@ -48,7 +49,9 @@ TEST_CASE(acos_test)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {2.4980915448f, 1.5707963268f, 0.0f}; std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acosf(n); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -64,7 +67,9 @@ TEST_CASE(acosh_test) ...@@ -64,7 +67,9 @@ TEST_CASE(acosh_test)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.4435683, 0.6223626, 1.316958}; std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acoshf(n); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -295,7 +300,9 @@ TEST_CASE(asin_test) ...@@ -295,7 +300,9 @@ TEST_CASE(asin_test)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-0.5235987756f, 0.f, 1.119769515}; std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return asinf(n); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -311,7 +318,9 @@ TEST_CASE(asinh_test) ...@@ -311,7 +318,9 @@ TEST_CASE(asinh_test)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-0.481211841, 0, 0.808866858}; std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return asinhf(n); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -320,13 +329,16 @@ TEST_CASE(atan_test) ...@@ -320,13 +329,16 @@ TEST_CASE(atan_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {3}}; migraphx::shape s{migraphx::shape::double_type, {3}};
auto l = mm->add_literal(migraphx::literal{s, {-1, 0, 1}}); std::vector<float> data{-1.0f, 0.0f, 1.0f};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("atan"), l); mm->add_instruction(migraphx::make_op("atan"), l);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-0.7853981634f, 0.0f, 0.7853981634f}; std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return atanf(n); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -335,13 +347,16 @@ TEST_CASE(atanh_test) ...@@ -335,13 +347,16 @@ TEST_CASE(atanh_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {3}}; migraphx::shape s{migraphx::shape::double_type, {3}};
auto l = mm->add_literal(migraphx::literal{s, {0.4435683, 0.6223626, 0.316958}}); std::vector<float> data{0.4435683f, 0.6223626f, 0.316958f};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("atanh"), l); mm->add_instruction(migraphx::make_op("atanh"), l);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.476664424, 0.728852153, 0.328261733}; std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return atanhf(n); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -675,14 +690,16 @@ TEST_CASE(ceil_test) ...@@ -675,14 +690,16 @@ TEST_CASE(ceil_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {9}}; migraphx::shape s{migraphx::shape::float_type, {9}};
auto l = std::vector<float> data = {1.1, 1.5, 1.6, -1.1, -1.5, -1.6, 0.0, 2.0, -2.0};
mm->add_literal(migraphx::literal{s, {1.1, 1.5, 1.6, -1.1, -1.5, -1.6, 0.0, 2.0, -2.0}}); auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("ceil"), l); mm->add_instruction(migraphx::make_op("ceil"), l);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {2.0, 2.0, 2.0, -1.0, -1.0, -1.0, 0.0, 2.0, -2.0}; std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::ceil(n); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -1088,13 +1105,16 @@ TEST_CASE(cos_test) ...@@ -1088,13 +1105,16 @@ TEST_CASE(cos_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = mm->add_literal(migraphx::literal{s, {-1, 0, 1}}); std::vector<float> data{-1, 0, 1};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("cos"), l); mm->add_instruction(migraphx::make_op("cos"), l);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.54030231f, 1.f, 0.54030231f}; std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return cosf(n); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -1103,13 +1123,16 @@ TEST_CASE(cosh_test) ...@@ -1103,13 +1123,16 @@ TEST_CASE(cosh_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 2}}; migraphx::shape s{migraphx::shape::float_type, {2, 2}};
auto l = mm->add_literal(migraphx::literal{s, {-1.0, 2.0, -3.0, 4.0}}); std::vector<float> data = {-1.0, 2.0, -3.0, 4.0};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("cosh"), l); mm->add_instruction(migraphx::make_op("cosh"), l);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{coshf(-1), coshf(2), coshf(-3), coshf(4)}; std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return coshf(n); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -1215,14 +1238,17 @@ TEST_CASE(div_test) ...@@ -1215,14 +1238,17 @@ TEST_CASE(div_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto l1 = mm->add_literal(migraphx::literal{s, {-1.0f, 0.5f, 1.0f}}); std::vector<float> data1 = {-1.0f, 0.5f, 1.0f};
auto l2 = mm->add_literal(migraphx::literal{s, {1.0f, 2.0f, 4.0f}}); std::vector<float> data2 = {1.0f, 2.0f, 4.0f};
auto l1 = mm->add_literal(migraphx::literal{s, data1});
auto l2 = mm->add_literal(migraphx::literal{s, data2});
mm->add_instruction(migraphx::make_op("div"), l1, l2); mm->add_instruction(migraphx::make_op("div"), l1, l2);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1.f, 0.25f, 0.25f}; std::vector<float> gold(data1.size());
std::transform(data1.begin(), data1.end(), data2.begin(), gold.begin(), std::divides<float>());
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -1297,14 +1323,16 @@ TEST_CASE(erf_test) ...@@ -1297,14 +1323,16 @@ TEST_CASE(erf_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {4}}; migraphx::shape s{migraphx::shape::float_type, {4}};
auto l = std::vector<float> data = {0.73785057, 1.58165966, -0.43597795, -0.01677432};
mm->add_literal(migraphx::literal{s, {0.73785057, 1.58165966, -0.43597795, -0.01677432}}); auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("erf"), l); mm->add_instruction(migraphx::make_op("erf"), l);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.70327317, 0.97470088, -0.46247893, -0.01892602}; std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return erff(n); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -1312,14 +1340,17 @@ TEST_CASE(exp_test) ...@@ -1312,14 +1340,17 @@ TEST_CASE(exp_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<float> data{-1, 0, 1};
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = mm->add_literal(migraphx::literal{s, {-1, 0, 1}}); auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("exp"), l); mm->add_instruction(migraphx::make_op("exp"), l);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.36787944f, 1.f, 2.71828183f}; std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return expf(n); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -1328,14 +1359,16 @@ TEST_CASE(floor_test) ...@@ -1328,14 +1359,16 @@ TEST_CASE(floor_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {9}}; migraphx::shape s{migraphx::shape::float_type, {9}};
auto l = std::vector<float> data = {1.1, 1.5, 0.6, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0};
mm->add_literal(migraphx::literal{s, {1.1, 1.5, 0.6, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}}); auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("floor"), l); mm->add_instruction(migraphx::make_op("floor"), l);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0, 1.0, 0.0, -2.0, -2.0, -1.0, -0.0, 2.0, -2.0}; std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return floor(n); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -1668,7 +1701,8 @@ TEST_CASE(if_literal_test) ...@@ -1668,7 +1701,8 @@ TEST_CASE(if_literal_test)
else_mod->add_return({l2}); else_mod->add_return({l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret}); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p; return p;
}; };
...@@ -1730,7 +1764,8 @@ TEST_CASE(if_param_test) ...@@ -1730,7 +1764,8 @@ TEST_CASE(if_param_test)
else_mod->add_return({a2}); else_mod->add_return({a2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond, x, y}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond, x, y}, {then_mod, else_mod});
mm->add_return({ret}); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p; return p;
}; };
...@@ -1796,7 +1831,8 @@ TEST_CASE(if_pl_test) ...@@ -1796,7 +1831,8 @@ TEST_CASE(if_pl_test)
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto outline = mm->add_outline(s); auto outline = mm->add_outline(s);
mm->add_return({outline, ret}); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({outline, r});
return p; return p;
}; };
...@@ -2110,12 +2146,12 @@ TEST_CASE(less_test) ...@@ -2110,12 +2146,12 @@ TEST_CASE(less_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {9}}; migraphx::shape s{migraphx::shape::float_type, {9}};
auto l0 = std::vector<float> data1 = {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0};
mm->add_literal(migraphx::literal{s, {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}}); std::vector<float> data2 = {1.1, 1.6, -0.1, -1.2, -1.5, -0.7, 0.0, 2.3, -2.1};
auto l1 = auto l0 = mm->add_literal(migraphx::literal{s, data1});
mm->add_literal(migraphx::literal{s, {1.1, 1.6, -0.1, -1.2, -1.5, -0.7, 0.0, 2.3, -2.1}}); auto l1 = mm->add_literal(migraphx::literal{s, data2});
auto le = mm->add_instruction(migraphx::make_op("less"), l0, l1); auto le = mm->add_instruction(migraphx::make_op("less"), l0, l1);
auto r = mm->add_instruction( auto r = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
le); le);
...@@ -2125,7 +2161,11 @@ TEST_CASE(less_test) ...@@ -2125,7 +2161,11 @@ TEST_CASE(less_test)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<bool> results_vector; std::vector<bool> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<bool> gold = {false, true, false, false, false, false, false, true, false}; std::vector<bool> gold(data1.size());
std::transform(
data1.begin(), data1.end(), data2.begin(), gold.begin(), [](float n1, float n2) -> bool {
return n1 < n2;
});
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
...@@ -2134,13 +2174,16 @@ TEST_CASE(log_test) ...@@ -2134,13 +2174,16 @@ TEST_CASE(log_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = mm->add_literal(migraphx::literal{s, {1, 2, 3}}); std::vector<float> data = {1, 2, 3};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("log"), l); mm->add_instruction(migraphx::make_op("log"), l);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.0f, 0.6931471806f, 1.0986122887f}; std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return logf(n); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -2149,14 +2192,20 @@ TEST_CASE(logical_and_test) ...@@ -2149,14 +2192,20 @@ TEST_CASE(logical_and_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::bool_type, {4}}; migraphx::shape s{migraphx::shape::bool_type, {4}};
auto l1 = mm->add_literal(migraphx::literal{s, {1, 0, 1, 0}}); std::vector<bool> data1{true, false, true, false};
auto l2 = mm->add_literal(migraphx::literal{s, {1, 1, 0, 0}}); std::vector<bool> data2{true, true, false, false};
auto l1 = mm->add_literal(migraphx::literal{s, data1});
auto l2 = mm->add_literal(migraphx::literal{s, data2});
mm->add_instruction(migraphx::make_op("logical_and"), l1, l2); mm->add_instruction(migraphx::make_op("logical_and"), l1, l2);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<char> results_vector; std::vector<char> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<char> gold = {1, 0, 0, 0}; std::vector<bool> gold(data2.size());
std::transform(
data1.begin(), data1.end(), data2.begin(), gold.begin(), [](bool n1, bool n2) -> bool {
return n1 and n2;
});
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -2165,14 +2214,20 @@ TEST_CASE(logical_or_test) ...@@ -2165,14 +2214,20 @@ TEST_CASE(logical_or_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::bool_type, {4}}; migraphx::shape s{migraphx::shape::bool_type, {4}};
auto l1 = mm->add_literal(migraphx::literal{s, {1, 0, 1, 0}}); std::vector<bool> data1{true, false, true, false};
auto l2 = mm->add_literal(migraphx::literal{s, {1, 1, 0, 0}}); std::vector<bool> data2{true, true, false, false};
auto l1 = mm->add_literal(migraphx::literal{s, data1});
auto l2 = mm->add_literal(migraphx::literal{s, data2});
mm->add_instruction(migraphx::make_op("logical_or"), l1, l2); mm->add_instruction(migraphx::make_op("logical_or"), l1, l2);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<char> results_vector; std::vector<char> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<char> gold = {1, 1, 1, 0}; std::vector<bool> gold(data1.size());
std::transform(
data1.begin(), data1.end(), data2.begin(), gold.begin(), [](bool n1, bool n2) -> bool {
return n1 or n2;
});
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -2181,14 +2236,20 @@ TEST_CASE(logical_xor_test) ...@@ -2181,14 +2236,20 @@ TEST_CASE(logical_xor_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::bool_type, {4}}; migraphx::shape s{migraphx::shape::bool_type, {4}};
auto l1 = mm->add_literal(migraphx::literal{s, {1, 0, 1, 0}}); std::vector<bool> data1{true, false, true, false};
auto l2 = mm->add_literal(migraphx::literal{s, {1, 1, 0, 0}}); std::vector<bool> data2{true, true, false, false};
auto l1 = mm->add_literal(migraphx::literal{s, data1});
auto l2 = mm->add_literal(migraphx::literal{s, data2});
mm->add_instruction(migraphx::make_op("logical_xor"), l1, l2); mm->add_instruction(migraphx::make_op("logical_xor"), l1, l2);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<char> results_vector; std::vector<char> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<char> gold = {0, 1, 1, 0}; std::vector<bool> gold = {false, true, true, false};
std::transform(
data1.begin(), data1.end(), data2.begin(), gold.begin(), [](bool n1, bool n2) -> bool {
return n1 ^ n2;
});
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -2559,6 +2620,8 @@ TEST_CASE(mul_test) ...@@ -2559,6 +2620,8 @@ TEST_CASE(mul_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
std::vector<float> data1{-1, 0, 1};
std::vector<float> data2{1, 2, 3};
auto l1 = mm->add_literal(migraphx::literal{s, {-1, 0, 1}}); auto l1 = mm->add_literal(migraphx::literal{s, {-1, 0, 1}});
auto l2 = mm->add_literal(migraphx::literal{s, {1, 2, 3}}); auto l2 = mm->add_literal(migraphx::literal{s, {1, 2, 3}});
mm->add_instruction(migraphx::make_op("mul"), l1, l2); mm->add_instruction(migraphx::make_op("mul"), l1, l2);
...@@ -2566,7 +2629,11 @@ TEST_CASE(mul_test) ...@@ -2566,7 +2629,11 @@ TEST_CASE(mul_test)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1, 0, 3}; std::vector<float> gold(data1.size());
std::transform(
data1.begin(), data1.end(), data2.begin(), gold.begin(), [](float n1, float n2) -> float {
return n1 * n2;
});
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -2583,8 +2650,8 @@ TEST_CASE(neg_test) ...@@ -2583,8 +2650,8 @@ TEST_CASE(neg_test)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> result_vector; std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = data;
std::vector<float> gold = {-1.0f, -1.3f, 1.2f, 0.0f, 100.f, -200.f}; std::transform(gold.begin(), gold.end(), gold.begin(), std::negate<float>());
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
...@@ -2595,13 +2662,14 @@ TEST_CASE(not_test) ...@@ -2595,13 +2662,14 @@ TEST_CASE(not_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {4}}; migraphx::shape s{migraphx::shape::int32_type, {4}};
auto l1 = mm->add_literal(migraphx::literal{s, {0, 8, 1, -32}}); std::vector<float> data{0, 8, 1, -32};
auto l1 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("not"), l1); mm->add_instruction(migraphx::make_op("not"), l1);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<char> results_vector; std::vector<char> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<char> gold = {1, 0, 0, 0}; std::vector<char> gold{1, 0, 0, 0};
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -2610,13 +2678,15 @@ TEST_CASE(not_test) ...@@ -2610,13 +2678,15 @@ TEST_CASE(not_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::bool_type, {4}}; migraphx::shape s{migraphx::shape::bool_type, {4}};
std::vector<bool> data{false, false, true, true};
auto l1 = mm->add_literal(migraphx::literal{s, {0, 0, 1, 1}}); auto l1 = mm->add_literal(migraphx::literal{s, {0, 0, 1, 1}});
mm->add_instruction(migraphx::make_op("not"), l1); mm->add_instruction(migraphx::make_op("not"), l1);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<char> results_vector; std::vector<char> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<char> gold = {1, 1, 0, 0}; std::vector<bool> gold(data.size());
std::transform(data.begin(), data.end(), gold.begin(), [](bool n) -> bool { return !n; });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
} }
...@@ -2716,14 +2786,17 @@ TEST_CASE(pow_test) ...@@ -2716,14 +2786,17 @@ TEST_CASE(pow_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto b = mm->add_literal(migraphx::literal{s, {1, 2, 3}}); std::vector<float> data = {1, 2, 3};
auto e = mm->add_literal(migraphx::literal{s, {1, 2, 3}}); auto b = mm->add_literal(migraphx::literal{s, data});
auto e = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("pow"), b, e); mm->add_instruction(migraphx::make_op("pow"), b, e);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0f, 4.0f, 27.0f}; std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::pow(n, n); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -3560,6 +3633,66 @@ TEST_CASE(reshape_test) ...@@ -3560,6 +3633,66 @@ TEST_CASE(reshape_test)
} }
} }
TEST_CASE(reverse_test_axis0)
{
migraphx::shape in_shape{migraphx::shape::float_type, {2, 16}};
std::vector<float> data(32);
std::iota(data.begin(), data.end(), 1);
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{in_shape, data});
std::vector<int> axes = {0};
mm->add_instruction(migraphx::make_op("reverse", {{"axes", axes}}), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> target_data = data;
std::swap_ranges(target_data.begin(), target_data.begin() + 16, target_data.begin() + 16);
EXPECT(migraphx::verify_range(results_vector, target_data));
}
TEST_CASE(reverse_test_axis1)
{
migraphx::shape in_shape{migraphx::shape::float_type, {2, 16}};
std::vector<float> data(32);
std::iota(data.begin(), data.end(), 1);
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{in_shape, data});
std::vector<int> axes = {1};
mm->add_instruction(migraphx::make_op("reverse", {{"axes", axes}}), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> target_data = data;
std::reverse(target_data.begin(), target_data.begin() + 16);
std::reverse(target_data.end() - 16, target_data.end());
EXPECT(migraphx::verify_range(results_vector, target_data));
}
TEST_CASE(reverse_test_axis10)
{
migraphx::shape in_shape{migraphx::shape::float_type, {2, 16}};
std::vector<float> data(32);
std::iota(data.begin(), data.end(), 1);
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{in_shape, data});
std::vector<int> axes = {1, 0};
mm->add_instruction(migraphx::make_op("reverse", {{"axes", axes}}), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> target_data = data;
std::reverse(target_data.begin(), target_data.begin() + 16);
std::reverse(target_data.end() - 16, target_data.end());
std::swap_ranges(target_data.begin(), target_data.begin() + 16, target_data.begin() + 16);
EXPECT(migraphx::verify_range(results_vector, target_data));
}
TEST_CASE(round_test) TEST_CASE(round_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -3627,13 +3760,16 @@ TEST_CASE(sin_test) ...@@ -3627,13 +3760,16 @@ TEST_CASE(sin_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = mm->add_literal(migraphx::literal{s, {-1, 0, 1}}); std::vector<float> data = {-1, 0, 1};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("sin"), l); mm->add_instruction(migraphx::make_op("sin"), l);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-0.84147098f, 0.f, 0.84147098f}; std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinf(n); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -3642,13 +3778,16 @@ TEST_CASE(sinh_test) ...@@ -3642,13 +3778,16 @@ TEST_CASE(sinh_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 2}}; migraphx::shape s{migraphx::shape::float_type, {2, 2}};
auto l = mm->add_literal(migraphx::literal{s, {-1.0, 2.0, -3.0, 4.0}}); std::vector<float> data{-1.0, 2.0, -3.0, 4.0};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("sinh"), l); mm->add_instruction(migraphx::make_op("sinh"), l);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{sinhf(-1), sinhf(2), sinhf(-3), sinhf(4)}; std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinhf(n); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -3795,14 +3934,16 @@ TEST_CASE(sqrt_test) ...@@ -3795,14 +3934,16 @@ TEST_CASE(sqrt_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {5}}; migraphx::shape s{migraphx::shape::float_type, {5}};
auto l = mm->add_literal( std::vector<float> data{1.02481645, 0.85643062, 0.03404123, 0.92791926, 0.10569184};
migraphx::literal{s, {1.02481645, 0.85643062, 0.03404123, 0.92791926, 0.10569184}}); auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("sqrt"), l); mm->add_instruction(migraphx::make_op("sqrt"), l);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.01233218, 0.92543537, 0.18450265, 0.96328566, 0.32510282}; std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sqrtf(n); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -3904,13 +4045,16 @@ TEST_CASE(tan_test) ...@@ -3904,13 +4045,16 @@ TEST_CASE(tan_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = mm->add_literal(migraphx::literal{s, {-1, 0, 1}}); std::vector<float> data{-1, 0, 1};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("tan"), l); mm->add_instruction(migraphx::make_op("tan"), l);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1.55740772f, 0.0f, 1.55740772f}; std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return tanf(n); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
...@@ -3919,13 +4063,16 @@ TEST_CASE(tanh_test) ...@@ -3919,13 +4063,16 @@ TEST_CASE(tanh_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 2}}; migraphx::shape s{migraphx::shape::float_type, {2, 2}};
auto l = mm->add_literal(migraphx::literal{s, {-1.0, 2.0, -3.0, 4.0}}); std::vector<float> data{-1.0, 2.0, -3.0, 4.0};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("tanh"), l); mm->add_instruction(migraphx::make_op("tanh"), l);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{tanhf(-1), tanhf(2), tanhf(-3), tanhf(4)}; std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return tanhf(n); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
......
...@@ -301,7 +301,7 @@ migraphx::program create_conv() ...@@ -301,7 +301,7 @@ migraphx::program create_conv()
migraphx::op::convolution op; migraphx::op::convolution op;
op.padding_mode = migraphx::op::padding_mode_t::same; op.padding_mode = migraphx::op::padding_mode_t::same;
op.padding = {1, 1}; op.padding = {1, 1, 1, 1};
op.stride = {1, 1}; op.stride = {1, 1};
op.dilation = {1, 1}; op.dilation = {1, 1};
auto l2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 2, 0, 1}}}), l1); auto l2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 2, 0, 1}}}), l1);
......
...@@ -44,3 +44,23 @@ void auto_print::set_terminate_handler(const std::string& name) ...@@ -44,3 +44,23 @@ void auto_print::set_terminate_handler(const std::string& name)
get_handler(tname)(); get_handler(tname)();
}); });
} }
static bool in_exception()
{
#if __cplusplus >= 201703L
return std::uncaught_exceptions() > 0;
#else
return std::uncaught_exception();
#endif
}
auto_print::~auto_print()
{
if(in_exception())
{
std::cout << std::endl;
for(const auto& tname : migraphx::get_targets())
get_handler(tname)();
}
get_handler(name) = [] {};
}
...@@ -15,10 +15,7 @@ struct auto_print ...@@ -15,10 +15,7 @@ struct auto_print
get_handler(name) = [&x] { std::cout << x << std::endl; }; get_handler(name) = [&x] { std::cout << x << std::endl; };
} }
~auto_print() ~auto_print();
{
get_handler(name) = [] {};
}
}; };
#endif #endif
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct batch_quant_dot_3 : verify_program<batch_quant_dot_3>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 6}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 6, 7}};
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), l1, l2);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 4, 6, 3}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 2, 6, 3}};
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 0, 1, 2}}}), l1);
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 1, 2, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, tl2);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 7, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 5, 7}};
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1);
auto sl1 = mm->add_instruction(migraphx::make_op("add"), tl1, tl1);
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2);
auto sl2 = mm->add_instruction(migraphx::make_op("add"), tl2, tl2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}}), sl1, sl2);
return p;
}
};
...@@ -45,5 +45,14 @@ int main(int argc, const char* argv[]) ...@@ -45,5 +45,14 @@ int main(int argc, const char* argv[])
run_verify rv; run_verify rv;
rv.add_validation_for("gpu", &validate_gpu); rv.add_validation_for("gpu", &validate_gpu);
rv.disable_test_for("cpu", {"test_if_lp", "test_if_param", "test_if_literal"}); rv.disable_test_for("cpu", {"test_if_lp", "test_if_param", "test_if_literal"});
rv.disable_test_for("gpu",
{"batch_quant_dot_2",
"batch_quant_dot_3",
"batch_quant_dot_5",
"quant_dot_3args_1",
"quant_dot_3args_2",
"quant_dot_3args_3",
"quant_dot_3args_4",
"quant_dot_3args_5"});
rv.run(argc, argv); rv.run(argc, argv);
} }
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct quant_dot_3args_5 : verify_program<quant_dot_3args_5>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {6, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 6}};
auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2);
return p;
}
};
#include "run_verify.hpp" #include "run_verify.hpp"
#include "auto_print.hpp" #include "auto_print.hpp"
#include "verify_program.hpp" #include "verify_program.hpp"
#include "test.hpp"
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
...@@ -121,7 +122,6 @@ void run_verify::verify(const std::string& name, const migraphx::program& p) con ...@@ -121,7 +122,6 @@ void run_verify::verify(const std::string& name, const migraphx::program& p) con
{ {
using result_future = using result_future =
std::future<std::pair<migraphx::program, std::vector<migraphx::argument>>>; std::future<std::pair<migraphx::program, std::vector<migraphx::argument>>>;
std::cout << "[ RUN ] " << name << std::endl;
auto_print::set_terminate_handler(name); auto_print::set_terminate_handler(name);
std::vector<std::pair<std::string, result_future>> results; std::vector<std::pair<std::string, result_future>> results;
std::vector<std::string> target_names; std::vector<std::string> target_names;
...@@ -180,25 +180,27 @@ void run_verify::verify(const std::string& name, const migraphx::program& p) con ...@@ -180,25 +180,27 @@ void run_verify::verify(const std::string& name, const migraphx::program& p) con
std::cout << tname << ":\n" << cp << std::endl; std::cout << tname << ":\n" << cp << std::endl;
std::cout << std::endl; std::cout << std::endl;
} }
EXPECT(passed);
} }
} }
std::set_terminate(nullptr); std::set_terminate(nullptr);
std::cout << "[ COMPLETE ] " << name << std::endl;
} }
void run_verify::run(int argc, const char* argv[]) const void run_verify::run(int argc, const char* argv[]) const
{ {
std::set<std::string> args(argv + 1, argv + argc); std::unordered_map<std::string, std::vector<std::string>> labels;
const auto& ps = get_programs(); for(auto&& p : get_programs())
for(auto&& p : ps)
{ {
if(not args.empty()) labels[p.section].push_back(p.name);
{ test::add_test_case(p.name, [=] { verify(p.name, p.get_program()); });
if(args.count(p.name) == 0 and args.count(p.section) == 0)
continue;
}
verify(p.name, p.get_program());
} }
test::driver d{};
d.get_case_names = [&](const std::string& name) -> std::vector<std::string> {
if(labels.count(name) > 0)
return labels.at(name);
return {name};
};
d.run(argc, argv);
} }
void run_verify::disable_parallel_for(const std::string& name) { info[name].parallel = false; } void run_verify::disable_parallel_for(const std::string& name) { info[name].parallel = false; }
......
...@@ -26,7 +26,8 @@ struct test_if_literal : verify_program<test_if_literal> ...@@ -26,7 +26,8 @@ struct test_if_literal : verify_program<test_if_literal>
else_mod->add_return({l2}); else_mod->add_return({l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret}); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p; return p;
} }
......
...@@ -27,7 +27,9 @@ struct test_if_lp : verify_program<test_if_lp> ...@@ -27,7 +27,9 @@ struct test_if_lp : verify_program<test_if_lp>
else_mod->add_return({s2, l2}); else_mod->add_return({s2, l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret}); auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r0, r1});
return p; return p;
} }
......
...@@ -29,7 +29,8 @@ struct test_if_param : verify_program<test_if_param> ...@@ -29,7 +29,8 @@ struct test_if_param : verify_program<test_if_param>
else_mod->add_return({a2}); else_mod->add_return({a2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret}); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p; return p;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment