Commit 4ea39116 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents 20128cae d8011adf
...@@ -30,39 +30,71 @@ ...@@ -30,39 +30,71 @@
#include <test.hpp> #include <test.hpp>
TEST_CASE(round_test) TEST_CASE(nearbyint_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {9}}; migraphx::shape s{migraphx::shape::float_type, {4, 4}};
auto l = auto l = mm->add_literal(migraphx::literal{s,
mm->add_literal(migraphx::literal{s, {1.1, 1.5, 1.6, -1.1, -1.5, -1.6, 0.0, 2.0, -2.0}}); {-3.51,
mm->add_instruction(migraphx::make_op("round"), l); -3.5,
-3.49,
-2.51,
-2.50,
-2.49,
-1.6,
-1.5,
-0.51,
-0.5,
0.5,
0.6,
2.4,
2.5,
3.5,
4.5}});
mm->add_instruction(migraphx::make_op("nearbyint"), l);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0, 2.0, 2.0, -1.0, -2.0, -2.0, 0.0, 2.0, -2.0}; std::vector<float> gold = {
-4.0, -4.0, -3.0, -3.0, -2.0, -2.0, -2.0, -2.0, -1.0, 0.0, 0.0, 1.0, 2.0, 2.0, 4.0, 4.0};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(round_dyn_test) TEST_CASE(nearbyint_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{4, 10}; migraphx::shape::dynamic_dimension dd{4, 10};
migraphx::shape s{migraphx::shape::float_type, {dd}}; migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("round"), input); mm->add_instruction(migraphx::make_op("nearbyint"), input);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
std::vector<float> input_data{1.1, 1.5, 1.6, -1.1, -1.5, -1.6, 0.0, 2.0, -2.0}; std::vector<float> input_data{-3.51,
-3.5,
-3.49,
-2.51,
-2.50,
-2.49,
-1.6,
-1.5,
-0.51,
-0.5,
0.5,
0.6,
2.4,
2.5,
3.5,
4.5};
migraphx::parameter_map params0; migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {9}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {16}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data()); params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
auto result = p.eval(params0).back(); auto result = p.eval(params0).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0, 2.0, 2.0, -1.0, -2.0, -2.0, 0.0, 2.0, -2.0}; std::vector<float> gold = {
-4.0, -4.0, -3.0, -3.0, -2.0, -2.0, -2.0, -2.0, -1.0, 0.0, 0.0, 1.0, 2.0, 2.0, 4.0, 4.0};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -55,7 +55,7 @@ TEST_CASE(quantizelinear_1) ...@@ -55,7 +55,7 @@ TEST_CASE(quantizelinear_1)
std::vector<float> results_vector(18); std::vector<float> results_vector(18);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{ std::vector<float> gold{
-128, 127, 65, -128, 1, 1, -1, 100, 92, -128, 127, 65, -128, 1, 1, -1, 100, 92}; -128, 127, 64, -128, 1, 1, -1, 100, 92, -128, 127, 64, -128, 1, 1, -1, 100, 92};
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
...@@ -80,6 +80,6 @@ TEST_CASE(quantizelinear_2) ...@@ -80,6 +80,6 @@ TEST_CASE(quantizelinear_2)
auto result = p1.eval({}).back(); auto result = p1.eval({}).back();
std::vector<float> results_vector(18); std::vector<float> results_vector(18);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 255, 65, 0, 2, 2, 0, 255, 255, 0, 255, 65, 0, 2, 2, 0, 255, 255}; std::vector<float> gold{0, 255, 64, 0, 2, 2, 0, 255, 255, 0, 255, 64, 0, 2, 2, 0, 255, 255};
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
...@@ -153,7 +153,7 @@ TEST_CASE(reshape_test2) ...@@ -153,7 +153,7 @@ TEST_CASE(reshape_test2)
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(reshape_dyn_test) TEST_CASE(reshape_dyn_1in_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -173,3 +173,78 @@ TEST_CASE(reshape_dyn_test) ...@@ -173,3 +173,78 @@ TEST_CASE(reshape_dyn_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(reshape_2in_test0)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s_in{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}};
migraphx::shape s_out{migraphx::shape::float_type, {{1, 4}, {6, 6}, {4, 4}, {1, 1}}};
auto input = mm->add_parameter("X", s_in);
auto output_buffer = mm->add_parameter("Y", s_out);
mm->add_instruction(migraphx::make_op("reshape"), input, output_buffer);
p.compile(migraphx::make_target("ref"));
std::vector<float> gold(48);
std::iota(gold.begin(), gold.end(), -3.);
std::vector<float> buffer(48);
std::iota(buffer.begin(), buffer.end(), 0.);
migraphx::parameter_map params;
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {2, 24, 1, 1}};
migraphx::shape output_fixed_shape{migraphx::shape::float_type, {2, 6, 4, 1}};
params["X"] = migraphx::argument(input_fixed_shape, gold.data());
params["Y"] = migraphx::argument(output_fixed_shape, buffer.data());
auto result = p.eval(params).back();
EXPECT(result.get_shape() == output_fixed_shape);
std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(reshape_2in_test1)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s_in{migraphx::shape::float_type, {2, 24, 1, 1}};
migraphx::shape s_out{migraphx::shape::float_type, {{2, 4}, {6, 6}, {2, 4}, {1, 1}}};
auto input = mm->add_parameter("X", s_in);
auto output_buffer = mm->add_parameter("Y", s_out);
mm->add_instruction(migraphx::make_op("reshape"), input, output_buffer);
p.compile(migraphx::make_target("ref"));
std::vector<float> gold(48);
std::iota(gold.begin(), gold.end(), -3.);
std::vector<float> buffer(48);
std::iota(buffer.begin(), buffer.end(), 0.);
migraphx::parameter_map params;
migraphx::shape output_fixed_shape{migraphx::shape::float_type, {2, 6, 4, 1}};
params["X"] = migraphx::argument(s_in, gold.data());
params["Y"] = migraphx::argument(output_fixed_shape, buffer.data());
auto result = p.eval(params).back();
EXPECT(result.get_shape() == output_fixed_shape);
std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(reshape_2in_elements_runtime_error)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s_in{migraphx::shape::float_type, {2, 24, 1, 1}};
migraphx::shape s_out{migraphx::shape::float_type, {{2, 4}, {6, 6}, {2, 4}, {1, 1}}};
auto input = mm->add_parameter("X", s_in);
auto output_buffer = mm->add_parameter("Y", s_out);
mm->add_instruction(migraphx::make_op("reshape"), input, output_buffer);
p.compile(migraphx::make_target("ref"));
std::vector<float> gold(48);
std::iota(gold.begin(), gold.end(), -3.);
std::vector<float> buffer(48);
std::iota(buffer.begin(), buffer.end(), 0.);
migraphx::parameter_map params;
// elements do not match up
migraphx::shape output_fixed_shape{migraphx::shape::float_type, {2, 6, 2, 1}};
params["X"] = migraphx::argument(s_in, gold.data());
params["Y"] = migraphx::argument(output_fixed_shape, buffer.data());
EXPECT(test::throws([&] { std::ignore = p.eval(params).back(); }));
}
...@@ -157,7 +157,169 @@ TEST_CASE(slice_var_inputs_static2) ...@@ -157,7 +157,169 @@ TEST_CASE(slice_var_inputs_static2)
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(slice_var_inputs_dyn) TEST_CASE(slice_var_inputs_dyn0)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::int32_type, {{2, 4, {2, 4}}, {2, 4, {2, 4}}, {3, 8}}};
auto input = mm->add_parameter("input", s0);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto starts = mm->add_parameter("starts", s1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}, {"ends", {10}}}), input, starts);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 3}};
std::vector<int> input_data(2 * 2 * 3);
std::iota(input_data.begin(), input_data.end(), 0);
std::vector<int> start_data = {1};
params["input"] = migraphx::argument(s2, input_data.data());
params["starts"] = migraphx::argument(s1, start_data.data());
auto result = p.eval(params).back();
std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_dyn1)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::int32_type, {{2, 4, {2, 4}}, {2, 4, {2, 4}}, {3, 8}}};
auto input = mm->add_parameter("input", s0);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto ends = mm->add_parameter("ends", s1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}, {"starts", {-5}}}), input, ends);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 3}};
std::vector<int> input_data(2 * 2 * 3);
std::iota(input_data.begin(), input_data.end(), 0);
std::vector<int> ends_data = {3};
params["input"] = migraphx::argument(s2, input_data.data());
params["ends"] = migraphx::argument(s1, ends_data.data());
auto result = p.eval(params).back();
std::vector<int> gold = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
std::vector<int> results_vector(2 * 2 * 3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_dyn2)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::int32_type, {{2, 4, {2, 4}}, {2, 4, {2, 4}}, {3, 8}}};
auto input = mm->add_parameter("input", s0);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto axes = mm->add_parameter("axes", s1);
mm->add_instruction(migraphx::make_op("slice", {{"starts", {1}}, {"ends", {-1}}}), input, axes);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 3}};
std::vector<int> input_data(2 * 2 * 3);
std::iota(input_data.begin(), input_data.end(), 0);
std::vector<int> axes_data = {2};
params["input"] = migraphx::argument(s2, input_data.data());
params["axes"] = migraphx::argument(s1, axes_data.data());
auto result = p.eval(params).back();
std::vector<int> gold = {1, 4, 7, 10};
std::vector<int> results_vector(2 * 2 * 1);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_dyn3)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::int32_type, {{2, 4, {2, 4}}, {2, 4, {2, 4}}, {3, 8}}};
auto input = mm->add_parameter("input", s0);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto starts = mm->add_parameter("starts", s1);
auto ends = mm->add_parameter("ends", s1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}}), input, starts, ends);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 3}};
std::vector<int> input_data(2 * 2 * 3);
std::iota(input_data.begin(), input_data.end(), 0);
std::vector<int> starts_data = {1};
std::vector<int> ends_data = {std::numeric_limits<int>::max()};
params["input"] = migraphx::argument(s2, input_data.data());
params["starts"] = migraphx::argument(s1, starts_data.data());
params["ends"] = migraphx::argument(s1, ends_data.data());
auto result = p.eval(params).back();
std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_dyn4)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::int32_type, {{2, 4, {2, 4}}, {2, 4, {2, 4}}, {3, 8}}};
auto input = mm->add_parameter("input", s0);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto starts = mm->add_parameter("starts", s1);
auto axes = mm->add_parameter("axes", s1);
mm->add_instruction(migraphx::make_op("slice", {{"ends", {std::numeric_limits<int>::max()}}}),
input,
starts,
axes);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 3}};
std::vector<int> input_data(2 * 2 * 3);
std::iota(input_data.begin(), input_data.end(), 0);
std::vector<int> starts_data = {1};
std::vector<int> axes_data = {2};
params["input"] = migraphx::argument(s2, input_data.data());
params["starts"] = migraphx::argument(s1, starts_data.data());
params["axes"] = migraphx::argument(s1, axes_data.data());
auto result = p.eval(params).back();
std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_dyn5)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::int32_type, {{2, 4, {2, 4}}, {2, 4, {2, 4}}, {3, 8}}};
auto input = mm->add_parameter("input", s0);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto ends = mm->add_parameter("ends", s1);
auto axes = mm->add_parameter("axes", s1);
mm->add_instruction(migraphx::make_op("slice", {{"starts", {-4}}}), input, ends, axes);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 3}};
std::vector<int> input_data(2 * 2 * 3);
std::iota(input_data.begin(), input_data.end(), 0);
std::vector<int> ends_data = {2};
std::vector<int> axes_data = {2};
params["input"] = migraphx::argument(s2, input_data.data());
params["ends"] = migraphx::argument(s1, ends_data.data());
params["axes"] = migraphx::argument(s1, axes_data.data());
auto result = p.eval(params).back();
std::vector<int> gold = {0, 1, 3, 4, 6, 7, 9, 10};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_dyn6)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
......
...@@ -54,7 +54,7 @@ struct allocate_no_out : migraphx::auto_register_op<allocate_no_out> ...@@ -54,7 +54,7 @@ struct allocate_no_out : migraphx::auto_register_op<allocate_no_out>
const migraphx::shape& output_shape, const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const const std::vector<migraphx::argument>&) const
{ {
return {output_shape}; return migraphx::argument{output_shape};
} }
}; };
...@@ -78,7 +78,7 @@ struct allocate_with_out : migraphx::auto_register_op<allocate_with_out> ...@@ -78,7 +78,7 @@ struct allocate_with_out : migraphx::auto_register_op<allocate_with_out>
const migraphx::shape& output_shape, const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const const std::vector<migraphx::argument>&) const
{ {
return {output_shape}; return migraphx::argument{output_shape};
} }
}; };
......
...@@ -31,10 +31,13 @@ ...@@ -31,10 +31,13 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <test.hpp> #include <test.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/env.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK_WORKAROUNDS);
bool is_quantizelinear(migraphx::instruction& ins) { return ins.name() == "quantizelinear"; } bool is_quantizelinear(migraphx::instruction& ins) { return ins.name() == "quantizelinear"; }
bool is_dequantizelinear(migraphx::instruction& ins) { return ins.name() == "dequantizelinear"; } bool is_dequantizelinear(migraphx::instruction& ins) { return ins.name() == "dequantizelinear"; }
bool is_clip_scalar(migraphx::instruction& ins) bool is_clip_scalar(migraphx::instruction& ins)
...@@ -82,7 +85,11 @@ TEST_CASE(quantizelinear) ...@@ -82,7 +85,11 @@ TEST_CASE(quantizelinear)
EXPECT(any_of(*p1.get_main_module(), &is_quantizelinear)); EXPECT(any_of(*p1.get_main_module(), &is_quantizelinear));
EXPECT(none_of(*p2.get_main_module(), &is_quantizelinear)); EXPECT(none_of(*p2.get_main_module(), &is_quantizelinear));
// ensure clip literals created in quantized program are scalar // ensure clip literals created in quantized program are scalar
EXPECT(any_of(*p2.get_main_module(), &is_clip_scalar)); // unless CK workarounds are enabled
if(migraphx::enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{}))
EXPECT(none_of(*p2.get_main_module(), &is_clip_scalar));
else
EXPECT(any_of(*p2.get_main_module(), &is_clip_scalar));
} }
TEST_CASE(dequantizelinear) TEST_CASE(dequantizelinear)
......
...@@ -237,4 +237,86 @@ TEST_CASE(const_slice_4input) ...@@ -237,4 +237,86 @@ TEST_CASE(const_slice_4input)
EXPECT(m0 == m1); EXPECT(m0 == m1);
} }
TEST_CASE(static_dimensions_of0)
{
// dead_code_elimination will get rid of atan
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {2, 4, 4}};
auto input = m0.add_parameter("data", s);
auto atan_ins = m0.add_instruction(migraphx::make_op("atan"), input);
auto dimensions_of_ins =
m0.add_instruction(migraphx::make_op("dimensions_of", {{"end", 3}}), atan_ins);
m0.add_return({dimensions_of_ins});
}
run_pass(m0);
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {2, 4, 4}};
m1.add_parameter("data", s);
migraphx::shape lit_shape{migraphx::shape::int64_type, {3}};
std::vector<int64_t> lit_data = {2, 4, 4};
auto lit_ins = m1.add_literal(migraphx::literal{lit_shape, lit_data});
m1.add_return({lit_ins});
}
EXPECT(m0 == m1);
}
TEST_CASE(static_dimensions_of1)
{
// dead_code_elimination will get rid of atan
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {{2, 4, {2, 4}}, {4, 4}, {4, 4}}};
auto input = m0.add_parameter("data", s);
auto atan_ins = m0.add_instruction(migraphx::make_op("atan"), input);
auto dimensions_of_ins = m0.add_instruction(
migraphx::make_op("dimensions_of", {{"start", 1}, {"end", 3}}), atan_ins);
m0.add_return({dimensions_of_ins});
}
run_pass(m0);
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {{2, 4, {2, 4}}, {4, 4}, {4, 4}}};
m1.add_parameter("data", s);
migraphx::shape lit_shape{migraphx::shape::int64_type, {2}};
std::vector<int64_t> lit_data = {4, 4};
auto lit_ins = m1.add_literal(migraphx::literal{lit_shape, lit_data});
m1.add_return({lit_ins});
}
EXPECT(m0 == m1);
}
// Does nothing because the dynamic_dimensions from start to end
// are not all fixed
TEST_CASE(static_dimensions_of_nonfixed)
{
// dead_code_elimination will get rid of atan
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {{2, 4, {2, 4}}, {4, 8}, {4, 8}}};
auto input = m0.add_parameter("data", s);
auto atan_ins = m0.add_instruction(migraphx::make_op("atan"), input);
auto dimensions_of_ins = m0.add_instruction(
migraphx::make_op("dimensions_of", {{"start", 1}, {"end", 3}}), atan_ins);
m0.add_return({dimensions_of_ins});
}
run_pass(m0);
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {{2, 4, {2, 4}}, {4, 8}, {4, 8}}};
auto input = m1.add_parameter("data", s);
auto atan_ins = m1.add_instruction(migraphx::make_op("atan"), input);
auto dimensions_of_ins = m1.add_instruction(
migraphx::make_op("dimensions_of", {{"start", 1}, {"end", 3}}), atan_ins);
m1.add_return({dimensions_of_ins});
}
EXPECT(m0 == m1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -1345,7 +1345,7 @@ TEST_CASE(transpose_contiguous_unsqueeze_unary) ...@@ -1345,7 +1345,7 @@ TEST_CASE(transpose_contiguous_unsqueeze_unary)
auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins); auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins);
auto unsq_ins = auto unsq_ins =
m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), cont_ins); m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), cont_ins);
auto round = m1.add_instruction(migraphx::make_op("round"), unsq_ins); auto round = m1.add_instruction(migraphx::make_op("nearbyint"), unsq_ins);
m1.add_instruction(pass_op{}, round); m1.add_instruction(pass_op{}, round);
} }
run_pass(m1); run_pass(m1);
...@@ -1354,7 +1354,7 @@ TEST_CASE(transpose_contiguous_unsqueeze_unary) ...@@ -1354,7 +1354,7 @@ TEST_CASE(transpose_contiguous_unsqueeze_unary)
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}}); auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto transpose_ins = auto transpose_ins =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x); m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto round = m2.add_instruction(migraphx::make_op("round"), transpose_ins); auto round = m2.add_instruction(migraphx::make_op("nearbyint"), transpose_ins);
auto cont_ins = m2.add_instruction(migraphx::make_op("contiguous"), round); auto cont_ins = m2.add_instruction(migraphx::make_op("contiguous"), round);
auto unsq_ins = auto unsq_ins =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), cont_ins); m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), cont_ins);
......
...@@ -41,11 +41,7 @@ TEST_CASE(make_invalid_target) ...@@ -41,11 +41,7 @@ TEST_CASE(make_invalid_target)
TEST_CASE(targets) TEST_CASE(targets)
{ {
// GCC doesn't load libmigraphx_ref unless necesssary even though it is linked to the test.
// Force it to load by making ref target
#if defined(__GNUC__) && !defined(__clang__)
auto ref_target = migraphx::make_target("ref"); auto ref_target = migraphx::make_target("ref");
#endif
auto ts = migraphx::get_targets(); auto ts = migraphx::get_targets();
EXPECT(ts.size() >= 1); EXPECT(ts.size() >= 1);
} }
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
auto m2_elements = m2_shape.elements();
auto a = mm->add_parameter("1", m1_shape);
auto b = mm->add_parameter("2", m1_shape);
auto b1 = mm->add_parameter("3", m1_shape);
std::vector<float> eights(m2_elements, 0.125);
auto eight = mm->add_literal(migraphx::literal{m2_shape, eights});
std::vector<float> zeros(m2_elements, 0);
auto zero = mm->add_literal(migraphx::literal{m2_shape, zeros});
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero);
auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), bias);
mm->add_instruction(migraphx::make_op("dot"), softmax, b1);
return p;
}
};
...@@ -27,16 +27,21 @@ ...@@ -27,16 +27,21 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_round : verify_program<test_round> struct gemm_2args_mm_8 : verify_program<gemm_2args_mm_8>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type, {2, 128, 32}, {4096, 1, 128}};
migraphx::shape b_shape{migraphx::shape::float_type, {32, 32}};
auto a = mm->add_parameter("a", a_shape);
auto b = mm->add_parameter("b", b_shape);
auto bb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 32, 32}}}), b);
mm->add_instruction(migraphx::make_op("dot"), a, bb);
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}};
auto param = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("round"), param);
return p; return p;
}; }
}; };
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
struct quant_conv_default_mode : verify_program<quant_conv_default_mode> struct quant_conv_1 : verify_program<quant_conv_1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
struct quant_conv_int8x4_default : verify_program<quant_conv_int8x4_default> struct quant_conv_2 : verify_program<quant_conv_2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
......
...@@ -44,8 +44,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DUMP_TEST) ...@@ -44,8 +44,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DUMP_TEST)
// An improved async, that doesn't block // An improved async, that doesn't block
template <class Function> template <class Function>
std::future<typename std::result_of<Function()>::type> detach_async(Function&& f, std::future<std::invoke_result_t<Function>> detach_async(Function&& f, bool parallel = true)
bool parallel = true)
{ {
if(parallel) if(parallel)
{ {
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -29,8 +29,8 @@ ...@@ -29,8 +29,8 @@
#include <migraphx/op/argmax.hpp> #include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp> #include <migraphx/op/argmin.hpp>
template <class T, int Axis, int NonStdShape> template <class T, int Axis, bool LastIndex, int NonStdShape>
struct test_arg_ops : verify_program<test_arg_ops<T, Axis, NonStdShape>> struct test_arg_ops : verify_program<test_arg_ops<T, Axis, LastIndex, NonStdShape>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -54,63 +54,111 @@ struct test_arg_ops : verify_program<test_arg_ops<T, Axis, NonStdShape>> ...@@ -54,63 +54,111 @@ struct test_arg_ops : verify_program<test_arg_ops<T, Axis, NonStdShape>>
break; break;
default: break; default: break;
} }
mm->add_instruction(T{Axis}, param); mm->add_instruction(T{Axis, LastIndex}, param);
return p; return p;
} }
}; };
// transpose argmax tests // transpose argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, 0>; template struct test_arg_ops<migraphx::op::argmax, 0, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, 1, 0>; template struct test_arg_ops<migraphx::op::argmax, 0, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, 2, 0>; template struct test_arg_ops<migraphx::op::argmax, 1, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, 3, 0>; template struct test_arg_ops<migraphx::op::argmax, 1, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, -1, 0>; template struct test_arg_ops<migraphx::op::argmax, 2, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, -2, 0>; template struct test_arg_ops<migraphx::op::argmax, 2, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, 3, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, 3, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, -1, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, -1, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, -2, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, -2, false, 0>;
// transpose argmin tests // transpose argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, 0>; template struct test_arg_ops<migraphx::op::argmin, 0, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, 1, 0>; template struct test_arg_ops<migraphx::op::argmin, 0, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, 2, 0>; template struct test_arg_ops<migraphx::op::argmin, 1, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, 3, 0>; template struct test_arg_ops<migraphx::op::argmin, 1, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, -3, 0>; template struct test_arg_ops<migraphx::op::argmin, 2, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, -4, 0>; template struct test_arg_ops<migraphx::op::argmin, 2, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, 3, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, 3, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, -3, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, -3, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, -4, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, -4, false, 0>;
// broadcast argmax tests // broadcast argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, 1>; template struct test_arg_ops<migraphx::op::argmax, 0, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, 1, 1>; template struct test_arg_ops<migraphx::op::argmax, 0, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, 2, 1>; template struct test_arg_ops<migraphx::op::argmax, 1, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, 3, 1>; template struct test_arg_ops<migraphx::op::argmax, 1, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, -1, 1>; template struct test_arg_ops<migraphx::op::argmax, 2, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, -2, 1>; template struct test_arg_ops<migraphx::op::argmax, 2, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, 3, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, 3, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, -1, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, -1, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, -2, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, -2, false, 1>;
// broadcast argmin tests // broadcast argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, 1>; template struct test_arg_ops<migraphx::op::argmin, 0, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, 1, 1>; template struct test_arg_ops<migraphx::op::argmin, 0, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, 2, 1>; template struct test_arg_ops<migraphx::op::argmin, 1, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, 3, 1>; template struct test_arg_ops<migraphx::op::argmin, 1, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, -3, 1>; template struct test_arg_ops<migraphx::op::argmin, 2, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, -4, 1>; template struct test_arg_ops<migraphx::op::argmin, 2, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, 3, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, 3, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, -3, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, -3, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, -4, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, -4, false, 1>;
// slice argmax tests // slice argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, 2>; template struct test_arg_ops<migraphx::op::argmax, 0, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, 1, 2>; template struct test_arg_ops<migraphx::op::argmax, 0, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, 2, 2>; template struct test_arg_ops<migraphx::op::argmax, 1, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, 3, 2>; template struct test_arg_ops<migraphx::op::argmax, 1, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, -1, 2>; template struct test_arg_ops<migraphx::op::argmax, 2, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, -2, 2>; template struct test_arg_ops<migraphx::op::argmax, 2, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, 3, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, 3, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, -1, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, -1, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, -2, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, -2, false, 2>;
// slice argmin tests // slice argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, 2>; template struct test_arg_ops<migraphx::op::argmin, 0, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, 1, 2>; template struct test_arg_ops<migraphx::op::argmin, 0, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, 2, 2>; template struct test_arg_ops<migraphx::op::argmin, 1, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, 3, 2>; template struct test_arg_ops<migraphx::op::argmin, 1, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, -3, 2>; template struct test_arg_ops<migraphx::op::argmin, 2, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, -4, 2>; template struct test_arg_ops<migraphx::op::argmin, 2, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, 3, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, 3, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, -3, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, -3, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, -4, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, -4, false, 2>;
// default case, standard shape argmax tests // default case, standard shape argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, 3>; template struct test_arg_ops<migraphx::op::argmax, 0, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, 1, 3>; template struct test_arg_ops<migraphx::op::argmax, 0, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, 2, 3>; template struct test_arg_ops<migraphx::op::argmax, 1, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, 3, 3>; template struct test_arg_ops<migraphx::op::argmax, 1, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, -1, 3>; template struct test_arg_ops<migraphx::op::argmax, 2, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, -2, 3>; template struct test_arg_ops<migraphx::op::argmax, 2, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, 3, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, 3, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, -1, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, -1, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, -2, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, -2, false, 3>;
// default case, standard shape argmin tests // default case, standard shape argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, 3>; template struct test_arg_ops<migraphx::op::argmin, 0, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, 1, 3>; template struct test_arg_ops<migraphx::op::argmin, 0, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, 2, 3>; template struct test_arg_ops<migraphx::op::argmin, 1, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, 3, 3>; template struct test_arg_ops<migraphx::op::argmin, 1, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, -3, 3>; template struct test_arg_ops<migraphx::op::argmin, 2, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, -4, 3>; template struct test_arg_ops<migraphx::op::argmin, 2, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, 3, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, 3, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, -3, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, -3, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, -4, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, -4, false, 3>;
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_flatten_dot_relu : verify_program<test_flatten_dot_relu>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto a =
mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 3, 5}});
a = mm->add_instruction(migraphx::make_op("flatten", {{"axis", 3}}), a);
auto b =
mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 5, 3, 3, 1}});
b = mm->add_instruction(migraphx::make_op("flatten", {{"axis", 3}}), b);
auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b);
mm->add_instruction(migraphx::make_op("relu"), dot);
return p;
}
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <limits>
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
template <class T>
struct test_isinf : verify_program<test_isinf<T>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto max = std::numeric_limits<T>::max();
auto min = std::numeric_limits<T>::min();
auto inf = std::numeric_limits<T>::infinity();
auto nan = std::numeric_limits<T>::quiet_NaN();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::get_type<T>(), {5}});
std::vector<T> data0{inf, -inf, max, min, nan};
migraphx::shape s1{migraphx::shape::get_type<T>(), {5}};
auto l0 = mm->add_literal(migraphx::literal{s1, data0});
x = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, l0);
mm->add_instruction(migraphx::make_op("isinf"), x);
return p;
}
};
template struct test_isinf<migraphx::half>;
template struct test_isinf<float>;
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <limits>
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_isinf_broadcast : verify_program<test_isinf_broadcast>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2}});
auto s0 = migraphx::shape{migraphx::shape::float_type, {2, 2}};
x = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", s0.lens()}}), x);
auto inf = std::numeric_limits<float>::infinity();
std::vector<float> data0{-inf, inf};
migraphx::shape s1{migraphx::shape::float_type, {1, 2}};
auto l0 = mm->add_literal(migraphx::literal{s1, data0});
x = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, l0);
mm->add_instruction(migraphx::make_op("isinf"), x);
return p;
}
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
template <class T>
struct test_nearbyint : verify_program<test_nearbyint<T>>
{
migraphx::program create_program() const
{
migraphx::program p;
std::vector<float> tmp{-4.5, -3.5, 0.5, 2.5, 3.5};
std::vector<T> data{tmp.cbegin(), tmp.cend()};
migraphx::shape s1{migraphx::shape::get_type<T>(), {5}};
auto* mm = p.get_main_module();
auto l0 = mm->add_literal(migraphx::literal{s1, data});
mm->add_instruction(migraphx::make_op("isinf"), l0);
return p;
};
};
template struct test_nearbyint<migraphx::half>;
template struct test_nearbyint<float>;
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_scatter_nonstandard_shape : verify_program<test_scatter_nonstandard_shape>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {3, 1, 3}, {1, 3, 2}};
migraphx::shape si{migraphx::shape::int32_type, {2, 1, 3}, {1, 3, 2}};
std::vector<int> vi = {1, 0, 2, 0, 2, 1};
migraphx::shape su{migraphx::shape::float_type, {2, 1, 3}, {1, 2, 3}};
auto pd = mm->add_parameter("data", sd);
auto li = mm->add_literal(migraphx::literal{si, vi});
auto pu = mm->add_parameter("update", su);
auto r = mm->add_instruction(migraphx::make_op("scatter_none", {{"axis", -1}}), pd, li, pu);
mm->add_return({r});
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment