Commit 7e297b13 authored by Paul's avatar Paul
Browse files

Merge

parents 86ea5e91 aa7ff911
......@@ -14,10 +14,10 @@ struct test_acosh : verify_program<test_acosh>
auto x = mm->add_parameter("x", s);
auto min_val = mm->add_literal(1.1f);
auto max_val = mm->add_literal(100.0f);
min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {16}}}),
min_val);
max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {16}}}),
max_val);
min_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), min_val);
max_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), max_val);
auto cx = mm->add_instruction(migraphx::make_op("clip"), x, min_val, max_val);
mm->add_instruction(migraphx::make_op("acosh"), cx);
return p;
......
......@@ -16,7 +16,7 @@ struct test_add_broadcast : verify_program<test_add_broadcast>
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}});
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {2, 2}});
auto by = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", x->get_shape().lens()}}), y);
migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", x->get_shape().lens()}}), y);
mm->add_instruction(migraphx::make_op("add"), x, by);
return p;
}
......
......@@ -16,7 +16,7 @@ struct test_add_broadcast2 : verify_program<test_add_broadcast2>
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 3, 4}});
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {3}});
auto by = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", x->get_shape().lens()}}), y);
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", x->get_shape().lens()}}), y);
mm->add_instruction(migraphx::make_op("add"), x, by);
return p;
}
......
......@@ -16,7 +16,7 @@ struct test_add_broadcast3 : verify_program<test_add_broadcast3>
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 4, 5}});
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {4}});
auto by = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", x->get_shape().lens()}}), y);
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", x->get_shape().lens()}}), y);
mm->add_instruction(migraphx::make_op("add"), x, by);
return p;
}
......
......@@ -16,7 +16,7 @@ struct test_add_broadcast4 : verify_program<test_add_broadcast4>
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 3, 5}});
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {3}});
auto by = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", x->get_shape().lens()}}), y);
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", x->get_shape().lens()}}), y);
mm->add_instruction(migraphx::make_op("add"), x, by);
return p;
}
......
......@@ -16,7 +16,7 @@ struct test_add_broadcast5 : verify_program<test_add_broadcast5>
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 4, 8}});
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {4}});
auto by = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", x->get_shape().lens()}}), y);
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", x->get_shape().lens()}}), y);
mm->add_instruction(migraphx::make_op("add"), x, by);
return p;
}
......
......@@ -15,7 +15,7 @@ struct test_add_broadcast6 : verify_program<test_add_broadcast6>
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 64, 568, 1328}});
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {64}});
auto by = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 64, 568, 1328}}}), y);
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 64, 568, 1328}}}), y);
mm->add_instruction(migraphx::make_op("add"), x, by);
return p;
}
......
......@@ -18,14 +18,14 @@ struct test_add_gelu : verify_program<test_add_gelu>
auto sqrt2 = mm->add_literal(static_cast<float>(M_SQRT2));
auto add = mm->add_instruction(migraphx::make_op("add"), x, y);
auto half_mbcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), half);
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), half);
auto mul_half = mm->add_instruction(migraphx::make_op("mul"), add, half_mbcast);
auto sqrt2_mbcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), sqrt2);
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), sqrt2);
auto div = mm->add_instruction(migraphx::make_op("div"), add, sqrt2_mbcast);
auto erf = mm->add_instruction(migraphx::make_op("erf"), div);
auto one_mbcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), one);
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), one);
auto add_one = mm->add_instruction(migraphx::make_op("add"), erf, one_mbcast);
mm->add_instruction(migraphx::make_op("mul"), mul_half, add_one);
return p;
......
......@@ -2,34 +2,92 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
template <class T, int Axis>
struct test_arg_ops : verify_program<test_arg_ops<T, Axis>>
template <class T, int Axis, int NonStdShape>
struct test_arg_ops : verify_program<test_arg_ops<T, Axis, NonStdShape>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 1025}};
migraphx::shape s{migraphx::shape::float_type, {2, 1, 4, 1025}};
auto param = mm->add_parameter("data", s);
switch(NonStdShape)
{
case 0:
param = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), param);
break;
case 1:
param = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 1025}}}), param);
break;
case 2:
param = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1}}, {"ends", {3}}}), param);
break;
default: break;
}
mm->add_instruction(T{Axis}, param);
return p;
}
};
template struct test_arg_ops<migraphx::op::argmax, 0>;
template struct test_arg_ops<migraphx::op::argmax, 1>;
template struct test_arg_ops<migraphx::op::argmax, 2>;
template struct test_arg_ops<migraphx::op::argmax, 3>;
template struct test_arg_ops<migraphx::op::argmax, -1>;
template struct test_arg_ops<migraphx::op::argmax, -2>;
template struct test_arg_ops<migraphx::op::argmin, 0>;
template struct test_arg_ops<migraphx::op::argmin, 1>;
template struct test_arg_ops<migraphx::op::argmin, 2>;
template struct test_arg_ops<migraphx::op::argmin, 3>;
template struct test_arg_ops<migraphx::op::argmin, -3>;
template struct test_arg_ops<migraphx::op::argmin, -4>;
// transpose argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, 0>;
template struct test_arg_ops<migraphx::op::argmax, 1, 0>;
template struct test_arg_ops<migraphx::op::argmax, 2, 0>;
template struct test_arg_ops<migraphx::op::argmax, 3, 0>;
template struct test_arg_ops<migraphx::op::argmax, -1, 0>;
template struct test_arg_ops<migraphx::op::argmax, -2, 0>;
// transpose argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, 0>;
template struct test_arg_ops<migraphx::op::argmin, 1, 0>;
template struct test_arg_ops<migraphx::op::argmin, 2, 0>;
template struct test_arg_ops<migraphx::op::argmin, 3, 0>;
template struct test_arg_ops<migraphx::op::argmin, -3, 0>;
template struct test_arg_ops<migraphx::op::argmin, -4, 0>;
// broadcast argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, 1>;
template struct test_arg_ops<migraphx::op::argmax, 1, 1>;
template struct test_arg_ops<migraphx::op::argmax, 2, 1>;
template struct test_arg_ops<migraphx::op::argmax, 3, 1>;
template struct test_arg_ops<migraphx::op::argmax, -1, 1>;
template struct test_arg_ops<migraphx::op::argmax, -2, 1>;
// broadcast argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, 1>;
template struct test_arg_ops<migraphx::op::argmin, 1, 1>;
template struct test_arg_ops<migraphx::op::argmin, 2, 1>;
template struct test_arg_ops<migraphx::op::argmin, 3, 1>;
template struct test_arg_ops<migraphx::op::argmin, -3, 1>;
template struct test_arg_ops<migraphx::op::argmin, -4, 1>;
// slice argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, 2>;
template struct test_arg_ops<migraphx::op::argmax, 1, 2>;
template struct test_arg_ops<migraphx::op::argmax, 2, 2>;
template struct test_arg_ops<migraphx::op::argmax, 3, 2>;
template struct test_arg_ops<migraphx::op::argmax, -1, 2>;
template struct test_arg_ops<migraphx::op::argmax, -2, 2>;
// slice argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, 2>;
template struct test_arg_ops<migraphx::op::argmin, 1, 2>;
template struct test_arg_ops<migraphx::op::argmin, 2, 2>;
template struct test_arg_ops<migraphx::op::argmin, 3, 2>;
template struct test_arg_ops<migraphx::op::argmin, -3, 2>;
template struct test_arg_ops<migraphx::op::argmin, -4, 2>;
// default case, standard shape argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, 3>;
template struct test_arg_ops<migraphx::op::argmax, 1, 3>;
template struct test_arg_ops<migraphx::op::argmax, 2, 3>;
template struct test_arg_ops<migraphx::op::argmax, 3, 3>;
template struct test_arg_ops<migraphx::op::argmax, -1, 3>;
template struct test_arg_ops<migraphx::op::argmax, -2, 3>;
// default case, standard shape argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, 3>;
template struct test_arg_ops<migraphx::op::argmin, 1, 3>;
template struct test_arg_ops<migraphx::op::argmin, 2, 3>;
template struct test_arg_ops<migraphx::op::argmin, 3, 3>;
template struct test_arg_ops<migraphx::op::argmin, -3, 3>;
template struct test_arg_ops<migraphx::op::argmin, -4, 3>;
......@@ -14,10 +14,10 @@ struct test_atanh : verify_program<test_atanh>
auto x = mm->add_parameter("x", s);
auto min_val = mm->add_literal(-0.95f);
auto max_val = mm->add_literal(0.95f);
min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {16}}}),
min_val);
max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {16}}}),
max_val);
min_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), min_val);
max_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), max_val);
auto cx = mm->add_instruction(migraphx::make_op("clip"), x, min_val, max_val);
mm->add_instruction(migraphx::make_op("atanh"), cx);
return p;
......
......@@ -12,7 +12,7 @@ struct test_avg_pooling_1d : verify_program<test_avg_pooling_1d>
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5}});
auto op = migraphx::op::pooling{"average", {0}, {1}, {3}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average, {0}, {1}, {3}};
mm->add_instruction(op, input);
return p;
}
......
......@@ -12,7 +12,8 @@ struct test_avg_pooling_3d : verify_program<test_avg_pooling_3d>
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5, 5, 5}});
auto op = migraphx::op::pooling{"average", {1, 1, 1}, {3, 3, 3}, {3, 3, 3}};
auto op = migraphx::op::pooling{
migraphx::op::pooling_mode::average, {1, 1, 1}, {3, 3, 3}, {3, 3, 3}};
mm->add_instruction(op, input);
return p;
}
......
......@@ -12,7 +12,8 @@ struct test_avg_pooling_3d_opt : verify_program<test_avg_pooling_3d_opt>
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 2, 3, 3, 3}});
auto op = migraphx::op::pooling{"average", {0, 0, 0}, {1, 1, 1}, {3, 3, 3}};
auto op = migraphx::op::pooling{
migraphx::op::pooling_mode::average, {0, 0, 0}, {1, 1, 1}, {3, 3, 3}};
mm->add_instruction(op, input);
return p;
}
......
......@@ -10,9 +10,11 @@ struct test_avg_pooling_ceil_3d : verify_program<test_avg_pooling_ceil_3d>
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5, 5, 5}});
auto op = migraphx::op::pooling{"average", {1, 1, 1}, {3, 3, 3}, {3, 3, 3}, true};
auto op = migraphx::op::pooling{
migraphx::op::pooling_mode::average, {1, 1, 1}, {3, 3, 3}, {3, 3, 3}, true};
mm->add_instruction(op, input);
return p;
}
......
......@@ -13,10 +13,10 @@ struct test_clip : verify_program<test_clip>
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3}});
auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f);
min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3}}}),
min_val);
max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3}}}),
max_val);
min_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), min_val);
max_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), max_val);
mm->add_instruction(migraphx::make_op("clip"), x, min_val, max_val);
return p;
}
......
......@@ -3,6 +3,7 @@
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/pooling.hpp>
struct test_concat_pooling : verify_program<test_concat_pooling>
{
......@@ -12,18 +13,19 @@ struct test_concat_pooling : verify_program<test_concat_pooling>
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 256, 8, 8}});
auto transpose =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), input);
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 3}}), transpose);
auto concat_t =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 3, 1, 2}}}), concat);
auto transpose = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), input);
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 3}}), transpose);
auto concat_t = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), concat);
auto pooling = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{"padding", {0, 0}},
{"stride", {1, 1}},
{"lengths", {8, 8}}}),
concat_t);
auto pooling =
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0}},
{"stride", {1, 1}},
{"lengths", {8, 8}}}),
concat_t);
mm->add_instruction(migraphx::make_op("relu"), pooling);
return p;
}
......
......@@ -16,8 +16,9 @@ struct test_concat_transpose : verify_program<test_concat_transpose>
migraphx::shape s2{migraphx::shape::int32_type, {2, 4}};
auto l0 = mm->add_parameter("x", s0);
auto lp1 = mm->add_parameter("y", s1);
auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), lp1);
auto l2 = mm->add_parameter("z", s2);
auto l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), lp1);
auto l2 = mm->add_parameter("z", s2);
mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1, l2);
return p;
}
......
......@@ -17,7 +17,8 @@ struct test_concat_transpose2 : verify_program<test_concat_transpose2>
auto l0 = mm->add_parameter("x", s0);
auto l1 = mm->add_parameter("y", s1);
auto lp2 = mm->add_parameter("z", s2);
auto l2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), lp2);
auto l2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), lp2);
mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1, l2);
return p;
}
......
......@@ -16,9 +16,11 @@ struct test_concat_transpose3 : verify_program<test_concat_transpose3>
migraphx::shape s2{migraphx::shape::int32_type, {5, 2}};
auto l0 = mm->add_parameter("x", s0);
auto lp1 = mm->add_parameter("y", s1);
auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), lp1);
auto l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), lp1);
auto lp2 = mm->add_parameter("z", s2);
auto l2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), lp2);
auto l2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), lp2);
mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1, l2);
return p;
}
......
......@@ -12,7 +12,6 @@ struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu>
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> input_lens{4, 3, 3, 3};
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto weights =
......@@ -22,15 +21,15 @@ struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu>
auto bias = mm->add_literal(l0);
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto bcast_add = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", conv->get_shape().lens()}}),
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}),
bias);
auto bias_add = mm->add_instruction(migraphx::make_op("add"), conv, bcast_add);
auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f);
min_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
migraphx::make_op("multibroadcast", {{"out_lens", conv->get_shape().lens()}}), min_val);
max_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
migraphx::make_op("multibroadcast", {{"out_lens", conv->get_shape().lens()}}), max_val);
mm->add_instruction(migraphx::make_op("clip"), bias_add, min_val, max_val);
return p;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment