Commit 7e297b13 authored by Paul's avatar Paul
Browse files

Merge

parents 86ea5e91 aa7ff911
...@@ -2,44 +2,44 @@ ...@@ -2,44 +2,44 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/make_op.hpp>
// struct test_conv_bn_add : verify_program<test_conv_bn_add> struct test_conv_bn_add : verify_program<test_conv_bn_add>
// { {
// static migraphx::instruction_ref add_bn(migraphx::program& p, static migraphx::instruction_ref add_bn(migraphx::module& m,
// migraphx::instruction_ref x, migraphx::instruction_ref x,
// std::size_t channels, std::size_t channels,
// std::size_t seed = 1) std::size_t seed = 1)
// { {
// migraphx::shape vars{migraphx::shape::float_type, {channels}}; migraphx::shape vars{migraphx::shape::float_type, {channels}};
// auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + auto scale = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + seed)));
// seed))); auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 auto bias = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + seed)));
// + seed))); auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, auto mean = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + seed)));
// 3 + seed))); auto variance = auto variance = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + seed)));
// mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + seed))); return return m.add_instruction(
// mm->add_instruction( migraphx::make_op("batch_norm_inference"), x, scale, bias, mean, variance);
// migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance); }
// }
// migraphx::program create_program() const migraphx::program create_program() const
// { {
// migraphx::program p; migraphx::program p;
// std::size_t ichannels = 64; auto* mm = p.get_main_module();
// std::size_t ochannels = 256; std::size_t ichannels = 64;
// auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, ichannels, 56, std::size_t ochannels = 256;
// 56}}); auto w = mm->add_literal(migraphx::generate_literal( auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, ichannels, 56, 56}});
// {migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 1)); auto w = mm->add_literal(migraphx::generate_literal(
// auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, ichannels, 56, {migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 1));
// 56}}); auto v = mm->add_literal(migraphx::generate_literal( auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, ichannels, 56, 56}});
// {migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 2)); auto v = mm->add_literal(migraphx::generate_literal(
// auto relu1 = mm->add_instruction(migraphx::op::relu{}, x); {migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 2));
// auto conv1 = mm->add_instruction(migraphx::op::convolution{}, relu1, w); auto relu1 = mm->add_instruction(migraphx::make_op("relu"), x);
// auto bn1 = add_bn(p, conv1, ochannels, 1); auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), relu1, w);
// auto relu2 = mm->add_instruction(migraphx::op::relu{}, y); auto bn1 = add_bn(*mm, conv1, ochannels, 1);
// auto conv2 = mm->add_instruction(migraphx::op::convolution{}, relu2, v); auto relu2 = mm->add_instruction(migraphx::make_op("relu"), y);
// auto bn2 = add_bn(p, conv2, ochannels, 1); auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), relu2, v);
// auto sum = mm->add_instruction(migraphx::op::add{}, bn1, bn2); auto bn2 = add_bn(*mm, conv2, ochannels, 1);
// mm->add_instruction(migraphx::op::relu{}, sum); auto sum = mm->add_instruction(migraphx::make_op("add"), bn1, bn2);
// return p; mm->add_instruction(migraphx::make_op("relu"), sum);
// } return p;
// }; }
};
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling> struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
{ {
...@@ -29,7 +30,7 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling> ...@@ -29,7 +30,7 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
migraphx::make_op("batch_norm_inference"), conv, scale, bias, mean, variance); migraphx::make_op("batch_norm_inference"), conv, scale, bias, mean, variance);
auto relu = mm->add_instruction(migraphx::make_op("relu"), bn); auto relu = mm->add_instruction(migraphx::make_op("relu"), bn);
mm->add_instruction(migraphx::make_op("pooling", mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"}, {{"mode", migraphx::op::pooling_mode::average},
{"padding", {1, 1}}, {"padding", {1, 1}},
{"stride", {2, 2}}, {"stride", {2, 2}},
{"lengths", {3, 3}}}), {"lengths", {3, 3}}}),
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2> struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2>
{ {
...@@ -47,7 +48,7 @@ struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2> ...@@ -47,7 +48,7 @@ struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2>
auto add = mm->add_instruction(migraphx::make_op("add"), bn1, bn2); auto add = mm->add_instruction(migraphx::make_op("add"), bn1, bn2);
auto relu = mm->add_instruction(migraphx::make_op("relu"), add); auto relu = mm->add_instruction(migraphx::make_op("relu"), add);
mm->add_instruction(migraphx::make_op("pooling", mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"}, {{"mode", migraphx::op::pooling_mode::average},
{"padding", {1, 1}}, {"padding", {1, 1}},
{"stride", {2, 2}}, {"stride", {2, 2}},
{"lengths", {3, 3}}}), {"lengths", {3, 3}}}),
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_conv_pooling : verify_program<test_conv_pooling> struct test_conv_pooling : verify_program<test_conv_pooling>
{ {
...@@ -15,7 +16,8 @@ struct test_conv_pooling : verify_program<test_conv_pooling> ...@@ -15,7 +16,8 @@ struct test_conv_pooling : verify_program<test_conv_pooling>
auto weights = auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto pooling = mm->add_instruction(migraphx::make_op("pooling", {{"mode", "max"}}), conv); auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), conv);
mm->add_instruction(migraphx::make_op("relu"), pooling); mm->add_instruction(migraphx::make_op("relu"), pooling);
return p; return p;
} }
......
...@@ -16,7 +16,7 @@ struct test_div2 : verify_program<test_div2> ...@@ -16,7 +16,7 @@ struct test_div2 : verify_program<test_div2>
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", b); auto z = mm->add_parameter("z", b);
auto zb = mm->add_instruction( auto zb = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", s.lens()}}), z); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", s.lens()}}), z);
auto diff = mm->add_instruction(migraphx::make_op("div"), x, y); auto diff = mm->add_instruction(migraphx::make_op("div"), x, y);
mm->add_instruction(migraphx::make_op("div"), diff, zb); mm->add_instruction(migraphx::make_op("div"), diff, zb);
return p; return p;
......
...@@ -13,9 +13,9 @@ struct test_equal_brcst : verify_program<test_equal_brcst> ...@@ -13,9 +13,9 @@ struct test_equal_brcst : verify_program<test_equal_brcst>
migraphx::shape s0{migraphx::shape::float_type, {3, 3}}; migraphx::shape s0{migraphx::shape::float_type, {3, 3}};
auto l0 = mm->add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
migraphx::shape s1{migraphx::shape::float_type, {3, 1}}; migraphx::shape s1{migraphx::shape::float_type, {3, 1}};
auto l1 = mm->add_parameter("y", s1); auto l1 = mm->add_parameter("y", s1);
auto bl1 = mm->add_instruction( auto bl1 =
migraphx::make_op("multibroadcast", {{"output_lens", s0.lens()}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s0.lens()}}), l1);
auto r = mm->add_instruction(migraphx::make_op("equal"), l0, bl1); auto r = mm->add_instruction(migraphx::make_op("equal"), l0, bl1);
mm->add_return({r}); mm->add_return({r});
......
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_gathernd_batch_dims_1 : verify_program<test_gathernd_batch_dims_1>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 3, 2, 3}};
migraphx::shape is{migraphx::shape::int64_type, {2, 3, 2}};
std::vector<int64_t> indices{1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0};
auto a0 = mm->add_parameter("data", ds);
auto a1 = mm->add_literal(migraphx::literal{is, indices});
int batch_dims = 1;
mm->add_instruction(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), a0, a1);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_gathernd_batch_dims_2 : verify_program<test_gathernd_batch_dims_2>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 3, 1, 3}};
migraphx::shape is{migraphx::shape::int64_type, {2, 3, 2}};
std::vector<int64_t> indices{0, 0, 0, 1, 0, 2, 0, 2, 0, 1, 0, 0};
auto a0 = mm->add_parameter("data", ds);
auto a1 = mm->add_literal(migraphx::literal{is, indices});
int batch_dims = 2;
mm->add_instruction(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), a0, a1);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_gathernd_default : verify_program<test_gathernd_default>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 2}};
migraphx::shape is{migraphx::shape::int64_type, {2, 2}};
std::vector<int64_t> indices{0, 0, 1, 1};
auto a0 = mm->add_parameter("data", ds);
auto a1 = mm->add_literal(migraphx::literal{is, indices});
mm->add_instruction(migraphx::make_op("gathernd"), a0, a1);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_gathernd_negative_indices : verify_program<test_gathernd_negative_indices>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 2}};
migraphx::shape is{migraphx::shape::int64_type, {2, 1, 1}};
std::vector<int64_t> indices{-1, 0};
auto a0 = mm->add_parameter("data", ds);
auto a1 = mm->add_literal(migraphx::literal{is, indices});
int batch_dims = 1;
mm->add_instruction(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), a0, a1);
return p;
}
};
...@@ -16,14 +16,14 @@ struct test_gelu : verify_program<test_gelu> ...@@ -16,14 +16,14 @@ struct test_gelu : verify_program<test_gelu>
auto one = mm->add_literal(1.0f); auto one = mm->add_literal(1.0f);
auto sqrt2 = mm->add_literal(static_cast<float>(M_SQRT2)); auto sqrt2 = mm->add_literal(static_cast<float>(M_SQRT2));
auto half_mbcast = mm->add_instruction( auto half_mbcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), half); migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), half);
auto mul_half = mm->add_instruction(migraphx::make_op("mul"), x, half_mbcast); auto mul_half = mm->add_instruction(migraphx::make_op("mul"), x, half_mbcast);
auto sqrt2_mbcast = mm->add_instruction( auto sqrt2_mbcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), sqrt2); migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), sqrt2);
auto div = mm->add_instruction(migraphx::make_op("div"), x, sqrt2_mbcast); auto div = mm->add_instruction(migraphx::make_op("div"), x, sqrt2_mbcast);
auto erf = mm->add_instruction(migraphx::make_op("erf"), div); auto erf = mm->add_instruction(migraphx::make_op("erf"), div);
auto one_mbcast = mm->add_instruction( auto one_mbcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), one); migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), one);
auto add_one = mm->add_instruction(migraphx::make_op("add"), erf, one_mbcast); auto add_one = mm->add_instruction(migraphx::make_op("add"), erf, one_mbcast);
mm->add_instruction(migraphx::make_op("mul"), mul_half, add_one); mm->add_instruction(migraphx::make_op("mul"), mul_half, add_one);
return p; return p;
......
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
...@@ -12,13 +13,13 @@ struct test_gemm_copy : verify_program<test_gemm_copy> ...@@ -12,13 +13,13 @@ struct test_gemm_copy : verify_program<test_gemm_copy>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::float_type, {2, 16}}; migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; migraphx::shape sc{migraphx::shape::float_type, {1, 8}};
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc); auto pc = mm->add_parameter("c", sc);
auto dr = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc); auto dr =
migraphx::add_apply_alpha_beta(*mm, {pa, pb, pc}, migraphx::make_op("dot"), 1.0f, 1.0f);
mm->add_instruction(migraphx::make_op("add"), dr, dr); mm->add_instruction(migraphx::make_op("add"), dr, dr);
return p; return p;
} }
}; };
...@@ -12,7 +12,7 @@ struct test_gemm_transposea : verify_program<test_gemm_transposea> ...@@ -12,7 +12,7 @@ struct test_gemm_transposea : verify_program<test_gemm_transposea>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {5, 4}}); auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {5, 4}});
auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}}); auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}});
auto at = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), a); auto at = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), a);
mm->add_instruction(migraphx::make_op("dot"), at, b); mm->add_instruction(migraphx::make_op("dot"), at, b);
return p; return p;
} }
......
...@@ -12,7 +12,8 @@ struct test_gemm_transposea_ex : verify_program<test_gemm_transposea_ex> ...@@ -12,7 +12,8 @@ struct test_gemm_transposea_ex : verify_program<test_gemm_transposea_ex>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}}); auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}});
auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 3}}); auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 3}});
auto at = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), a); auto at =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), a);
mm->add_instruction(migraphx::make_op("dot"), at, b); mm->add_instruction(migraphx::make_op("dot"), at, b);
return p; return p;
} }
......
...@@ -12,8 +12,8 @@ struct test_gemm_transposeab : verify_program<test_gemm_transposeab> ...@@ -12,8 +12,8 @@ struct test_gemm_transposeab : verify_program<test_gemm_transposeab>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {5, 4}}); auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {5, 4}});
auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {3, 5}}); auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {3, 5}});
auto at = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), a); auto at = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), a);
auto bt = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), b); auto bt = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b);
mm->add_instruction(migraphx::make_op("dot"), at, bt); mm->add_instruction(migraphx::make_op("dot"), at, bt);
return p; return p;
} }
......
...@@ -12,7 +12,7 @@ struct test_gemm_transposeb : verify_program<test_gemm_transposeb> ...@@ -12,7 +12,7 @@ struct test_gemm_transposeb : verify_program<test_gemm_transposeb>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}}); auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}});
auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {3, 5}}); auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {3, 5}});
auto bt = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), b); auto bt = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b);
mm->add_instruction(migraphx::make_op("dot"), a, bt); mm->add_instruction(migraphx::make_op("dot"), a, bt);
return p; return p;
} }
......
...@@ -12,7 +12,8 @@ struct test_gemm_transposeb_ex : verify_program<test_gemm_transposeb_ex> ...@@ -12,7 +12,8 @@ struct test_gemm_transposeb_ex : verify_program<test_gemm_transposeb_ex>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 4, 5}}); auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 4, 5}});
auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 3, 5}}); auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 3, 5}});
auto bt = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 1}}}), b); auto bt =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b);
mm->add_instruction(migraphx::make_op("dot"), a, bt); mm->add_instruction(migraphx::make_op("dot"), a, bt);
return p; return p;
} }
......
...@@ -13,7 +13,7 @@ struct test_global_avg_pooling : verify_program<test_global_avg_pooling> ...@@ -13,7 +13,7 @@ struct test_global_avg_pooling : verify_program<test_global_avg_pooling>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto op = migraphx::op::pooling{"average"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
auto lens = input->get_shape().lens(); auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
mm->add_instruction(op, input); mm->add_instruction(op, input);
......
...@@ -13,7 +13,7 @@ struct test_global_max_pooling : verify_program<test_global_max_pooling> ...@@ -13,7 +13,7 @@ struct test_global_max_pooling : verify_program<test_global_max_pooling>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto op = migraphx::op::pooling{"max"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
auto lens = input->get_shape().lens(); auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
mm->add_instruction(op, input); mm->add_instruction(op, input);
......
...@@ -13,9 +13,9 @@ struct test_greater_brcst : verify_program<test_greater_brcst> ...@@ -13,9 +13,9 @@ struct test_greater_brcst : verify_program<test_greater_brcst>
migraphx::shape s0{migraphx::shape::float_type, {3, 3}}; migraphx::shape s0{migraphx::shape::float_type, {3, 3}};
auto l0 = mm->add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
migraphx::shape s1{migraphx::shape::float_type, {3, 1}}; migraphx::shape s1{migraphx::shape::float_type, {3, 1}};
auto l1 = mm->add_parameter("y", s1); auto l1 = mm->add_parameter("y", s1);
auto bl1 = mm->add_instruction( auto bl1 =
migraphx::make_op("multibroadcast", {{"output_lens", s0.lens()}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s0.lens()}}), l1);
auto r = mm->add_instruction(migraphx::make_op("greater"), l0, bl1); auto r = mm->add_instruction(migraphx::make_op("greater"), l0, bl1);
mm->add_return({r}); mm->add_return({r});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment