Commit baac1dab authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/develop' into ck-host-lib

parents 830dff7a 77042e30
......@@ -23,7 +23,7 @@
*/
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/program.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/instruction.hpp>
......@@ -33,12 +33,20 @@
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/pass_manager.hpp>
bool is_quantizelinear(migraphx::instruction& ins) { return ins.name() == "quantizelinear"; }
bool is_dequantizelinear(migraphx::instruction& ins) { return ins.name() == "dequantizelinear"; }
void run_pass(migraphx::module& m) { migraphx::run_passes(m, {migraphx::rewrite_quantization{}}); }
migraphx::argument eval(const migraphx::program& p)
{
auto r = p.eval({});
EXPECT(r.size() == 1);
return r.front();
}
TEST_CASE(quantizelinear)
{
......@@ -58,8 +66,8 @@ TEST_CASE(quantizelinear)
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
migraphx::rewrite_quantization opt;
opt.apply(*p2.get_main_module());
run_pass(*p2.get_main_module());
EXPECT(eval(p1) == eval(p2));
EXPECT(any_of(*p1.get_main_module(), &is_quantizelinear));
EXPECT(none_of(*p2.get_main_module(), &is_quantizelinear));
}
......@@ -71,8 +79,8 @@ TEST_CASE(dequantizelinear)
std::vector<float> xv = {0, 1, 2, 5, 10, 50, 100, 150, 250};
migraphx::shape ss{migraphx::shape::float_type, {1, 3, 3}};
std::vector<float> sv = {2, 2, 2, 2, 2, 2, 2, 2, 2};
migraphx::shape zs{migraphx::shape::uint8_type, {1, 3, 3}};
std::vector<uint8_t> zv = {0, 0, 0, 0, 0, 0, 0, 0, 0};
migraphx::shape zs{migraphx::shape::float_type, {1, 3, 3}};
std::vector<float> zv = {0, 0, 0, 0, 0, 0, 0, 0, 0};
auto create_program = [&]() {
migraphx::program p;
auto* mm = p.get_main_module();
......@@ -86,8 +94,8 @@ TEST_CASE(dequantizelinear)
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
migraphx::rewrite_quantization opt;
opt.apply(*p2.get_main_module());
run_pass(*p2.get_main_module());
EXPECT(eval(p1) == eval(p2));
EXPECT(any_of(*p1.get_main_module(), &is_dequantizelinear));
EXPECT(none_of(*p2.get_main_module(), &is_dequantizelinear));
}
......
......@@ -27,7 +27,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/make_op.hpp>
......@@ -207,7 +207,7 @@ static auto run_prog(migraphx::program p, int64_t iter_num, bool cond, int64_t i
migraphx::shape s{migraphx::shape::int64_type, {1}};
migraphx::shape sc{migraphx::shape::bool_type};
p.compile(migraphx::ref::target{});
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map pp;
pp["iter_num"] = migraphx::argument(si, &iter_num);
pp["ccond"] = migraphx::argument(sc, &cond);
......
......@@ -22,7 +22,7 @@
* THE SOFTWARE.
*/
#include <migraphx/program.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/load_save.hpp>
#include "test.hpp"
#include <migraphx/make_op.hpp>
......@@ -82,7 +82,7 @@ TEST_CASE(as_file)
TEST_CASE(compiled)
{
migraphx::program p1 = create_program();
p1.compile(migraphx::ref::target{});
p1.compile(migraphx::make_target("ref"));
std::vector<char> buffer = migraphx::save_buffer(p1);
migraphx::program p2 = migraphx::load_buffer(buffer);
EXPECT(p1.sort() == p2.sort());
......
......@@ -61,6 +61,8 @@ struct reflectable_type
}
};
std::vector<nested_type> nested_types = {};
std::tuple<int, nested_type, std::string> tuple_items = std::make_tuple(0, nested_type{0}, "");
migraphx::optional<int> opt_value = migraphx::nullopt;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -71,7 +73,8 @@ struct reflectable_type
f(self.et, "et"),
f(self.se, "se"),
f(self.ce, "ce"),
f(self.nested_types, "nested_types"));
f(self.nested_types, "nested_types"),
f(self.tuple_items, "tuple_items"));
}
};
......@@ -83,7 +86,9 @@ TEST_CASE(serialize_reflectable_type)
{},
reflectable_type::simple1,
reflectable_type::class_enum::class2,
{{1}, {2}}};
{{1}, {2}},
{5, {4}, "hello"},
{migraphx::nullopt}};
migraphx::value v1 = migraphx::to_value(t1);
reflectable_type t2 = migraphx::from_value<reflectable_type>(v1);
migraphx::value v2 = migraphx::to_value(t2);
......@@ -125,6 +130,21 @@ TEST_CASE(serialize_empty_struct)
EXPECT(v.at("a").to<int>() == 1);
}
TEST_CASE(serialize_empty_optional)
{
migraphx::optional<int> x{};
migraphx::value v = migraphx::to_value(x);
EXPECT(v.is_null());
}
TEST_CASE(serialize_optional)
{
migraphx::optional<int> x{2};
migraphx::value v = migraphx::to_value(x);
EXPECT(v.is_int64());
EXPECT(v.to<int>() == 2);
}
TEST_CASE(from_value_binary)
{
std::vector<std::uint8_t> data(10);
......
......@@ -30,6 +30,7 @@
#include <array>
#include <algorithm>
#include <numeric>
#include <migraphx/verify.hpp>
#include "test.hpp"
TEST_CASE(test_shape_default)
......@@ -41,22 +42,13 @@ TEST_CASE(test_shape_default)
TEST_CASE(test_dyn_4arg_constructor)
{
migraphx::shape s{migraphx::shape::float_type,
{
1,
4,
4,
},
{
4,
4,
4,
},
{0, 0, 0}};
std::vector<migraphx::shape::dynamic_dimension> expected_dyn_dims = {
{1, 4, 0}, {4, 4, 0}, {4, 4, 0}};
EXPECT(s.dynamic());
EXPECT(s.dyn_dims() == expected_dyn_dims);
migraphx::shape s0{migraphx::shape::float_type, {1, 4, 4}, {4, 4, 4}, {{}, {}, {}}};
migraphx::shape s1{migraphx::shape::float_type, {1, 4, 4}, {4, 4, 4}, {}};
std::vector<migraphx::shape::dynamic_dimension> expected_dyn_dims = {{1, 4}, {4, 4}, {4, 4}};
EXPECT(s0.dynamic());
EXPECT(s0.dyn_dims() == expected_dyn_dims);
EXPECT(s1.dynamic());
EXPECT(s1.dyn_dims() == expected_dyn_dims);
}
TEST_CASE(test_shape_assign)
......@@ -85,17 +77,26 @@ TEST_CASE(test_shape_standard)
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_standard_singleton_dim)
{
migraphx::shape s{migraphx::shape::float_type, {5, 1, 8}, {8, 4, 1}};
EXPECT(s.standard());
EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_min_max_opt)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 1}};
EXPECT(s.min_lens() == s.lens());
EXPECT(s.max_lens() == s.lens());
EXPECT(s.opt_lens() == s.lens());
EXPECT(s.opt_lens().empty());
}
TEST_CASE(test_shape_dynamic_fixed)
{
migraphx::shape s{migraphx::shape::float_type, {{2, 2, 0}, {2, 2, 0}, {3, 3, 0}}};
migraphx::shape s{migraphx::shape::float_type, {{2, 2}, {2, 2}, {3, 3}}};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
......@@ -106,7 +107,8 @@ TEST_CASE(test_shape_dynamic_fixed)
EXPECT(not s.dyn_dims().at(0).has_optimal());
EXPECT(s.min_lens() == std::vector<std::size_t>{2, 2, 3});
EXPECT(s.max_lens() == std::vector<std::size_t>{2, 2, 3});
EXPECT(s.opt_lens() == std::vector<std::size_t>{0, 0, 0});
std::vector<std::set<std::size_t>> e_opt_lens = {{}, {}, {}};
EXPECT(s.opt_lens() == e_opt_lens);
EXPECT(s.bytes() == 2 * 2 * 3 * sizeof(float));
}
......@@ -114,8 +116,8 @@ TEST_CASE(test_shape_dynamic_not_fixed)
{
using migraphx::shape;
std::vector<shape::dynamic_dimension> dims = {};
dims.push_back(shape::dynamic_dimension{2, 5, 2});
dims.push_back(shape::dynamic_dimension{2, 8, 0});
dims.push_back(shape::dynamic_dimension{2, 5, {2}});
dims.push_back(shape::dynamic_dimension{2, 8});
migraphx::shape s{migraphx::shape::float_type, dims};
EXPECT(not s.standard());
EXPECT(not s.packed());
......@@ -127,18 +129,16 @@ TEST_CASE(test_shape_dynamic_not_fixed)
EXPECT(s.dyn_dims().at(0).has_optimal());
EXPECT(s.min_lens() == std::vector<std::size_t>{2, 2});
EXPECT(s.max_lens() == std::vector<std::size_t>{5, 8});
EXPECT(s.opt_lens() == std::vector<std::size_t>{2, 0});
EXPECT(s.opt_lens() == std::vector<std::set<std::size_t>>{{2}, {}});
EXPECT(s.bytes() == 5 * 8 * sizeof(float));
}
TEST_CASE(test_shape_dynamic_compares)
{
using migraphx::shape;
auto a = shape::dynamic_dimension{2, 5, 2};
auto b = a;
auto c = shape::dynamic_dimension{2, 5, 2};
auto d = shape::dynamic_dimension{3, 8, 4};
EXPECT(a == b);
auto a = shape::dynamic_dimension{2, 5, {2}};
auto c = shape::dynamic_dimension{2, 5, {2}};
auto d = shape::dynamic_dimension{3, 8};
EXPECT(a == c);
EXPECT(a != d);
......@@ -163,13 +163,13 @@ TEST_CASE(test_shape_dynamic_compares)
TEST_CASE(dynamic_dimension_size_t_compares)
{
using migraphx::shape;
auto a = shape::dynamic_dimension{2, 2, 2};
auto a = shape::dynamic_dimension{2, 2, {2}};
EXPECT(a == 2);
EXPECT(a != 3);
EXPECT(static_cast<std::size_t>(2) == a);
EXPECT(static_cast<std::size_t>(3) != a);
auto b = shape::dynamic_dimension{2, 4, 0};
auto b = shape::dynamic_dimension{2, 4};
EXPECT(b != 2);
EXPECT(static_cast<std::size_t>(2) != b);
}
......@@ -177,36 +177,50 @@ TEST_CASE(dynamic_dimension_size_t_compares)
TEST_CASE(dynamic_dimension_add_sub_fixed)
{
using migraphx::shape;
auto a = shape::dynamic_dimension{2, 5, 2};
auto a = shape::dynamic_dimension{2, 5, {2}};
a += 3;
EXPECT(a == shape::dynamic_dimension{5, 8, 5});
EXPECT(a == shape::dynamic_dimension{5, 8, {5}});
a -= 3;
EXPECT(a == shape::dynamic_dimension{2, 5, 2});
EXPECT(a == shape::dynamic_dimension{2, 5, {2}});
auto b = shape::dynamic_dimension{3, 6, 3};
auto b = shape::dynamic_dimension{3, 6, {3}};
EXPECT((a + 1) == b);
EXPECT((1 + a) == b);
EXPECT((b - 1) == a);
auto c = shape::dynamic_dimension{4, 7, 4};
auto c = shape::dynamic_dimension{4, 7, {4}};
EXPECT((a + 2) == c);
EXPECT((2 + a) == c);
EXPECT((c - 2) == a);
auto d = shape::dynamic_dimension{4, 8, 0};
auto e = shape::dynamic_dimension{2, 6, 0};
auto d = shape::dynamic_dimension{4, 8};
auto e = shape::dynamic_dimension{2, 6};
EXPECT((d - 2) == e);
EXPECT((e + 2) == d);
EXPECT((2 + e) == d);
}
TEST_CASE(dynamic_dimension_serialize)
{
using migraphx::shape;
auto a = shape::dynamic_dimension{2, 5, {2, 3}};
auto b = shape::dynamic_dimension{3, 6, {3}};
auto v1 = migraphx::to_value(a);
auto v2 = migraphx::to_value(b);
EXPECT(v1 != v2);
auto c = migraphx::from_value<shape::dynamic_dimension>(v1);
EXPECT(a == c);
auto d = migraphx::from_value<shape::dynamic_dimension>(v2);
EXPECT(b == d);
}
TEST_CASE(test_shape_dynamic_errors)
{
using migraphx::shape;
std::vector<shape::dynamic_dimension> dims = {};
dims.push_back(shape::dynamic_dimension{2, 5, 2});
dims.push_back(shape::dynamic_dimension{2, 8, 0});
dims.push_back(shape::dynamic_dimension{2, 5, {2}});
dims.push_back(shape::dynamic_dimension{2, 8});
migraphx::shape s{shape::float_type, dims};
EXPECT(test::throws([&] { s.elements(); }));
EXPECT(test::throws([&] { s.index({0, 1}); }));
......@@ -220,13 +234,13 @@ TEST_CASE(test_shape_dynamic_serialize)
{
using migraphx::shape;
std::vector<shape::dynamic_dimension> dims1 = {};
dims1.push_back(shape::dynamic_dimension{2, 5, 2});
dims1.push_back(shape::dynamic_dimension{2, 8, 0});
dims1.push_back(shape::dynamic_dimension{2, 5, {2}});
dims1.push_back(shape::dynamic_dimension{2, 8});
migraphx::shape s1{shape::float_type, dims1};
auto v1 = migraphx::to_value(s1);
std::vector<shape::dynamic_dimension> dims2 = {};
dims2.push_back(shape::dynamic_dimension{2, 5, 2});
dims2.push_back(shape::dynamic_dimension{2, 5, {2}});
migraphx::shape s2{shape::uint64_type, dims2};
auto v2 = migraphx::to_value(s2);
EXPECT(v1 != v2);
......@@ -238,6 +252,30 @@ TEST_CASE(test_shape_dynamic_serialize)
EXPECT(s3 != s4);
}
TEST_CASE(any_of_dynamic_true)
{
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}});
migraphx::shape s0{sub_shapes};
EXPECT(s0.any_of_dynamic());
sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 1}, {4, 4}}});
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}});
migraphx::shape s1{sub_shapes};
EXPECT(s1.any_of_dynamic());
}
TEST_CASE(any_of_dynamic_false)
{
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {1, 4}});
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}});
migraphx::shape s{sub_shapes};
EXPECT(not s.any_of_dynamic());
}
TEST_CASE(test_shape_packed)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {2, 1}};
......@@ -261,14 +299,13 @@ TEST_CASE(test_shape_ndim_static)
TEST_CASE(test_shape_ndim_dyn)
{
migraphx::shape s0{migraphx::shape::float_type, {{2, 2, 0}, {2, 2, 0}}};
migraphx::shape s0{migraphx::shape::float_type, {{2, 2}, {2, 2}}};
EXPECT(s0.ndim() == 2);
migraphx::shape s1{migraphx::shape::float_type, {{1, 1, 0}, {2, 4, 0}, {2, 4, 0}, {2, 4, 0}}};
migraphx::shape s1{migraphx::shape::float_type, {{1, 1}, {2, 4}, {2, 4}, {2, 4}}};
EXPECT(s1.ndim() == 4);
migraphx::shape s2{migraphx::shape::float_type,
{{1, 1, 0}, {2, 4, 0}, {2, 4, 0}, {1, 1, 1}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 1}, {2, 4}, {2, 4}, {1, 1}, {3, 3}}};
EXPECT(s2.ndim() == 5);
}
......@@ -303,17 +340,60 @@ TEST_CASE(test_shape_static_to_dynamic)
{
migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4, 4}};
migraphx::shape s1 = s0.to_dynamic();
migraphx::shape s2{migraphx::shape::float_type, {{1, 1, 0}, {2, 2, 0}, {4, 4, 0}, {4, 4, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 1}, {2, 2}, {4, 4}, {4, 4}}};
EXPECT(s1 == s2);
}
TEST_CASE(test_shape_dyn_to_dynamic)
{
migraphx::shape s0{migraphx::shape::float_type, {{1, 1, 0}, {2, 4, 0}, {2, 4, 0}, {2, 4, 0}}};
migraphx::shape s0{migraphx::shape::float_type, {{1, 1}, {2, 4}, {2, 4}, {2, 4}}};
migraphx::shape s1 = s0.to_dynamic();
EXPECT(s0 == s1);
}
TEST_CASE(test_shape_subshapes_to_dynamic)
{
std::vector<migraphx::shape> sub_shapes0 = {};
sub_shapes0.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
sub_shapes0.push_back(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}});
migraphx::shape s0{sub_shapes0};
migraphx::shape s1 = s0.to_dynamic();
std::vector<migraphx::shape> sub_shapes1 = {};
sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {{3, 3}, {4, 4}, {5, 5}}});
migraphx::shape s2{sub_shapes1};
EXPECT(s1 == s2);
}
TEST_CASE(test_shape_dyn_to_static)
{
migraphx::shape s0{migraphx::shape::float_type, {{1, 1}, {2, 2}, {2, 10}, {2, 10}}};
migraphx::shape s1 = s0.to_static(4);
migraphx::shape s2{migraphx::shape::float_type, {1, 2, 4, 4}};
EXPECT(s1 == s2);
}
TEST_CASE(test_shape_static_to_static)
{
migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4, 4}};
migraphx::shape s1 = s0.to_static(8);
EXPECT(s0 == s1);
}
TEST_CASE(test_shape_subshapes_to_static)
{
std::vector<migraphx::shape> sub_shapes0 = {};
sub_shapes0.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
sub_shapes0.push_back(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}});
migraphx::shape s0{sub_shapes0};
migraphx::shape s1 = s0.to_static(3);
std::vector<migraphx::shape> sub_shapes1 = {};
sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {3, 4}});
sub_shapes1.push_back(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}});
migraphx::shape s2{sub_shapes1};
EXPECT(s1 == s2);
}
TEST_CASE(test_shape_overlap)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 2}};
......@@ -864,4 +944,16 @@ TEST_CASE(test_with_type)
EXPECT(s.strides() == new_s.strides());
}
TEST_CASE(test_multi_index)
{
migraphx::shape s{migraphx::shape::float_type, {2, 4, 6}};
EXPECT(migraphx::verify_range(s.multi(0), std::vector<size_t>{0, 0, 0}));
EXPECT(migraphx::verify_range(s.multi(4), std::vector<size_t>{0, 0, 4}));
EXPECT(migraphx::verify_range(s.multi(6), std::vector<size_t>{0, 1, 0}));
EXPECT(migraphx::verify_range(s.multi(8), std::vector<size_t>{0, 1, 2}));
EXPECT(migraphx::verify_range(s.multi(24), std::vector<size_t>{1, 0, 0}));
EXPECT(migraphx::verify_range(s.multi(30), std::vector<size_t>{1, 1, 0}));
EXPECT(migraphx::verify_range(s.multi(34), std::vector<size_t>{1, 1, 4}));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -509,6 +509,34 @@ TEST_CASE(simplify_dot_add)
EXPECT(m1 == m2);
}
TEST_CASE(simplify_conv_add)
{
migraphx::shape s{migraphx::shape::float_type, {1, 3, 32, 32}};
migraphx::shape ws{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s);
auto c = m1.add_literal(migraphx::generate_literal(s, 1));
auto w = m1.add_literal(migraphx::generate_literal(ws, 2));
auto sum = m1.add_instruction(migraphx::make_op("add"), c, x);
auto conv = m1.add_instruction(migraphx::make_op("convolution"), sum, w);
m1.add_instruction(pass_op{}, conv);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", s);
auto c = m2.add_literal(migraphx::generate_literal(s, 1));
auto w = m2.add_literal(migraphx::generate_literal(ws, 2));
auto conv1 = m2.add_instruction(migraphx::make_op("convolution"), c, w);
auto conv2 = m2.add_instruction(migraphx::make_op("convolution"), x, w);
auto sum = m2.add_instruction(migraphx::make_op("add"), conv1, conv2);
m2.add_instruction(pass_op{}, sum);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_inner_broadcast1)
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
......@@ -585,6 +613,60 @@ TEST_CASE(simplify_inner_broadcast_scalar)
EXPECT(m1 == m2);
}
TEST_CASE(simplify_inner_broadcast_different_dims)
{
auto b = migraphx::op::multibroadcast{{2, 384, 768}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {384, 768}});
auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {768}});
auto xb = m1.add_instruction(b, x);
auto yb = m1.add_instruction(b, y);
auto sum = m1.add_instruction(migraphx::make_op("add"), xb, yb);
m1.add_instruction(pass_op{}, sum);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {384, 768}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}});
auto yb = m2.add_instruction(migraphx::op::multibroadcast{{384, 768}}, y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_inner_broadcast_different_broadcasts)
{
auto b = migraphx::op::broadcast{1, {1, 24, 112, 112}};
auto mb = migraphx::op::multibroadcast{{1, 24, 112, 112}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {24}});
auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {24, 1, 1}});
auto xb = m1.add_instruction(b, x);
auto yb = m1.add_instruction(mb, y);
auto sum = m1.add_instruction(migraphx::make_op("add"), xb, yb);
m1.add_instruction(pass_op{}, sum);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {24}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {24, 1, 1}});
auto xs = m2.add_instruction(migraphx::make_op("squeeze"), x);
auto ys = m2.add_instruction(migraphx::make_op("squeeze"), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), xs, ys);
auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_add_conv1)
{
migraphx::module m;
......@@ -1067,16 +1149,18 @@ TEST_CASE(simplify_neg_unit_mult_const)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto unit = m1.add_literal(-1);
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 6}});
auto unit = m1.add_literal(
migraphx::literal{{migraphx::shape::int32_type, {1, 6}}, std::vector<int>(6, -1)});
m1.add_instruction(migraphx::make_op("mul"), x, unit);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
m2.add_instruction(migraphx::make_op("neg"), x);
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 6}});
auto x2 = m2.add_instruction(migraphx::make_op("neg"), x);
m2.add_instruction(migraphx::make_op("identity"), x2);
}
EXPECT((m1 == m2));
......@@ -1095,7 +1179,29 @@ TEST_CASE(simplify_neg_unit_mult_const2)
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
m2.add_instruction(migraphx::make_op("neg"), x);
auto x2 = m2.add_instruction(migraphx::make_op("neg"), x);
m2.add_instruction(migraphx::make_op("identity"), x2);
}
EXPECT((m1 == m2));
}
TEST_CASE(simplify_neg_unit_mult_const_add)
{
migraphx::module m1;
{
auto unit = m1.add_literal(-1);
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto x2 = m1.add_instruction(migraphx::make_op("mul"), unit, x);
m1.add_instruction(migraphx::make_op("add"), x2, x2);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto x2 = m2.add_instruction(migraphx::make_op("neg"), x);
m2.add_instruction(migraphx::make_op("add"), x2, x2);
}
EXPECT((m1 == m2));
......@@ -1118,7 +1224,8 @@ TEST_CASE(simplify_neg_unit_mul_const_vec)
migraphx::module m2;
{
auto x = m2.add_parameter("x", x_shape);
m2.add_instruction(migraphx::make_op("neg"), x);
auto x2 = m2.add_instruction(migraphx::make_op("neg"), x);
m2.add_instruction(migraphx::make_op("identity"), x2);
}
EXPECT(m1 == m2);
......@@ -1141,7 +1248,8 @@ TEST_CASE(simplify_neg_unit_mul_const_vec2)
migraphx::module m2;
{
auto x = m2.add_parameter("x", x_shape);
m2.add_instruction(migraphx::make_op("neg"), x);
auto x2 = m2.add_instruction(migraphx::make_op("neg"), x);
m2.add_instruction(migraphx::make_op("identity"), x2);
}
EXPECT(m1 == m2);
......@@ -1160,7 +1268,8 @@ TEST_CASE(simplify_neg_unit_div_const)
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
m2.add_instruction(migraphx::make_op("neg"), x);
auto x2 = m2.add_instruction(migraphx::make_op("neg"), x);
m2.add_instruction(migraphx::make_op("identity"), x2);
}
EXPECT(m1 == m2);
......@@ -1183,7 +1292,8 @@ TEST_CASE(simplify_neg_unit_div_const_vec)
migraphx::module m2;
{
auto x = m2.add_parameter("x", x_shape);
m2.add_instruction(migraphx::make_op("neg"), x);
auto x2 = m2.add_instruction(migraphx::make_op("neg"), x);
m2.add_instruction(migraphx::make_op("identity"), x2);
}
EXPECT(m1 == m2);
......@@ -1243,7 +1353,8 @@ TEST_CASE(simplify_sub_neg_zero_const)
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}});
m2.add_instruction(migraphx::make_op("neg"), x);
auto x2 = m2.add_instruction(migraphx::make_op("neg"), x);
m2.add_instruction(migraphx::make_op("identity"), x2);
}
EXPECT(m1 == m2);
}
......@@ -1265,7 +1376,8 @@ TEST_CASE(simplify_sub_neg_zero_const_vec)
migraphx::module m2;
{
auto x = m2.add_parameter("x", x_shape);
m2.add_instruction(migraphx::make_op("neg"), x);
auto x2 = m2.add_instruction(migraphx::make_op("neg"), x);
m2.add_instruction(migraphx::make_op("identity"), x2);
}
EXPECT(m1 == m2);
......@@ -2945,6 +3057,38 @@ TEST_CASE(reorder_slice_ins_deps)
EXPECT(m == create_module());
}
TEST_CASE(dot_broadcast_different_rank)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {768}});
auto y = m1.add_parameter("y", {migraphx::shape::float_type, {768, 3072}});
auto xb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 768}}}), x);
auto yb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 768, 3072}}}), y);
auto dot = m1.add_instruction(migraphx::make_op("dot"), xb, yb);
m1.add_return({dot});
};
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {768}});
auto y = m2.add_parameter("y", {migraphx::shape::float_type, {768, 3072}});
auto xb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {384, 768}}}), x);
auto yb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {768, 3072}}}), y);
auto dot = m2.add_instruction(migraphx::make_op("dot"), xb, yb);
auto broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 3072}}}), dot);
m2.add_return({broadcast});
};
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(dot_fusion_reshape)
{
migraphx::module m1;
......@@ -2994,4 +3138,257 @@ TEST_CASE(dot_fusion_reshape)
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(mul_dot_a)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_parameter("input", as);
auto lit =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 1, 32}}));
auto litb =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", as.lens()}}), lit);
auto mul = m1.add_instruction(migraphx::make_op("mul"), a, litb);
auto b = m1.add_literal(migraphx::generate_literal(bs));
auto dot = m1.add_instruction(migraphx::make_op("dot"), mul, b);
m1.add_return({dot});
};
run_pass(m1);
migraphx::module m2;
{
auto a = m2.add_parameter("input", as);
auto lit =
m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 1, 32}}));
auto litb = m2.add_instruction(
migraphx::make_op("multibroadcast",
{{"out_lens", migraphx::reorder_dims(bs.lens(), {0, 2, 1})}}),
lit);
auto litt =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), litb);
auto b = m2.add_literal(migraphx::generate_literal(bs));
auto mul = m2.add_instruction(migraphx::make_op("mul"), b, litt);
auto dot = m2.add_instruction(migraphx::make_op("dot"), a, mul);
m2.add_return({dot});
};
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(mul_dot_b)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto b = m1.add_parameter("input", bs);
auto lit =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 32, 1}}));
auto litb =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", bs.lens()}}), lit);
auto mul = m1.add_instruction(migraphx::make_op("mul"), b, litb);
auto a = m1.add_literal(migraphx::generate_literal(as));
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, mul);
m1.add_return({dot});
};
run_pass(m1);
migraphx::module m2;
{
auto b = m2.add_parameter("input", bs);
auto lit =
m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 32, 1}}));
auto litb = m2.add_instruction(
migraphx::make_op("multibroadcast",
{{"out_lens", migraphx::reorder_dims(as.lens(), {0, 2, 1})}}),
lit);
auto litt =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), litb);
auto a = m2.add_literal(migraphx::generate_literal(as));
auto mul = m2.add_instruction(migraphx::make_op("mul"), a, litt);
auto dot = m2.add_instruction(migraphx::make_op("dot"), mul, b);
m2.add_return({dot});
};
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(mul_dot_a_not_k_broadcast)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_parameter("input", as);
auto lit =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 256, 1}}));
auto litb =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", as.lens()}}), lit);
auto mul = m1.add_instruction(migraphx::make_op("mul"), a, litb);
auto b = m1.add_literal(migraphx::generate_literal(bs));
auto dot = m1.add_instruction(migraphx::make_op("dot"), mul, b);
m1.add_return({dot});
};
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(mul_dot_b_not_k_broadcast)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto b = m1.add_parameter("input", bs);
auto lit =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 1, 128}}));
auto litb =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", bs.lens()}}), lit);
auto mul = m1.add_instruction(migraphx::make_op("mul"), b, litb);
auto a = m1.add_literal(migraphx::generate_literal(as));
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, mul);
m1.add_return({dot});
};
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(dot_mul_a)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_parameter("input", as);
auto b = m1.add_literal(migraphx::generate_literal(bs));
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b);
auto lit =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 1, 128}}));
auto litb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", dot->get_shape().lens()}}), lit);
auto mul = m1.add_instruction(migraphx::make_op("mul"), dot, litb);
m1.add_return({mul});
};
run_pass(m1);
migraphx::module m2;
{
auto a = m2.add_parameter("input", as);
auto b = m2.add_literal(migraphx::generate_literal(bs));
auto lit =
m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 1, 128}}));
auto litb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", bs.lens()}}), lit);
auto mul = m2.add_instruction(migraphx::make_op("mul"), b, litb);
auto dot = m2.add_instruction(migraphx::make_op("dot"), a, mul);
m2.add_return({dot});
};
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(dot_mul_a_non_const)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_parameter("input", as);
auto b = m1.add_literal(migraphx::generate_literal(bs));
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b);
auto lit =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 256, 1}}));
auto litb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", dot->get_shape().lens()}}), lit);
auto mul = m1.add_instruction(migraphx::make_op("mul"), dot, litb);
m1.add_return({mul});
};
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(dot_mul_b)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_literal(migraphx::generate_literal(as));
auto b = m1.add_parameter("input", bs);
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b);
auto lit =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 256, 1}}));
auto litb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", dot->get_shape().lens()}}), lit);
auto mul = m1.add_instruction(migraphx::make_op("mul"), dot, litb);
m1.add_return({mul});
};
run_pass(m1);
migraphx::module m2;
{
auto a = m2.add_literal(migraphx::generate_literal(as));
auto b = m2.add_parameter("input", bs);
auto lit =
m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 256, 1}}));
auto litb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", as.lens()}}), lit);
auto mul = m2.add_instruction(migraphx::make_op("mul"), a, litb);
auto dot = m2.add_instruction(migraphx::make_op("dot"), mul, b);
m2.add_return({dot});
};
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(dot_mul_b_non_const)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_literal(migraphx::generate_literal(as));
auto b = m1.add_parameter("input", bs);
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b);
auto lit =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 1, 128}}));
auto litb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", dot->get_shape().lens()}}), lit);
auto mul = m1.add_instruction(migraphx::make_op("mul"), dot, litb);
m1.add_return({mul});
};
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -23,7 +23,7 @@
*/
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/program.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/instruction.hpp>
#include <test.hpp>
#include <migraphx/make_op.hpp>
......@@ -402,9 +402,10 @@ TEST_CASE(conv_bias_add)
auto bias = m1.add_parameter("bias", s6);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto zero32 = m1.add_literal(std::int32_t{0});
auto d1 = add_quantize_op(m1, "dequantizelinear", weights, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", bias, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", bias, scale, zero32);
auto q1 = add_quantize_op(m1, "quantizelinear", input, scale, zero);
auto d5 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto c1 = m1.add_instruction(migraphx::make_op("convolution",
......@@ -428,9 +429,10 @@ TEST_CASE(conv_bias_add)
auto bias = m2.add_parameter("bias", s6);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto zero32 = m2.add_literal(std::int32_t{0});
auto scale1 = m2.add_literal(0.25f);
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero);
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
{{"padding", {0, 0, 0, 0}},
......@@ -468,9 +470,10 @@ TEST_CASE(conv_pooling_dot)
auto input = m1.add_parameter("input", s7);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto zero32 = m1.add_literal(std::int32_t{0});
auto d1 = add_quantize_op(m1, "dequantizelinear", weights, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", bias, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", bias, scale, zero32);
auto d3 = add_quantize_op(m1, "dequantizelinear", ab, scale, zero);
auto d4 = add_quantize_op(m1, "dequantizelinear", db, scale, zero);
auto q1 = add_quantize_op(m1, "quantizelinear", input, scale, zero);
......@@ -515,10 +518,11 @@ TEST_CASE(conv_pooling_dot)
auto input = m2.add_parameter("input", s7);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto zero32 = m2.add_literal(std::int32_t{0});
auto scale1 = m2.add_literal(0.25f);
auto scale2 = m2.add_literal(0.25f);
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero);
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32);
auto d3 = add_quantize_op(m2, "dequantizelinear", ab, scale, zero);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
......@@ -572,9 +576,10 @@ TEST_CASE(mobilenet_snippet)
auto input = mm.add_parameter("input", s7);
auto scale = mm.add_literal(0.5f);
auto zero = mm.add_literal(std::int8_t{0});
auto zero32 = mm.add_literal(std::int32_t{0});
auto d1 = add_quantize_op(mm, "dequantizelinear", weights, scale, zero);
auto d2 = add_quantize_op(mm, "dequantizelinear", bias, scale, zero);
auto d2 = add_quantize_op(mm, "dequantizelinear", bias, scale, zero32);
auto d3 = add_quantize_op(mm, "dequantizelinear", ab, scale, zero);
auto d4 = add_quantize_op(mm, "dequantizelinear", db, scale, zero);
auto q1 = add_quantize_op(mm, "quantizelinear", input, scale, zero);
......@@ -686,8 +691,8 @@ TEST_CASE(conv_correctness)
auto input = migraphx::argument(si, iv.data());
std::vector<float> wv(sw.elements(), 10);
auto weights = migraphx::argument(sw, wv.data());
p1.compile(migraphx::target(migraphx::ref::target{}));
p2.compile(migraphx::target(migraphx::ref::target{}));
p1.compile(migraphx::target(migraphx::make_target("ref")));
p2.compile(migraphx::target(migraphx::make_target("ref")));
auto result1 = p1.eval({{"input", input}, {"weights", weights}}).back();
std::vector<float> rv1(16);
......@@ -736,8 +741,8 @@ TEST_CASE(dot_correctness)
auto a = migraphx::argument(sh1, av.data());
std::vector<float> bv(sh2.elements(), 10);
auto b = migraphx::argument(sh2, bv.data());
p1.compile(migraphx::target(migraphx::ref::target{}));
p2.compile(migraphx::target(migraphx::ref::target{}));
p1.compile(migraphx::target(migraphx::make_target("ref")));
p2.compile(migraphx::target(migraphx::make_target("ref")));
auto result1 = p1.eval({{"a", a}, {"b", b}}).back();
std::vector<float> rv1(sh3.elements());
......
......@@ -1322,6 +1322,46 @@ TEST_CASE(transpose_slice)
EXPECT(m1 == m2);
}
TEST_CASE(transpose_slice_unsqueeze)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {4, 1024, 96, 64}});
auto transpose1 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto slice1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {8}}}),
transpose1);
auto slice2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {16}}, {"ends", {24}}}),
transpose1);
auto slice3 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {32}}, {"ends", {40}}}),
transpose1);
m1.add_return({slice1, slice2, slice3});
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {4, 1024, 96, 64}});
auto unsq =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {12}}}), x);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 4, 1}}}), unsq);
auto slice1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose);
auto sq1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice1);
auto slice2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), transpose);
auto sq2 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice2);
auto slice3 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {5}}}), transpose);
auto sq3 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice3);
m2.add_return({sq1, sq2, sq3});
}
EXPECT(m1 == m2);
}
TEST_CASE(transpose_slice_diff_perm)
{
migraphx::module m1;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/program.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <test.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::split_single_dyn_dim{}, migraphx::dead_code_elimination{}});
}
TEST_CASE(dynamic_batch)
{
// Slightly different from ref_ops_test in that the literal is copied over the submodules.
// A different compiler pass will pull the literals from the submodules to the main module.
migraphx::program p0;
{
auto* mm0 = p0.get_main_module();
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p0.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit = submod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sm_shape.lens()}}), literal_ins);
auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
submod->add_return({add_ins});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
auto* dim3 = create_submodule(3, "dim_3");
auto* dim4 = create_submodule(4, "dim_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2, dim3, dim4});
auto ret =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm0->add_return({ret});
}
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input1 = mm1->add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit =
mm1->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, input1);
auto add_ins = mm1->add_instruction(migraphx::make_op("add"), input1, broadcast_lit);
mm1->add_return({add_ins});
}
run_pass(p1);
EXPECT(p0 == p1);
}
TEST_CASE(multiple_outputs)
{
migraphx::program p0;
{
auto* mm0 = p0.get_main_module();
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p0.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit = submod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sm_shape.lens()}}), literal_ins);
auto add0_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
auto add1_ins = submod->add_instruction(migraphx::make_op("add"), sm_input, sm_input);
submod->add_return({add0_ins, add1_ins});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
auto* dim3 = create_submodule(3, "dim_3");
auto* dim4 = create_submodule(4, "dim_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
migraphx::shape tmp_s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
sub_shapes.push_back(tmp_s);
sub_shapes.push_back(tmp_s);
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2, dim3, dim4});
auto ret0 =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
auto ret1 =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), sm_ins);
mm0->add_return({ret0, ret1});
}
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input1 = mm1->add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit =
mm1->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, input1);
auto add0_ins = mm1->add_instruction(migraphx::make_op("add"), input1, broadcast_lit);
auto add1_ins = mm1->add_instruction(migraphx::make_op("add"), input1, input1);
mm1->add_return({add0_ins, add1_ins});
}
run_pass(p1);
EXPECT(p0 == p1);
}
TEST_CASE(broadcast_match)
{
// Slightly different from ref_ops_test in that the literal is copied over the submodules.
// A different compiler pass will pull the literals from the submodules to the main module.
migraphx::program p0;
{
auto* mm0 = p0.get_main_module();
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p0.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {4}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6, 5, 4, 3}});
auto broadcast_lit = submod->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", sm_shape.lens()}}),
literal_ins);
auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
submod->add_return({add_ins});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
auto* dim3 = create_submodule(3, "dim_3");
auto* dim4 = create_submodule(4, "dim_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2, dim3, dim4});
auto ret =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm0->add_return({ret});
}
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input1 = mm1->add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {4}}};
auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6, 5, 4, 3}});
auto broadcast_lit = mm1->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}}), literal_ins, input1);
auto add_ins = mm1->add_instruction(migraphx::make_op("add"), input1, broadcast_lit);
mm1->add_return({add_ins});
}
run_pass(p1);
EXPECT(p0 == p1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -22,7 +22,6 @@
* THE SOFTWARE.
*/
#include <migraphx/register_target.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/target.hpp>
#include "test.hpp"
......@@ -42,8 +41,13 @@ TEST_CASE(make_invalid_target)
TEST_CASE(targets)
{
// GCC doesn't load libmigraphx_ref unless necesssary even though it is linked to the test.
// Force it to load by making ref target
#if defined(__GNUC__) && !defined(__clang__)
auto ref_target = migraphx::make_target("ref");
#endif
auto ts = migraphx::get_targets();
EXPECT(ts.size() > 0);
EXPECT(ts.size() == 1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -67,7 +67,14 @@ int main(int argc, const char* argv[])
{
run_verify rv;
rv.add_validation_for("gpu", &validate_gpu);
rv.disable_test_for("cpu", {"test_if_lp", "test_if_param", "test_if_literal"});
rv.disable_test_for("cpu",
{"test_if_lp",
"test_if_param",
"test_if_literal",
"test_select_module_add",
"test_select_module_reduce",
"test_select_module_conv",
"test_split_single_dyn_dim"});
rv.disable_test_for("gpu", {"test_conv_bn_add"});
rv.run(argc, argv);
}
......@@ -26,7 +26,7 @@
#include "verify_program.hpp"
#include "test.hpp"
#include <migraphx/env.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/load_save.hpp>
......@@ -67,15 +67,17 @@ inline void verify_load_save(const migraphx::program& p)
EXPECT(p == loaded);
}
inline void compile_check(migraphx::program& p, const migraphx::target& t, bool show_trace = false)
inline void compile_check(migraphx::program& p,
const migraphx::target& t,
migraphx::compile_options c_opts,
bool show_trace = false)
{
auto name = t.name();
auto shapes = p.get_output_shapes();
std::stringstream ss;
migraphx::compile_options options;
if(show_trace)
options.trace = migraphx::tracer{std::cout};
p.compile(t, options);
c_opts.trace = migraphx::tracer{std::cout};
p.compile(t, c_opts);
if(shapes.size() != p.get_output_shapes().size())
{
std::cout << ss.str() << std::endl;
......@@ -115,19 +117,23 @@ void run_verify::validate(const migraphx::target& t,
}
std::vector<migraphx::argument> run_verify::run_ref(migraphx::program p,
migraphx::parameter_map inputs) const
migraphx::parameter_map inputs,
const migraphx::compile_options& c_opts) const
{
migraphx::ref::target t{};
migraphx::target t = migraphx::make_target("ref");
auto_print pp{p, t.name()};
compile_check(p, t);
compile_check(p, t, c_opts);
return p.eval(std::move(inputs));
}
std::pair<migraphx::program, std::vector<migraphx::argument>> run_verify::run_target(
const migraphx::target& t, migraphx::program p, const migraphx::parameter_map& inputs) const
std::pair<migraphx::program, std::vector<migraphx::argument>>
run_verify::run_target(const migraphx::target& t,
migraphx::program p,
const migraphx::parameter_map& inputs,
const migraphx::compile_options& c_opts) const
{
auto_print pp{p, t.name()};
auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{});
compile_check(p, t, (trace_target == t.name()));
compile_check(p, t, c_opts, (trace_target == t.name()));
migraphx::parameter_map m;
for(auto&& input : inputs)
{
......@@ -157,7 +163,9 @@ auto get_hash(const T& x)
return std::hash<T>{}(x);
}
void run_verify::verify(const std::string& name, const migraphx::program& p) const
void run_verify::verify(const std::string& name,
const migraphx::program& p,
const migraphx::compile_options& c_opts) const
{
using result_future =
std::future<std::pair<migraphx::program, std::vector<migraphx::argument>>>;
......@@ -184,17 +192,26 @@ void run_verify::verify(const std::string& name, const migraphx::program& p) con
std::vector<std::pair<std::string, result_future>> results;
migraphx::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
if(x.second.dynamic())
{
// create static shape using maximum dimensions
migraphx::shape static_shape{x.second.type(), x.second.max_lens()};
m[x.first] = migraphx::generate_argument(static_shape, get_hash(x.first));
}
else
{
m[x.first] = migraphx::generate_argument(x.second, get_hash(x.first));
}
}
auto gold_f = detach_async([=] { return run_ref(p, m); });
auto gold_f = detach_async([=] { return run_ref(p, m, c_opts); });
for(const auto& tname : target_names)
{
target_info ti = get_target_info(tname);
auto t = migraphx::make_target(tname);
results.emplace_back(tname,
detach_async([=] { return run_target(t, p, m); }, ti.parallel));
results.emplace_back(
tname, detach_async([=] { return run_target(t, p, m, c_opts); }, ti.parallel));
}
assert(gold_f.valid());
......@@ -235,7 +252,7 @@ void run_verify::run(int argc, const char* argv[]) const
for(auto&& p : get_programs())
{
labels[p.section].push_back(p.name);
test::add_test_case(p.name, [=] { verify(p.name, p.get_program()); });
test::add_test_case(p.name, [=] { verify(p.name, p.get_program(), p.compile_options); });
}
test::driver d{};
d.get_case_names = [&](const std::string& name) -> std::vector<std::string> {
......
......@@ -40,15 +40,19 @@ struct target_info
struct run_verify
{
std::vector<migraphx::argument> run_ref(migraphx::program p,
migraphx::parameter_map inputs) const;
migraphx::parameter_map inputs,
const migraphx::compile_options& c_opts) const;
std::pair<migraphx::program, std::vector<migraphx::argument>>
run_target(const migraphx::target& t,
migraphx::program p,
const migraphx::parameter_map& inputs) const;
const migraphx::parameter_map& inputs,
const migraphx::compile_options& c_opts) const;
void validate(const migraphx::target& t,
const migraphx::program& p,
const migraphx::parameter_map& m) const;
void verify(const std::string& name, const migraphx::program& p) const;
void verify(const std::string& name,
const migraphx::program& p,
const migraphx::compile_options& c_opts) const;
void run(int argc, const char* argv[]) const;
target_info get_target_info(const std::string& name) const;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_add_conv_constant : verify_program<test_add_conv_constant>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {1, 3, 32, 32}};
migraphx::shape ws{migraphx::shape::float_type, {4, 3, 3, 3}};
auto x = mm->add_parameter("x", s);
auto c = mm->add_literal(migraphx::generate_literal(s, 1));
auto w = mm->add_literal(migraphx::generate_literal(ws, 2));
auto sum = mm->add_instruction(migraphx::make_op("add"), c, x);
mm->add_instruction(migraphx::make_op("convolution"), sum, w);
return p;
}
};
......@@ -33,13 +33,12 @@ struct test_concat_axis_2 : verify_program<test_concat_axis_2>
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::int32_type, {3, 2, 1}};
migraphx::shape s1{migraphx::shape::int32_type, {3, 2, 1}};
migraphx::shape s2{migraphx::shape::int32_type, {3, 2, 1}};
auto l0 = mm->add_parameter("x", s0);
auto l1 = mm->add_parameter("y", s1);
auto l2 = mm->add_parameter("z", s2);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 2}}), l0, l1, l2);
migraphx::shape s{migraphx::shape::int32_type, {3, 2, 1}};
auto x0 = mm->add_parameter("x0", s);
auto x1 = mm->add_parameter("x1", s);
auto x2 = mm->add_parameter("x2", s);
auto x3 = mm->add_parameter("x3", s);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 2}}), x0, x1, x2, x3);
return p;
}
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
struct test_dot_mul_a : verify_program<test_dot_mul_a>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
auto a = mm->add_parameter("input", as);
auto b = mm->add_literal(migraphx::generate_literal(bs));
auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto lit =
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 1, 128}}));
auto litb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", dot->get_shape().lens()}}), lit);
auto mul = mm->add_instruction(migraphx::make_op("mul"), dot, litb);
mm->add_return({mul});
return p;
}
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
struct test_dot_mul_b : verify_program<test_dot_mul_b>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
auto a = mm->add_literal(migraphx::generate_literal(as));
auto b = mm->add_parameter("input", bs);
auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto lit =
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 256, 1}}));
auto litb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", dot->get_shape().lens()}}), lit);
auto mul = mm->add_instruction(migraphx::make_op("mul"), dot, litb);
mm->add_return({mul});
return p;
}
};
......@@ -24,31 +24,30 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/reduce_mean.hpp>
migraphx::instruction_ref add_layernorm(migraphx::module& m,
migraphx::instruction_ref x,
std::vector<size_t> dims,
float eps = 1e-12f)
{
auto scale =
m.add_parameter("scale", migraphx::shape{migraphx::shape::float_type, {dims.back()}});
auto bias =
m.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {dims.back()}});
auto epsilon = m.add_literal(eps);
auto exponent = m.add_literal(2.0f);
auto mgx_type = x->get_shape().type();
auto scale = m.add_parameter("scale", migraphx::shape{mgx_type, {dims.back()}});
auto bias = m.add_parameter("bias", migraphx::shape{mgx_type, {dims.back()}});
auto epsilon = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {eps}});
auto exponent = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {2.0f}});
auto mean = m.add_instruction(migraphx::op::reduce_mean({2}), x);
auto mean = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), x);
auto mean_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto sub = m.add_instruction(migraphx::make_op("sub"), x, mean_mbcast);
auto exponent_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), exponent);
auto pow = m.add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast);
auto var = m.add_instruction(migraphx::op::reduce_mean({2}), pow);
auto var = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), pow);
auto epsilon_mbcast = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, dims.at(1), 1}}}), epsilon);
auto add_epsilon = m.add_instruction(migraphx::make_op("add"), var, epsilon_mbcast);
......@@ -90,6 +89,32 @@ struct test_layernorm2 : verify_program<test_layernorm2>
}
};
struct test_layernorm_large : verify_program<test_layernorm_large>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 32, 262144};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
add_layernorm(*mm, x, dims);
return p;
}
};
struct test_layernorm_fp16 : verify_program<test_layernorm_fp16>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 24, 64};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::half_type, dims});
add_layernorm(*mm, x, dims);
return p;
}
};
struct test_layernorm_eps : verify_program<test_layernorm_eps>
{
migraphx::program create_program() const
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/op/max.hpp>
#include <migraphx/op/min.hpp>
template <class Op, migraphx::shape::type_t T>
struct test_min_max : verify_program<test_min_max<Op, T>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{T, {128}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_instruction(Op{}, x, y);
return p;
}
};
template struct test_min_max<migraphx::op::max, migraphx::shape::float_type>;
template struct test_min_max<migraphx::op::max, migraphx::shape::half_type>;
template struct test_min_max<migraphx::op::max, migraphx::shape::double_type>;
template struct test_min_max<migraphx::op::min, migraphx::shape::float_type>;
template struct test_min_max<migraphx::op::min, migraphx::shape::half_type>;
template struct test_min_max<migraphx::op::min, migraphx::shape::double_type>;
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_mul_dot_a : verify_program<test_mul_dot_a>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
auto a = mm->add_parameter("input", as);
auto lit =
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 1, 32}}));
auto litb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", as.lens()}}), lit);
auto mul = mm->add_instruction(migraphx::make_op("mul"), a, litb);
auto b = mm->add_literal(migraphx::generate_literal(bs));
auto dot = mm->add_instruction(migraphx::make_op("dot"), mul, b);
mm->add_return({dot});
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment