"docs/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "1bd3637f4532c4e2b44c00dc415a5b465f5b54d3"
Commit d2fa123c authored by charlie's avatar charlie
Browse files

Add ref ops test

parent 5120f911
...@@ -1948,6 +1948,35 @@ TEST_CASE(equal_test) ...@@ -1948,6 +1948,35 @@ TEST_CASE(equal_test)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(equal_dynamic_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{6, 12, 9}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto p0 = mm->add_parameter("l", s);
auto p1 = mm->add_parameter("r", s);
auto eq = mm->add_instruction(migraphx::make_op("equal"), p0, p1);
auto r = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
eq);
mm->add_return({r});
p.compile(migraphx::ref::target{});
std::vector<float> l_data{1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0};
std::vector<float> r_data{1.1, 1.6, -0.1, -1.2, -1.5, -0.7, 0.0, 2.3, -2.1};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {9}};
params0["x"] = migraphx::argument(input_fixed_shape0, l_data.data());
params0["slope"] = migraphx::argument(input_fixed_shape0, r_data.data());
auto result = p.eval(params0).back();
std::vector<bool> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<bool> gold = {true, false, false, false, true, false, true, false, false};
EXPECT(results_vector == gold);
}
TEST_CASE(erf_test) TEST_CASE(erf_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -4768,15 +4797,15 @@ TEST_CASE(prelu_dynamic_test) ...@@ -4768,15 +4797,15 @@ TEST_CASE(prelu_dynamic_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{3, 3, 0}}; std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}};
migraphx::shape s{migraphx::shape::float_type, dd}; migraphx::shape s{migraphx::shape::float_type, dd};
std::vector<float> x_data{-1, 0, 2};
std::vector<float> slope_data{2, 1, 2};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto slope = mm->add_parameter("slope", s); auto slope = mm->add_parameter("slope", s);
mm->add_instruction(migraphx::make_op("prelu"), x, slope); mm->add_instruction(migraphx::make_op("prelu"), x, slope);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
std::vector<float> x_data{-1, 0, 2};
std::vector<float> slope_data{2, 1, 2};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["x"] = migraphx::argument(input_fixed_shape0, x_data.data()); params0["x"] = migraphx::argument(input_fixed_shape0, x_data.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment