Commit 00d5d880 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents 00d90ca8 f60c3815
slice_5arg_reverse_test:
9arg_step"Constant*#
value** Bstep
Barg_axis"Constant*,
value* *Baxis
@arg_end"Constant*+
value**Bend
D arg_start"Constant*-
value*!*Bstart
5
0
arg_start
arg_end
arg_axis
arg_step1"Sliceslice_5arg_reverse_testZ
0


b
1


B
\ No newline at end of file
slice_5arg_step_test:
9arg_step"Constant*#
value** Bstep
Barg_axis"Constant*,
value* *Baxis
@arg_end"Constant*+
value**Bend
D arg_start"Constant*-
value*!*Bstart
5
0
arg_start
arg_end
arg_axis
arg_step1"Sliceslice_5arg_step_testZ
0


b
1


B
\ No newline at end of file
slice_5arg_test:
slice_5arg_test:
0arg_step"Constant*
value**Bstep
Barg_axis"Constant*,
......@@ -20,4 +20,4 @@ D arg_start"Constant*-
1


B
\ No newline at end of file
B
\ No newline at end of file
......@@ -76,6 +76,7 @@ TEST_CASE(if_else_test)
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_data, data.data());
pp["y"] = migraphx::argument(s_data, data.data());
auto result = p.eval(pp).back();
......@@ -160,6 +161,55 @@ TEST_CASE(if_pl_test)
}
}
TEST_CASE(if_tuple_test)
{
auto run_prog = [](bool cond) {
migraphx::program p = migraphx::parse_onnx("if_tuple_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape xs{migraphx::shape::float_type, {1, 4}};
migraphx::shape ys{migraphx::shape::float_type, {3, 4}};
migraphx::shape cond_s{migraphx::shape::bool_type};
std::vector<float> x_data(xs.elements(), 1.0f);
std::vector<float> y_data(ys.elements(), 2.0f);
std::vector<char> cond_data{static_cast<char>(cond)};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(xs, x_data.data());
pp["y"] = migraphx::argument(ys, y_data.data());
pp["cond"] = migraphx::argument(cond_s, cond_data.data());
auto results = p.eval(pp);
std::vector<std::vector<float>> rets;
for(const auto& arg : results)
{
std::vector<float> vec;
arg.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
rets.push_back(vec);
}
return rets;
};
// then branch
{
auto results = run_prog(true);
std::vector<float> gold0(4, 2.0f);
std::vector<float> gold1(12, 4.0f);
EXPECT(migraphx::verify_range(results.at(0), gold0));
EXPECT(migraphx::verify_range(results.at(1), gold1));
}
// else branch
{
auto results = run_prog(false);
std::vector<float> gold0(4, 3.0f);
std::vector<float> gold1(12, 5.0f);
EXPECT(migraphx::verify_range(results.at(0), gold0));
EXPECT(migraphx::verify_range(results.at(1), gold1));
}
}
TEST_CASE(instance_norm_test)
{
migraphx::program p = migraphx::parse_onnx("instance_norm_val_test.onnx");
......@@ -240,7 +290,84 @@ TEST_CASE(lessorequal_test)
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(resize_test)
TEST_CASE(resize_downsample_f_test)
{
migraphx::program p = migraphx::parse_onnx("resize_downsample_f_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 4}};
std::vector<float> dx(sx.elements());
std::iota(dx.begin(), dx.end(), 0.0f);
migraphx::parameter_map pp;
pp["X"] = migraphx::argument(sx, dx.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.0f, 3.0f};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(resize_upsample_linear_ac_test)
{
migraphx::program p = migraphx::parse_onnx("resize_upsample_linear_ac_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
std::vector<float> dx = {1.0f, 2.0f, 3.0f, 4.0f};
migraphx::parameter_map pp;
pp["X"] = migraphx::argument(sx, dx.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1,
4.0f / 3,
5.0f / 3,
2,
5.0f / 3,
2,
7.0f / 3,
8.0f / 3,
7.0f / 3,
8.0f / 3,
3,
10.0f / 3,
3,
10.0f / 3,
11.0f / 3,
4};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(resize_upsample_linear_test)
{
migraphx::program p = migraphx::parse_onnx("resize_upsample_linear_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
std::vector<float> dx = {1.0f, 2.0f, 3.0f, 4.0f};
migraphx::parameter_map pp;
pp["X"] = migraphx::argument(sx, dx.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
1, 1.25, 1.75, 2, 1.5, 1.75, 2.25, 2.5, 2.5, 2.75, 3.25, 3.5, 3, 3.25, 3.75, 4};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(resize_upsample_pf_test)
{
migraphx::program p = migraphx::parse_onnx("resize_upsample_pf_test.onnx");
p.compile(migraphx::ref::target{});
......@@ -281,6 +408,85 @@ TEST_CASE(selu_test)
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(slice_test)
{
migraphx::program p = migraphx::parse_onnx("slice_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape sh_data{migraphx::shape::float_type, {3, 2}};
std::vector<float> data = {0, 1, 2, 3, 4, 5};
migraphx::parameter_map pp;
pp["0"] = migraphx::argument(sh_data, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {2, 3};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(slice_5arg_test)
{
migraphx::program p = migraphx::parse_onnx("slice_5arg_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape sh_data{migraphx::shape::float_type, {5, 5}}; // start
std::vector<float> data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24};
migraphx::parameter_map pp;
pp["0"] = migraphx::argument(sh_data, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {10, 11, 12, 13, 15, 16, 17, 18};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(slice_reverse_test)
{
migraphx::program p = migraphx::parse_onnx("slice_5arg_reverse_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape sh_data{migraphx::shape::float_type, {5, 5}}; // start
std::vector<float> data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24};
migraphx::parameter_map pp;
pp["0"] = migraphx::argument(sh_data, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {14, 13, 12, 11, 19, 18, 17, 16};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(slice_step_test)
{
migraphx::program p = migraphx::parse_onnx("slice_5arg_step_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape sh_data{migraphx::shape::float_type, {5, 5}}; // start
std::vector<float> data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24};
migraphx::parameter_map pp;
pp["0"] = migraphx::argument(sh_data, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {14, 12};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(upsample_test)
{
migraphx::program p = migraphx::parse_onnx("upsample_test.onnx");
......
......@@ -68,6 +68,43 @@ TEST_CASE(batch_norm_inference_shape)
throws_shape(migraphx::make_op("batch_norm_inference"), s, vars, vars, vars, vars, vars);
}
TEST_CASE(broadcast)
{
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {1}, {0}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", lens}}),
input);
}
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {2}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
}
{
std::vector<std::size_t> lens{2, 2};
migraphx::shape input{migraphx::shape::float_type, {1, 2}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 2, 4, 3}, {0, 0, 3, 1}},
migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}),
input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}), input);
}
}
TEST_CASE(convolution_shape)
{
migraphx::shape output{migraphx::shape::float_type, {4, 4, 1, 1}};
......@@ -106,6 +143,24 @@ TEST_CASE(convolution_shape)
throws_shape(migraphx::make_op("convolution"), input_3d, weights_3d);
}
TEST_CASE(contiguous_shape)
{
migraphx::shape output{migraphx::shape::float_type, {2, 2}};
migraphx::shape input{migraphx::shape::float_type, {2, 2}, {1, 2}};
expect_shape(output, migraphx::make_op("contiguous"), input);
throws_shape(migraphx::make_op("contiguous"), input, input);
migraphx::shape single{migraphx::shape::float_type, {2}};
expect_shape(single, migraphx::make_op("contiguous"), single);
}
TEST_CASE(contiguous_shape_scalar)
{
migraphx::shape output{migraphx::shape::float_type};
migraphx::shape input{migraphx::shape::float_type};
expect_shape(output, migraphx::make_op("contiguous"), input);
}
TEST_CASE(deconvolution_shape)
{
migraphx::shape input{migraphx::shape::float_type, {4, 4, 1, 1}};
......@@ -137,141 +192,6 @@ TEST_CASE(deconvolution_shape)
weights_3d);
}
TEST_CASE(quant_convolution_shape)
{
migraphx::shape output{migraphx::shape::int32_type, {4, 4, 1, 1}};
migraphx::shape input{migraphx::shape::int8_type, {4, 3, 3, 3}};
migraphx::shape weights{migraphx::shape::int8_type, {4, 3, 3, 3}};
expect_shape(output, migraphx::make_op("quant_convolution"), input, weights);
throws_shape(migraphx::make_op("quant_convolution"), input);
throws_shape(migraphx::make_op("quant_convolution",
{{"padding", {0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
input,
weights);
throws_shape(migraphx::make_op("quant_convolution",
{{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
input,
weights);
migraphx::shape input2{migraphx::shape::int32_type, {3, 3}};
migraphx::shape weights2{migraphx::shape::float_type, {3, 3}};
throws_shape(migraphx::make_op("quant_convolution"), input2, weights2);
throws_shape(migraphx::make_op("quant_convolution"), input2, weights);
migraphx::shape input3{migraphx::shape::int32_type, {4, 3, 3, 3}};
migraphx::shape weight3{migraphx::shape::float_type, {4, 3, 3, 3}};
throws_shape(migraphx::make_op("quant_convolution"), input3, weights);
throws_shape(migraphx::make_op("quant_convolution"), input, weight3);
throws_shape(migraphx::make_op("quant_convolution"), input3, weight3);
}
TEST_CASE(pooling_shape)
{
migraphx::shape output{migraphx::shape::float_type, {4, 3, 1, 1}};
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
throws_shape(
migraphx::make_op("pooling",
{{"mode", "max"}, {"padding", {1}}, {"stride", {0}}, {"lengths", {1}}}),
input);
expect_shape(
output,
migraphx::make_op(
"pooling",
{{"mode", "max"}, {"padding", {0, 0}}, {"stride", {3, 3}}, {"lengths", {1, 1}}}),
input);
migraphx::shape output1{migraphx::shape::float_type, {4, 3, 2, 2}};
expect_shape(output1,
migraphx::make_op("pooling",
{{"mode", "max"},
{"padding", {0, 0}},
{"stride", {3, 3}},
{"lengths", {1, 1}},
{"ceil_mode", true}}),
input);
}
TEST_CASE(inconsistent_attr_shape)
{
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
throws_shape(migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}),
input,
weights);
throws_shape(migraphx::make_op("deconvolution",
{{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}),
input,
weights);
throws_shape(
migraphx::make_op(
"pooling", {{"mode", "max"}, {"padding", {1}}, {"stride", {0}}, {"lengths", {1, 1}}}),
input);
}
TEST_CASE(transpose_shape)
{
migraphx::shape input{migraphx::shape::float_type, {2, 2}};
migraphx::shape output{migraphx::shape::float_type, {2, 2}, {1, 2}};
expect_shape(input, migraphx::make_op("transpose", {{"dims", {0, 1}}}), input);
expect_shape(output, migraphx::make_op("transpose", {{"dims", {1, 0}}}), input);
expect_shape(output, migraphx::make_op("transpose"), input);
throws_shape(migraphx::make_op("transpose", {{"dims", {1, 2}}}), input);
}
TEST_CASE(contiguous_shape)
{
migraphx::shape output{migraphx::shape::float_type, {2, 2}};
migraphx::shape input{migraphx::shape::float_type, {2, 2}, {1, 2}};
expect_shape(output, migraphx::make_op("contiguous"), input);
throws_shape(migraphx::make_op("contiguous"), input, input);
migraphx::shape single{migraphx::shape::float_type, {2}};
expect_shape(single, migraphx::make_op("contiguous"), single);
}
TEST_CASE(contiguous_shape_scalar)
{
migraphx::shape output{migraphx::shape::float_type};
migraphx::shape input{migraphx::shape::float_type};
expect_shape(output, migraphx::make_op("contiguous"), input);
}
TEST_CASE(reshape_shape)
{
migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}};
for(auto&& new_shape :
std::vector<std::vector<int64_t>>{{8, 3, 1, 1}, {1, 3, 4, 2}, {1, 3, 4, 2}})
{
std::vector<std::size_t> lens(new_shape.size());
std::copy(new_shape.begin(), new_shape.end(), lens.begin());
migraphx::shape output{migraphx::shape::float_type, lens};
expect_shape(output, migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
for(auto&& new_shape :
std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}, {3, 0, 0}, {3, 2, 0}})
{
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
std::vector<std::pair<std::vector<int64_t>, migraphx::shape>> minus1_tests{
{{2, -1, 3}, {migraphx::shape::float_type, {2, 4, 3}}},
{{0, -1, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
{{2, -1, 0}, {migraphx::shape::float_type, {2, 12, 1}}},
{{0, 0, -1}, {migraphx::shape::float_type, {24, 1, 1}}},
{{2, 0, -1}, {migraphx::shape::float_type, {2, 1, 12}}},
{{-1, 2, 3}, {migraphx::shape::float_type, {4, 2, 3}}},
{{-1, 0, 3}, {migraphx::shape::float_type, {8, 1, 3}}},
{{-1, 0, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
{{-1, 3, 0}, {migraphx::shape::float_type, {8, 3, 1}}}};
for(auto& it : minus1_tests)
{
expect_shape(it.second, migraphx::make_op("reshape", {{"dims", it.first}}), input);
}
}
TEST_CASE(flatten_shape)
{
migraphx::shape input{migraphx::shape::float_type, {2, 4, 6, 8}};
......@@ -300,173 +220,48 @@ TEST_CASE(flatten_shape)
throws_shape(migraphx::make_op("flatten", {{"axis", -5}}), input);
}
TEST_CASE(slice_shape)
TEST_CASE(gather)
{
migraphx::shape input{migraphx::shape::int32_type, {2, 2, 3}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1}}, {"ends", {3}}}),
input);
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraphx::make_op(
"slice", {{"axes", {0, 1, 2}}, {"starts", {0, 0, 1}}, {"ends", {2, 2, 3}}}),
input);
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 1}, {6, 3, 1}},
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {10}}}),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 4, 5}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
}
TEST_CASE(multibroadcast)
{
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {2, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 0, 1}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 3, 4, 5}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {2, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 1, 0, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {1}};
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {5, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 1, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type};
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 0, 0, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 0, 1}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 3, 1}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 1, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 1, 1, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
}
{
std::vector<std::size_t> lens{4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
}
{
std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {3, 4}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
}
{
std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
}
}
TEST_CASE(broadcast)
{
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {1}, {0}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", lens}}),
input);
}
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {2}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
}
{
std::vector<std::size_t> lens{2, 2};
migraphx::shape input{migraphx::shape::float_type, {1, 2}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 2, 4, 3}, {0, 0, 3, 1}},
migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}),
input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}), input);
}
}
TEST_CASE(gather)
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 4, 5}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 3, 4, 5}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {1}};
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type};
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type};
......@@ -512,219 +307,546 @@ TEST_CASE(gather)
}
}
template <class T>
void test_softmax_variations()
// 3 input arguments
TEST_CASE(gemm)
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{0}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{1}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{2}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{3}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 4;
throws_shape(T{axis}, input);
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
}
TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); }
TEST_CASE(test_argmax)
{
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
migraphx::make_op("argmax", {{"axis", 0}}),
input);
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 1}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
migraphx::make_op("argmax", {{"axis", 1}}),
input);
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
migraphx::make_op("argmax", {{"axis", 2}}),
input);
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 1}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
migraphx::make_op("argmax", {{"axis", 3}}),
input);
migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::make_op("argmax", {{"axis", 4}}), input);
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
}
TEST_CASE(test_argmin)
{
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
migraphx::make_op("argmin", {{"axis", 0}}),
input);
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
migraphx::make_op("dot"),
s_m1,
s_m2,
s_m3);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
migraphx::make_op("argmin", {{"axis", 1}}),
input);
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 8}},
migraphx::make_op("dot"),
s_m1,
s_m2,
s_m3);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
migraphx::make_op("argmin", {{"axis", 2}}),
input);
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
migraphx::make_op("argmin", {{"axis", 3}}),
input);
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::make_op("argmin", {{"axis", 4}}), input);
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
}
TEST_CASE(test_scalar)
TEST_CASE(get_tuple_elem_test)
{
migraphx::shape s1{migraphx::shape::float_type, {1}, {1}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 4, 5}, {0, 0, 0, 0}};
expect_shape(s2, migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 3, 4, 5}}}), s1);
migraphx::shape s0{migraphx::shape::bool_type, {1, 1}};
migraphx::shape s1{migraphx::shape::float_type, {2, 3}};
migraphx::shape s2{migraphx::shape::int32_type, {5, 6}};
migraphx::shape s_tuple({s0, s1, s2});
expect_shape(s0, migraphx::make_op("get_tuple_elem", {{"index", 0}}), s_tuple);
expect_shape(s1, migraphx::make_op("get_tuple_elem", {{"index", 1}}), s_tuple);
expect_shape(s2, migraphx::make_op("get_tuple_elem", {{"index", 2}}), s_tuple);
throws_shape(migraphx::make_op("get_tuple_elem", {{"index", 3}}), s_tuple);
throws_shape(migraphx::make_op("get_tuple_elem", {{"index", 0}}), s0);
throws_shape(migraphx::make_op("get_tuple_elem", {{"index", 1}}), s1);
throws_shape(migraphx::make_op("get_tuple_elem", {{"index", 0}}), s2);
}
TEST_CASE(test_scalar_nelemnts)
TEST_CASE(gru)
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 3, 4, 5}}}), input);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
TEST_CASE(test_squeeze)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {3}}}), s1);
}
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
TEST_CASE(test_squeeze_negative_axis)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {-2}}}), s1);
}
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
TEST_CASE(test_squeeze_wrong_axis)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
throws_shape(migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
TEST_CASE(test_squeeze_all)
{
migraphx::shape s1{migraphx::shape::float_type, {1}};
migraphx::shape s2{migraphx::shape::float_type};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
}
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
TEST_CASE(test_unsqueeze_scalar)
{
migraphx::shape s1{migraphx::shape::float_type, {1}, {0}};
migraphx::shape s2{migraphx::shape::float_type, {1}, {1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0}}}), s1);
}
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
TEST_CASE(test_unsqueeze_scalar_tensor1)
{
migraphx::shape s{migraphx::shape::float_type, {4, 3, 3}, {0, 0, 0}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
TEST_CASE(test_unsqueeze_scalar_tensor2)
{
migraphx::shape s{migraphx::shape::float_type, {1, 1, 1}, {0, 0, 0}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s);
}
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
TEST_CASE(test_unsqueeze)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size + 1},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
}
TEST_CASE(test_unsqueeze_negative_axis)
TEST_CASE(inconsistent_attr_shape)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1);
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
throws_shape(migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}),
input,
weights);
throws_shape(migraphx::make_op("deconvolution",
{{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}),
input,
weights);
throws_shape(
migraphx::make_op(
"pooling", {{"mode", "max"}, {"padding", {1}}, {"stride", {0}}, {"lengths", {1, 1}}}),
input);
}
template <class T>
void test_reduce_ops()
void test_softmax_variations()
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{0}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{1}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{2}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{3}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 4;
throws_shape(T{axis}, input);
}
}
TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); }
TEST_CASE(lstm)
{
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{{0, 1, 2, 3}}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, T{{2, 3}}, input);
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, T{{0}}, input);
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size + 1},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}}, T{{-1}}, input);
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(T{{4}}, input);
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
}
TEST_CASE(reduce_sum) { test_reduce_ops<migraphx::op::reduce_sum>(); }
TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); }
// 2 inputs arguments
TEST_CASE(matmul)
{
......@@ -825,523 +947,269 @@ TEST_CASE(matmul)
}
}
// 3 input arguments
TEST_CASE(gemm)
TEST_CASE(multibroadcast)
{
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 1}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 1}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
migraphx::make_op("dot"),
s_m1,
s_m2,
s_m3);
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {2, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 0, 1}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 8}},
migraphx::make_op("dot"),
s_m1,
s_m2,
s_m3);
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {2, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 1, 0, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {5, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 1, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 0, 0, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 0, 1}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
}
}
// quant_dot
TEST_CASE(quant_dot_2args)
{
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
migraphx::make_op("quant_dot"),
s_m1,
s_m2);
std::vector<std::size_t> lens{4, 4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 3, 1}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
}
{
migraphx::shape s_m1{migraphx::shape::int8_type, {3, 8}};
migraphx::shape s_m2{migraphx::shape::int8_type, {8, 7}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {3, 7}},
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}),
s_m1,
s_m2);
std::vector<std::size_t> lens{4, 1, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 1, 1, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
input);
}
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {8, 8}};
throws_shape(migraphx::make_op("quant_dot"), s_m1, s_m2);
std::vector<std::size_t> lens{4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
}
}
TEST_CASE(quant_dot_3args)
{
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
migraphx::shape s_m3{migraphx::shape::int32_type, {2, 8}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
migraphx::make_op("quant_dot"),
s_m1,
s_m2,
s_m3);
std::vector<std::size_t> lens{4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
}
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
migraphx::shape s_m3{migraphx::shape::int8_type, {2, 8}};
throws_shape(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 2}}), s_m1, s_m2, s_m3);
std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {3, 4}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
}
{
std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
}
}
TEST_CASE(rnn)
TEST_CASE(pooling_shape)
{
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape output{migraphx::shape::float_type, {4, 3, 1, 1}};
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
throws_shape(
migraphx::make_op("pooling",
{{"mode", "max"}, {"padding", {1}}, {"stride", {0}}, {"lengths", {1}}}),
input);
expect_shape(
output,
migraphx::make_op(
"pooling",
{{"mode", "max"}, {"padding", {0, 0}}, {"stride", {3, 3}}, {"lengths", {1, 1}}}),
input);
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
migraphx::shape output1{migraphx::shape::float_type, {4, 3, 2, 2}};
expect_shape(output1,
migraphx::make_op("pooling",
{{"mode", "max"},
{"padding", {0, 0}},
{"stride", {3, 3}},
{"lengths", {1, 1}},
{"ceil_mode", true}}),
input);
}
TEST_CASE(prefix_scan_sum)
{
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
migraphx::shape s{migraphx::shape::float_type, {1, 2, 3}};
throws_shape(
migraphx::make_op("prefix_scan_sum", {{"axis", 3}, {"exclusive", 0}, {"reverse", 0}}),
s);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
migraphx::shape s{migraphx::shape::float_type, {1, 2}};
throws_shape(
migraphx::make_op("prefix_scan_sum", {{"axis", -3}, {"exclusive", 0}, {"reverse", 0}}),
s);
}
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
TEST_CASE(quant_convolution_shape)
{
migraphx::shape output{migraphx::shape::int32_type, {4, 4, 1, 1}};
migraphx::shape input{migraphx::shape::int8_type, {4, 3, 3, 3}};
migraphx::shape weights{migraphx::shape::int8_type, {4, 3, 3, 3}};
expect_shape(output, migraphx::make_op("quant_convolution"), input, weights);
throws_shape(migraphx::make_op("quant_convolution"), input);
throws_shape(migraphx::make_op("quant_convolution",
{{"padding", {0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
input,
weights);
throws_shape(migraphx::make_op("quant_convolution",
{{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
input,
weights);
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape input2{migraphx::shape::int32_type, {3, 3}};
migraphx::shape weights2{migraphx::shape::float_type, {3, 3}};
throws_shape(migraphx::make_op("quant_convolution"), input2, weights2);
throws_shape(migraphx::make_op("quant_convolution"), input2, weights);
throws_shape(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size + 1},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
migraphx::shape input3{migraphx::shape::int32_type, {4, 3, 3, 3}};
migraphx::shape weight3{migraphx::shape::float_type, {4, 3, 3, 3}};
throws_shape(migraphx::make_op("quant_convolution"), input3, weights);
throws_shape(migraphx::make_op("quant_convolution"), input, weight3);
throws_shape(migraphx::make_op("quant_convolution"), input3, weight3);
}
// quant_dot
TEST_CASE(quant_dot_2args)
{
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
migraphx::make_op("quant_dot"),
s_m1,
s_m2);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape s_m1{migraphx::shape::int8_type, {3, 8}};
migraphx::shape s_m2{migraphx::shape::int8_type, {8, 7}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {3, 7}},
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}),
s_m1,
s_m2);
}
throws_shape(
migraphx::make_op(
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {8, 8}};
throws_shape(migraphx::make_op("quant_dot"), s_m1, s_m2);
}
}
TEST_CASE(gru)
TEST_CASE(quant_dot_3args)
{
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
migraphx::shape s_m3{migraphx::shape::int32_type, {2, 8}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
migraphx::make_op("quant_dot"),
s_m1,
s_m2,
s_m3);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
migraphx::shape s_m3{migraphx::shape::int8_type, {2, 8}};
throws_shape(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 2}}), s_m1, s_m2, s_m3);
}
}
template <class T>
void test_reduce_ops()
{
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{{0, 1, 2, 3}}, input);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size + 1},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, T{{2, 3}}, input);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, T{{0}}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}}, T{{-1}}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(T{{4}}, input);
}
}
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); }
TEST_CASE(reduce_sum) { test_reduce_ops<migraphx::op::reduce_sum>(); }
throws_shape(
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
TEST_CASE(reshape_shape)
{
migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}};
for(auto&& new_shape :
std::vector<std::vector<int64_t>>{{8, 3, 1, 1}, {1, 3, 4, 2}, {1, 3, 4, 2}})
{
std::vector<std::size_t> lens(new_shape.size());
std::copy(new_shape.begin(), new_shape.end(), lens.begin());
migraphx::shape output{migraphx::shape::float_type, lens};
expect_shape(output, migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
for(auto&& new_shape :
std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}, {3, 0, 0}, {3, 2, 0}})
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
std::vector<std::pair<std::vector<int64_t>, migraphx::shape>> minus1_tests{
{{2, -1, 3}, {migraphx::shape::float_type, {2, 4, 3}}},
{{0, -1, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
{{2, -1, 0}, {migraphx::shape::float_type, {2, 12, 1}}},
{{0, 0, -1}, {migraphx::shape::float_type, {24, 1, 1}}},
{{2, 0, -1}, {migraphx::shape::float_type, {2, 1, 12}}},
{{-1, 2, 3}, {migraphx::shape::float_type, {4, 2, 3}}},
{{-1, 0, 3}, {migraphx::shape::float_type, {8, 1, 3}}},
{{-1, 0, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
{{-1, 3, 0}, {migraphx::shape::float_type, {8, 3, 1}}}};
throws_shape(
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
for(auto& it : minus1_tests)
{
expect_shape(it.second, migraphx::make_op("reshape", {{"dims", it.first}}), input);
}
}
TEST_CASE(lstm)
TEST_CASE(rnn)
{
{
std::size_t batch_size = 2;
......@@ -1352,16 +1220,16 @@ TEST_CASE(lstm)
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"lstm",
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
......@@ -1369,7 +1237,9 @@ TEST_CASE(lstm)
{"clip", clip}}),
in_shape,
w_shape,
r_shape);
r_shape,
b_shape,
ih_shape);
}
{
......@@ -1381,18 +1251,16 @@ TEST_CASE(lstm)
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"lstm",
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
......@@ -1414,18 +1282,16 @@ TEST_CASE(lstm)
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::make_op(
"lstm",
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
......@@ -1447,16 +1313,14 @@ TEST_CASE(lstm)
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape(
migraphx::make_op(
"lstm",
"rnn",
{{"hidden_size", hidden_size + 1},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
......@@ -1478,16 +1342,14 @@ TEST_CASE(lstm)
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape(
migraphx::make_op(
"lstm",
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
......@@ -1509,16 +1371,14 @@ TEST_CASE(lstm)
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape(
migraphx::make_op(
"lstm",
"rnn",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("tanh")})},
......@@ -1532,21 +1392,176 @@ TEST_CASE(lstm)
}
}
TEST_CASE(prefix_scan_sum)
TEST_CASE(slice_shape)
{
migraphx::shape input{migraphx::shape::int32_type, {2, 2, 3}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1}}, {"ends", {3}}}),
input);
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraphx::make_op(
"slice", {{"axes", {0, 1, 2}}, {"starts", {0, 0, 1}}, {"ends", {2, 2, 3}}}),
input);
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 1}, {6, 3, 1}},
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {10}}}),
input);
}
TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
TEST_CASE(test_argmax)
{
{
migraphx::shape s{migraphx::shape::float_type, {1, 2, 3}};
throws_shape(
migraphx::make_op("prefix_scan_sum", {{"axis", 3}, {"exclusive", 0}, {"reverse", 0}}),
s);
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
migraphx::make_op("argmax", {{"axis", 0}}),
input);
}
{
migraphx::shape s{migraphx::shape::float_type, {1, 2}};
throws_shape(
migraphx::make_op("prefix_scan_sum", {{"axis", -3}, {"exclusive", 0}, {"reverse", 0}}),
s);
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
migraphx::make_op("argmax", {{"axis", 1}}),
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
migraphx::make_op("argmax", {{"axis", 2}}),
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
migraphx::make_op("argmax", {{"axis", 3}}),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::make_op("argmax", {{"axis", 4}}), input);
}
}
TEST_CASE(test_argmin)
{
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
migraphx::make_op("argmin", {{"axis", 0}}),
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
migraphx::make_op("argmin", {{"axis", 1}}),
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
migraphx::make_op("argmin", {{"axis", 2}}),
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
migraphx::make_op("argmin", {{"axis", 3}}),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::make_op("argmin", {{"axis", 4}}), input);
}
}
TEST_CASE(test_scalar)
{
migraphx::shape s1{migraphx::shape::float_type, {1}, {1}};
migraphx::shape s2{migraphx::shape::float_type, {2, 3, 4, 5}, {0, 0, 0, 0}};
expect_shape(s2, migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 3, 4, 5}}}), s1);
}
TEST_CASE(test_scalar_nelemnts)
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 3, 4, 5}}}), input);
}
TEST_CASE(test_squeeze)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {3}}}), s1);
}
TEST_CASE(test_squeeze_all)
{
migraphx::shape s1{migraphx::shape::float_type, {1}};
migraphx::shape s2{migraphx::shape::float_type};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
}
TEST_CASE(test_squeeze_negative_axis)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {-2}}}), s1);
}
TEST_CASE(test_squeeze_wrong_axis)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
throws_shape(migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
}
TEST_CASE(test_unsqueeze)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_negative_axis)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1);
}
TEST_CASE(test_unsqueeze_scalar)
{
migraphx::shape s1{migraphx::shape::float_type, {1}, {0}};
migraphx::shape s2{migraphx::shape::float_type, {1}, {1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0}}}), s1);
}
TEST_CASE(test_unsqueeze_scalar_tensor1)
{
migraphx::shape s{migraphx::shape::float_type, {4, 3, 3}, {0, 0, 0}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s);
}
TEST_CASE(test_unsqueeze_scalar_tensor2)
{
migraphx::shape s{migraphx::shape::float_type, {1, 1, 1}, {0, 0, 0}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s);
}
TEST_CASE(transpose_shape)
{
migraphx::shape input{migraphx::shape::float_type, {2, 2}};
migraphx::shape output{migraphx::shape::float_type, {2, 2}, {1, 2}};
expect_shape(input, migraphx::make_op("transpose", {{"dims", {0, 1}}}), input);
expect_shape(output, migraphx::make_op("transpose", {{"dims", {1, 0}}}), input);
expect_shape(output, migraphx::make_op("transpose"), input);
throws_shape(migraphx::make_op("transpose", {{"dims", {1, 2}}}), input);
}
TEST_CASE(step_test)
......
......@@ -3,6 +3,7 @@
#include <migraphx/make_op.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <migraphx/module.hpp>
#include <sstream>
#include <string>
#include <migraphx/make_op.hpp>
......@@ -115,4 +116,36 @@ TEST_CASE(ops)
EXPECT(names.size() > 1);
}
TEST_CASE(rnn)
{
migraphx::shape s{migraphx::shape::float_type, {2, 1}};
std::vector<float> data1(2, 2.0f);
std::vector<float> data2(2, 3.0f);
migraphx::argument a1(s, data1.data());
migraphx::argument a2(s, data2.data());
auto op = migraphx::make_op("rnn");
EXPECT(test::throws([&] { op.compute(s, {a1, a2}); }));
}
TEST_CASE(if_op)
{
migraphx::shape s{migraphx::shape::bool_type, {1}};
std::vector<char> data = {1};
migraphx::argument cond(s, data.data());
migraphx::shape sd{migraphx::shape::float_type, {2, 1}};
std::vector<float> data1(2, 2.0f);
std::vector<float> data2(2, 3.0f);
migraphx::argument a1(sd, data1.data());
migraphx::argument a2(sd, data2.data());
migraphx::module m("name");
auto l = m.add_literal(migraphx::literal(sd, data1));
m.add_return({l});
auto op = migraphx::make_op("add");
EXPECT(test::throws([&] { op.compute(s, {cond, a1, a2}, {&m, &m}, {}); }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <iostream>
#include <vector>
#include <cmath>
#include <migraphx/literal.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
......@@ -48,7 +49,9 @@ TEST_CASE(acos_test)
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {2.4980915448f, 1.5707963268f, 0.0f};
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acosf(n); });
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -64,7 +67,9 @@ TEST_CASE(acosh_test)
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.4435683, 0.6223626, 1.316958};
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acoshf(n); });
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -295,7 +300,9 @@ TEST_CASE(asin_test)
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-0.5235987756f, 0.f, 1.119769515};
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return asinf(n); });
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -311,7 +318,9 @@ TEST_CASE(asinh_test)
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-0.481211841, 0, 0.808866858};
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return asinhf(n); });
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -320,13 +329,16 @@ TEST_CASE(atan_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {3}};
auto l = mm->add_literal(migraphx::literal{s, {-1, 0, 1}});
std::vector<float> data{-1.0f, 0.0f, 1.0f};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("atan"), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-0.7853981634f, 0.0f, 0.7853981634f};
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return atanf(n); });
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -335,13 +347,16 @@ TEST_CASE(atanh_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {3}};
auto l = mm->add_literal(migraphx::literal{s, {0.4435683, 0.6223626, 0.316958}});
std::vector<float> data{0.4435683f, 0.6223626f, 0.316958f};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("atanh"), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.476664424, 0.728852153, 0.328261733};
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return atanhf(n); });
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -675,14 +690,16 @@ TEST_CASE(ceil_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {9}};
auto l =
mm->add_literal(migraphx::literal{s, {1.1, 1.5, 1.6, -1.1, -1.5, -1.6, 0.0, 2.0, -2.0}});
std::vector<float> data = {1.1, 1.5, 1.6, -1.1, -1.5, -1.6, 0.0, 2.0, -2.0};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("ceil"), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {2.0, 2.0, 2.0, -1.0, -1.0, -1.0, 0.0, 2.0, -2.0};
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::ceil(n); });
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -1088,13 +1105,16 @@ TEST_CASE(cos_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = mm->add_literal(migraphx::literal{s, {-1, 0, 1}});
std::vector<float> data{-1, 0, 1};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("cos"), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.54030231f, 1.f, 0.54030231f};
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return cosf(n); });
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -1103,13 +1123,16 @@ TEST_CASE(cosh_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
auto l = mm->add_literal(migraphx::literal{s, {-1.0, 2.0, -3.0, 4.0}});
std::vector<float> data = {-1.0, 2.0, -3.0, 4.0};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("cosh"), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{coshf(-1), coshf(2), coshf(-3), coshf(4)};
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return coshf(n); });
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -1215,14 +1238,17 @@ TEST_CASE(div_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l1 = mm->add_literal(migraphx::literal{s, {-1.0f, 0.5f, 1.0f}});
auto l2 = mm->add_literal(migraphx::literal{s, {1.0f, 2.0f, 4.0f}});
std::vector<float> data1 = {-1.0f, 0.5f, 1.0f};
std::vector<float> data2 = {1.0f, 2.0f, 4.0f};
auto l1 = mm->add_literal(migraphx::literal{s, data1});
auto l2 = mm->add_literal(migraphx::literal{s, data2});
mm->add_instruction(migraphx::make_op("div"), l1, l2);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1.f, 0.25f, 0.25f};
std::vector<float> gold(data1.size());
std::transform(data1.begin(), data1.end(), data2.begin(), gold.begin(), std::divides<float>());
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -1297,14 +1323,16 @@ TEST_CASE(erf_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {4}};
auto l =
mm->add_literal(migraphx::literal{s, {0.73785057, 1.58165966, -0.43597795, -0.01677432}});
std::vector<float> data = {0.73785057, 1.58165966, -0.43597795, -0.01677432};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("erf"), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.70327317, 0.97470088, -0.46247893, -0.01892602};
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return erff(n); });
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -1312,14 +1340,17 @@ TEST_CASE(exp_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data{-1, 0, 1};
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = mm->add_literal(migraphx::literal{s, {-1, 0, 1}});
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("exp"), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.36787944f, 1.f, 2.71828183f};
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return expf(n); });
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -1328,14 +1359,16 @@ TEST_CASE(floor_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {9}};
auto l =
mm->add_literal(migraphx::literal{s, {1.1, 1.5, 0.6, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}});
std::vector<float> data = {1.1, 1.5, 0.6, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("floor"), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0, 1.0, 0.0, -2.0, -2.0, -1.0, -0.0, 2.0, -2.0};
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return floor(n); });
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -1668,7 +1701,8 @@ TEST_CASE(if_literal_test)
else_mod->add_return({l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p;
};
......@@ -1730,7 +1764,8 @@ TEST_CASE(if_param_test)
else_mod->add_return({a2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond, x, y}, {then_mod, else_mod});
mm->add_return({ret});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p;
};
......@@ -1796,7 +1831,8 @@ TEST_CASE(if_pl_test)
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto outline = mm->add_outline(s);
mm->add_return({outline, ret});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({outline, r});
return p;
};
......@@ -2110,12 +2146,12 @@ TEST_CASE(less_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {9}};
auto l0 =
mm->add_literal(migraphx::literal{s, {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}});
auto l1 =
mm->add_literal(migraphx::literal{s, {1.1, 1.6, -0.1, -1.2, -1.5, -0.7, 0.0, 2.3, -2.1}});
auto le = mm->add_instruction(migraphx::make_op("less"), l0, l1);
auto r = mm->add_instruction(
std::vector<float> data1 = {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0};
std::vector<float> data2 = {1.1, 1.6, -0.1, -1.2, -1.5, -0.7, 0.0, 2.3, -2.1};
auto l0 = mm->add_literal(migraphx::literal{s, data1});
auto l1 = mm->add_literal(migraphx::literal{s, data2});
auto le = mm->add_instruction(migraphx::make_op("less"), l0, l1);
auto r = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
le);
......@@ -2125,7 +2161,11 @@ TEST_CASE(less_test)
auto result = p.eval({}).back();
std::vector<bool> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<bool> gold = {false, true, false, false, false, false, false, true, false};
std::vector<bool> gold(data1.size());
std::transform(
data1.begin(), data1.end(), data2.begin(), gold.begin(), [](float n1, float n2) -> bool {
return n1 < n2;
});
EXPECT(results_vector == gold);
}
......@@ -2134,13 +2174,16 @@ TEST_CASE(log_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = mm->add_literal(migraphx::literal{s, {1, 2, 3}});
std::vector<float> data = {1, 2, 3};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("log"), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.0f, 0.6931471806f, 1.0986122887f};
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return logf(n); });
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -2149,14 +2192,20 @@ TEST_CASE(logical_and_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::bool_type, {4}};
auto l1 = mm->add_literal(migraphx::literal{s, {1, 0, 1, 0}});
auto l2 = mm->add_literal(migraphx::literal{s, {1, 1, 0, 0}});
std::vector<bool> data1{true, false, true, false};
std::vector<bool> data2{true, true, false, false};
auto l1 = mm->add_literal(migraphx::literal{s, data1});
auto l2 = mm->add_literal(migraphx::literal{s, data2});
mm->add_instruction(migraphx::make_op("logical_and"), l1, l2);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<char> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<char> gold = {1, 0, 0, 0};
std::vector<bool> gold(data2.size());
std::transform(
data1.begin(), data1.end(), data2.begin(), gold.begin(), [](bool n1, bool n2) -> bool {
return n1 and n2;
});
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -2165,14 +2214,20 @@ TEST_CASE(logical_or_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::bool_type, {4}};
auto l1 = mm->add_literal(migraphx::literal{s, {1, 0, 1, 0}});
auto l2 = mm->add_literal(migraphx::literal{s, {1, 1, 0, 0}});
std::vector<bool> data1{true, false, true, false};
std::vector<bool> data2{true, true, false, false};
auto l1 = mm->add_literal(migraphx::literal{s, data1});
auto l2 = mm->add_literal(migraphx::literal{s, data2});
mm->add_instruction(migraphx::make_op("logical_or"), l1, l2);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<char> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<char> gold = {1, 1, 1, 0};
std::vector<bool> gold(data1.size());
std::transform(
data1.begin(), data1.end(), data2.begin(), gold.begin(), [](bool n1, bool n2) -> bool {
return n1 or n2;
});
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -2181,14 +2236,20 @@ TEST_CASE(logical_xor_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::bool_type, {4}};
auto l1 = mm->add_literal(migraphx::literal{s, {1, 0, 1, 0}});
auto l2 = mm->add_literal(migraphx::literal{s, {1, 1, 0, 0}});
std::vector<bool> data1{true, false, true, false};
std::vector<bool> data2{true, true, false, false};
auto l1 = mm->add_literal(migraphx::literal{s, data1});
auto l2 = mm->add_literal(migraphx::literal{s, data2});
mm->add_instruction(migraphx::make_op("logical_xor"), l1, l2);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<char> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<char> gold = {0, 1, 1, 0};
std::vector<bool> gold = {false, true, true, false};
std::transform(
data1.begin(), data1.end(), data2.begin(), gold.begin(), [](bool n1, bool n2) -> bool {
return n1 ^ n2;
});
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -2559,6 +2620,8 @@ TEST_CASE(mul_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
std::vector<float> data1{-1, 0, 1};
std::vector<float> data2{1, 2, 3};
auto l1 = mm->add_literal(migraphx::literal{s, {-1, 0, 1}});
auto l2 = mm->add_literal(migraphx::literal{s, {1, 2, 3}});
mm->add_instruction(migraphx::make_op("mul"), l1, l2);
......@@ -2566,7 +2629,11 @@ TEST_CASE(mul_test)
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1, 0, 3};
std::vector<float> gold(data1.size());
std::transform(
data1.begin(), data1.end(), data2.begin(), gold.begin(), [](float n1, float n2) -> float {
return n1 * n2;
});
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -2583,8 +2650,8 @@ TEST_CASE(neg_test)
auto result = p.eval({}).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1.0f, -1.3f, 1.2f, 0.0f, 100.f, -200.f};
std::vector<float> gold = data;
std::transform(gold.begin(), gold.end(), gold.begin(), std::negate<float>());
EXPECT(migraphx::verify_range(result_vector, gold));
}
......@@ -2595,13 +2662,14 @@ TEST_CASE(not_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {4}};
auto l1 = mm->add_literal(migraphx::literal{s, {0, 8, 1, -32}});
std::vector<float> data{0, 8, 1, -32};
auto l1 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("not"), l1);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<char> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<char> gold = {1, 0, 0, 0};
std::vector<char> gold{1, 0, 0, 0};
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -2610,13 +2678,15 @@ TEST_CASE(not_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::bool_type, {4}};
std::vector<bool> data{false, false, true, true};
auto l1 = mm->add_literal(migraphx::literal{s, {0, 0, 1, 1}});
mm->add_instruction(migraphx::make_op("not"), l1);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<char> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<char> gold = {1, 1, 0, 0};
std::vector<bool> gold(data.size());
std::transform(data.begin(), data.end(), gold.begin(), [](bool n) -> bool { return !n; });
EXPECT(migraphx::verify_range(results_vector, gold));
}
}
......@@ -2716,14 +2786,17 @@ TEST_CASE(pow_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto b = mm->add_literal(migraphx::literal{s, {1, 2, 3}});
auto e = mm->add_literal(migraphx::literal{s, {1, 2, 3}});
std::vector<float> data = {1, 2, 3};
auto b = mm->add_literal(migraphx::literal{s, data});
auto e = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("pow"), b, e);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0f, 4.0f, 27.0f};
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::pow(n, n); });
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -3560,6 +3633,66 @@ TEST_CASE(reshape_test)
}
}
TEST_CASE(reverse_test_axis0)
{
migraphx::shape in_shape{migraphx::shape::float_type, {2, 16}};
std::vector<float> data(32);
std::iota(data.begin(), data.end(), 1);
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{in_shape, data});
std::vector<int> axes = {0};
mm->add_instruction(migraphx::make_op("reverse", {{"axes", axes}}), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> target_data = data;
std::swap_ranges(target_data.begin(), target_data.begin() + 16, target_data.begin() + 16);
EXPECT(migraphx::verify_range(results_vector, target_data));
}
TEST_CASE(reverse_test_axis1)
{
migraphx::shape in_shape{migraphx::shape::float_type, {2, 16}};
std::vector<float> data(32);
std::iota(data.begin(), data.end(), 1);
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{in_shape, data});
std::vector<int> axes = {1};
mm->add_instruction(migraphx::make_op("reverse", {{"axes", axes}}), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> target_data = data;
std::reverse(target_data.begin(), target_data.begin() + 16);
std::reverse(target_data.end() - 16, target_data.end());
EXPECT(migraphx::verify_range(results_vector, target_data));
}
TEST_CASE(reverse_test_axis10)
{
migraphx::shape in_shape{migraphx::shape::float_type, {2, 16}};
std::vector<float> data(32);
std::iota(data.begin(), data.end(), 1);
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{in_shape, data});
std::vector<int> axes = {1, 0};
mm->add_instruction(migraphx::make_op("reverse", {{"axes", axes}}), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> target_data = data;
std::reverse(target_data.begin(), target_data.begin() + 16);
std::reverse(target_data.end() - 16, target_data.end());
std::swap_ranges(target_data.begin(), target_data.begin() + 16, target_data.begin() + 16);
EXPECT(migraphx::verify_range(results_vector, target_data));
}
TEST_CASE(round_test)
{
migraphx::program p;
......@@ -3627,13 +3760,16 @@ TEST_CASE(sin_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = mm->add_literal(migraphx::literal{s, {-1, 0, 1}});
std::vector<float> data = {-1, 0, 1};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("sin"), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-0.84147098f, 0.f, 0.84147098f};
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinf(n); });
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -3642,13 +3778,16 @@ TEST_CASE(sinh_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
auto l = mm->add_literal(migraphx::literal{s, {-1.0, 2.0, -3.0, 4.0}});
std::vector<float> data{-1.0, 2.0, -3.0, 4.0};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("sinh"), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{sinhf(-1), sinhf(2), sinhf(-3), sinhf(4)};
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinhf(n); });
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -3795,14 +3934,16 @@ TEST_CASE(sqrt_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {5}};
auto l = mm->add_literal(
migraphx::literal{s, {1.02481645, 0.85643062, 0.03404123, 0.92791926, 0.10569184}});
std::vector<float> data{1.02481645, 0.85643062, 0.03404123, 0.92791926, 0.10569184};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("sqrt"), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.01233218, 0.92543537, 0.18450265, 0.96328566, 0.32510282};
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sqrtf(n); });
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -3904,13 +4045,16 @@ TEST_CASE(tan_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = mm->add_literal(migraphx::literal{s, {-1, 0, 1}});
std::vector<float> data{-1, 0, 1};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("tan"), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1.55740772f, 0.0f, 1.55740772f};
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return tanf(n); });
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -3919,13 +4063,16 @@ TEST_CASE(tanh_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
auto l = mm->add_literal(migraphx::literal{s, {-1.0, 2.0, -3.0, 4.0}});
std::vector<float> data{-1.0, 2.0, -3.0, 4.0};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("tanh"), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{tanhf(-1), tanhf(2), tanhf(-3), tanhf(4)};
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return tanhf(n); });
EXPECT(migraphx::verify_range(results_vector, gold));
}
......
......@@ -301,7 +301,7 @@ migraphx::program create_conv()
migraphx::op::convolution op;
op.padding_mode = migraphx::op::padding_mode_t::same;
op.padding = {1, 1};
op.padding = {1, 1, 1, 1};
op.stride = {1, 1};
op.dilation = {1, 1};
auto l2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 2, 0, 1}}}), l1);
......
......@@ -44,3 +44,23 @@ void auto_print::set_terminate_handler(const std::string& name)
get_handler(tname)();
});
}
static bool in_exception()
{
#if __cplusplus >= 201703L
return std::uncaught_exceptions() > 0;
#else
return std::uncaught_exception();
#endif
}
auto_print::~auto_print()
{
if(in_exception())
{
std::cout << std::endl;
for(const auto& tname : migraphx::get_targets())
get_handler(tname)();
}
get_handler(name) = [] {};
}
......@@ -15,10 +15,7 @@ struct auto_print
get_handler(name) = [&x] { std::cout << x << std::endl; };
}
~auto_print()
{
get_handler(name) = [] {};
}
~auto_print();
};
#endif
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct batch_quant_dot_3 : verify_program<batch_quant_dot_3>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 6}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 6, 7}};
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), l1, l2);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 4, 6, 3}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 2, 6, 3}};
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 0, 1, 2}}}), l1);
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 1, 2, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, tl2);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 7, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 5, 7}};
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1);
auto sl1 = mm->add_instruction(migraphx::make_op("add"), tl1, tl1);
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2);
auto sl2 = mm->add_instruction(migraphx::make_op("add"), tl2, tl2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}}), sl1, sl2);
return p;
}
};
......@@ -45,5 +45,14 @@ int main(int argc, const char* argv[])
run_verify rv;
rv.add_validation_for("gpu", &validate_gpu);
rv.disable_test_for("cpu", {"test_if_lp", "test_if_param", "test_if_literal"});
rv.disable_test_for("gpu",
{"batch_quant_dot_2",
"batch_quant_dot_3",
"batch_quant_dot_5",
"quant_dot_3args_1",
"quant_dot_3args_2",
"quant_dot_3args_3",
"quant_dot_3args_4",
"quant_dot_3args_5"});
rv.run(argc, argv);
}
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct quant_dot_3args_5 : verify_program<quant_dot_3args_5>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {6, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 6}};
auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2);
return p;
}
};
#include "run_verify.hpp"
#include "auto_print.hpp"
#include "verify_program.hpp"
#include "test.hpp"
#include <migraphx/env.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/ranges.hpp>
......@@ -121,7 +122,6 @@ void run_verify::verify(const std::string& name, const migraphx::program& p) con
{
using result_future =
std::future<std::pair<migraphx::program, std::vector<migraphx::argument>>>;
std::cout << "[ RUN ] " << name << std::endl;
auto_print::set_terminate_handler(name);
std::vector<std::pair<std::string, result_future>> results;
std::vector<std::string> target_names;
......@@ -180,25 +180,27 @@ void run_verify::verify(const std::string& name, const migraphx::program& p) con
std::cout << tname << ":\n" << cp << std::endl;
std::cout << std::endl;
}
EXPECT(passed);
}
}
std::set_terminate(nullptr);
std::cout << "[ COMPLETE ] " << name << std::endl;
}
void run_verify::run(int argc, const char* argv[]) const
{
std::set<std::string> args(argv + 1, argv + argc);
const auto& ps = get_programs();
for(auto&& p : ps)
std::unordered_map<std::string, std::vector<std::string>> labels;
for(auto&& p : get_programs())
{
if(not args.empty())
{
if(args.count(p.name) == 0 and args.count(p.section) == 0)
continue;
}
verify(p.name, p.get_program());
labels[p.section].push_back(p.name);
test::add_test_case(p.name, [=] { verify(p.name, p.get_program()); });
}
test::driver d{};
d.get_case_names = [&](const std::string& name) -> std::vector<std::string> {
if(labels.count(name) > 0)
return labels.at(name);
return {name};
};
d.run(argc, argv);
}
void run_verify::disable_parallel_for(const std::string& name) { info[name].parallel = false; }
......
......@@ -26,7 +26,8 @@ struct test_if_literal : verify_program<test_if_literal>
else_mod->add_return({l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p;
}
......
......@@ -27,7 +27,9 @@ struct test_if_lp : verify_program<test_if_lp>
else_mod->add_return({s2, l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r0, r1});
return p;
}
......
......@@ -29,7 +29,8 @@ struct test_if_param : verify_program<test_if_param>
else_mod->add_return({a2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment