Commit 4a39a0f7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into add-conv_bn_add-test

parents 5564172e bb827865
...@@ -16,7 +16,7 @@ struct test_div2 : verify_program<test_div2> ...@@ -16,7 +16,7 @@ struct test_div2 : verify_program<test_div2>
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", b); auto z = mm->add_parameter("z", b);
auto zb = mm->add_instruction( auto zb = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", s.lens()}}), z); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", s.lens()}}), z);
auto diff = mm->add_instruction(migraphx::make_op("div"), x, y); auto diff = mm->add_instruction(migraphx::make_op("div"), x, y);
mm->add_instruction(migraphx::make_op("div"), diff, zb); mm->add_instruction(migraphx::make_op("div"), diff, zb);
return p; return p;
......
...@@ -13,9 +13,9 @@ struct test_equal_brcst : verify_program<test_equal_brcst> ...@@ -13,9 +13,9 @@ struct test_equal_brcst : verify_program<test_equal_brcst>
migraphx::shape s0{migraphx::shape::float_type, {3, 3}}; migraphx::shape s0{migraphx::shape::float_type, {3, 3}};
auto l0 = mm->add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
migraphx::shape s1{migraphx::shape::float_type, {3, 1}}; migraphx::shape s1{migraphx::shape::float_type, {3, 1}};
auto l1 = mm->add_parameter("y", s1); auto l1 = mm->add_parameter("y", s1);
auto bl1 = mm->add_instruction( auto bl1 =
migraphx::make_op("multibroadcast", {{"output_lens", s0.lens()}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s0.lens()}}), l1);
auto r = mm->add_instruction(migraphx::make_op("equal"), l0, bl1); auto r = mm->add_instruction(migraphx::make_op("equal"), l0, bl1);
mm->add_return({r}); mm->add_return({r});
......
...@@ -16,14 +16,14 @@ struct test_gelu : verify_program<test_gelu> ...@@ -16,14 +16,14 @@ struct test_gelu : verify_program<test_gelu>
auto one = mm->add_literal(1.0f); auto one = mm->add_literal(1.0f);
auto sqrt2 = mm->add_literal(static_cast<float>(M_SQRT2)); auto sqrt2 = mm->add_literal(static_cast<float>(M_SQRT2));
auto half_mbcast = mm->add_instruction( auto half_mbcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), half); migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), half);
auto mul_half = mm->add_instruction(migraphx::make_op("mul"), x, half_mbcast); auto mul_half = mm->add_instruction(migraphx::make_op("mul"), x, half_mbcast);
auto sqrt2_mbcast = mm->add_instruction( auto sqrt2_mbcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), sqrt2); migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), sqrt2);
auto div = mm->add_instruction(migraphx::make_op("div"), x, sqrt2_mbcast); auto div = mm->add_instruction(migraphx::make_op("div"), x, sqrt2_mbcast);
auto erf = mm->add_instruction(migraphx::make_op("erf"), div); auto erf = mm->add_instruction(migraphx::make_op("erf"), div);
auto one_mbcast = mm->add_instruction( auto one_mbcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), one); migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), one);
auto add_one = mm->add_instruction(migraphx::make_op("add"), erf, one_mbcast); auto add_one = mm->add_instruction(migraphx::make_op("add"), erf, one_mbcast);
mm->add_instruction(migraphx::make_op("mul"), mul_half, add_one); mm->add_instruction(migraphx::make_op("mul"), mul_half, add_one);
return p; return p;
......
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
...@@ -12,13 +13,13 @@ struct test_gemm_copy : verify_program<test_gemm_copy> ...@@ -12,13 +13,13 @@ struct test_gemm_copy : verify_program<test_gemm_copy>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::float_type, {2, 16}}; migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; migraphx::shape sc{migraphx::shape::float_type, {1, 8}};
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc); auto pc = mm->add_parameter("c", sc);
auto dr = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc); auto dr =
migraphx::add_apply_alpha_beta(*mm, {pa, pb, pc}, migraphx::make_op("dot"), 1.0f, 1.0f);
mm->add_instruction(migraphx::make_op("add"), dr, dr); mm->add_instruction(migraphx::make_op("add"), dr, dr);
return p; return p;
} }
}; };
...@@ -12,7 +12,7 @@ struct test_gemm_transposea : verify_program<test_gemm_transposea> ...@@ -12,7 +12,7 @@ struct test_gemm_transposea : verify_program<test_gemm_transposea>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {5, 4}}); auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {5, 4}});
auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}}); auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}});
auto at = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), a); auto at = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), a);
mm->add_instruction(migraphx::make_op("dot"), at, b); mm->add_instruction(migraphx::make_op("dot"), at, b);
return p; return p;
} }
......
...@@ -12,7 +12,8 @@ struct test_gemm_transposea_ex : verify_program<test_gemm_transposea_ex> ...@@ -12,7 +12,8 @@ struct test_gemm_transposea_ex : verify_program<test_gemm_transposea_ex>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}}); auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}});
auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 3}}); auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 3}});
auto at = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), a); auto at =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), a);
mm->add_instruction(migraphx::make_op("dot"), at, b); mm->add_instruction(migraphx::make_op("dot"), at, b);
return p; return p;
} }
......
...@@ -12,8 +12,8 @@ struct test_gemm_transposeab : verify_program<test_gemm_transposeab> ...@@ -12,8 +12,8 @@ struct test_gemm_transposeab : verify_program<test_gemm_transposeab>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {5, 4}}); auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {5, 4}});
auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {3, 5}}); auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {3, 5}});
auto at = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), a); auto at = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), a);
auto bt = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), b); auto bt = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b);
mm->add_instruction(migraphx::make_op("dot"), at, bt); mm->add_instruction(migraphx::make_op("dot"), at, bt);
return p; return p;
} }
......
...@@ -12,7 +12,7 @@ struct test_gemm_transposeb : verify_program<test_gemm_transposeb> ...@@ -12,7 +12,7 @@ struct test_gemm_transposeb : verify_program<test_gemm_transposeb>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}}); auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}});
auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {3, 5}}); auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {3, 5}});
auto bt = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), b); auto bt = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b);
mm->add_instruction(migraphx::make_op("dot"), a, bt); mm->add_instruction(migraphx::make_op("dot"), a, bt);
return p; return p;
} }
......
...@@ -12,7 +12,8 @@ struct test_gemm_transposeb_ex : verify_program<test_gemm_transposeb_ex> ...@@ -12,7 +12,8 @@ struct test_gemm_transposeb_ex : verify_program<test_gemm_transposeb_ex>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 4, 5}}); auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 4, 5}});
auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 3, 5}}); auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 3, 5}});
auto bt = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 1}}}), b); auto bt =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b);
mm->add_instruction(migraphx::make_op("dot"), a, bt); mm->add_instruction(migraphx::make_op("dot"), a, bt);
return p; return p;
} }
......
...@@ -13,9 +13,9 @@ struct test_greater_brcst : verify_program<test_greater_brcst> ...@@ -13,9 +13,9 @@ struct test_greater_brcst : verify_program<test_greater_brcst>
migraphx::shape s0{migraphx::shape::float_type, {3, 3}}; migraphx::shape s0{migraphx::shape::float_type, {3, 3}};
auto l0 = mm->add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
migraphx::shape s1{migraphx::shape::float_type, {3, 1}}; migraphx::shape s1{migraphx::shape::float_type, {3, 1}};
auto l1 = mm->add_parameter("y", s1); auto l1 = mm->add_parameter("y", s1);
auto bl1 = mm->add_instruction( auto bl1 =
migraphx::make_op("multibroadcast", {{"output_lens", s0.lens()}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s0.lens()}}), l1);
auto r = mm->add_instruction(migraphx::make_op("greater"), l0, bl1); auto r = mm->add_instruction(migraphx::make_op("greater"), l0, bl1);
mm->add_return({r}); mm->add_return({r});
......
...@@ -26,7 +26,8 @@ struct test_if_literal : verify_program<test_if_literal> ...@@ -26,7 +26,8 @@ struct test_if_literal : verify_program<test_if_literal>
else_mod->add_return({l2}); else_mod->add_return({l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret}); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p; return p;
} }
......
...@@ -27,7 +27,9 @@ struct test_if_lp : verify_program<test_if_lp> ...@@ -27,7 +27,9 @@ struct test_if_lp : verify_program<test_if_lp>
else_mod->add_return({s2, l2}); else_mod->add_return({s2, l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret}); auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r0, r1});
return p; return p;
} }
......
...@@ -29,7 +29,8 @@ struct test_if_param : verify_program<test_if_param> ...@@ -29,7 +29,8 @@ struct test_if_param : verify_program<test_if_param>
else_mod->add_return({a2}); else_mod->add_return({a2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret}); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p; return p;
} }
......
...@@ -18,24 +18,24 @@ add_layernorm(migraphx::module& m, migraphx::instruction_ref x, std::vector<size ...@@ -18,24 +18,24 @@ add_layernorm(migraphx::module& m, migraphx::instruction_ref x, std::vector<size
auto mean = m.add_instruction(migraphx::op::reduce_mean({2}), x); auto mean = m.add_instruction(migraphx::op::reduce_mean({2}), x);
auto mean_mbcast = auto mean_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", dims}}), mean); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto sub = m.add_instruction(migraphx::make_op("sub"), x, mean_mbcast); auto sub = m.add_instruction(migraphx::make_op("sub"), x, mean_mbcast);
auto exponent_mbcast = auto exponent_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", dims}}), exponent); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), exponent);
auto pow = m.add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast); auto pow = m.add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast);
auto var = m.add_instruction(migraphx::op::reduce_mean({2}), pow); auto var = m.add_instruction(migraphx::op::reduce_mean({2}), pow);
auto epsilon_mbcast = m.add_instruction( auto epsilon_mbcast = m.add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, dims.at(1), 1}}}), epsilon); migraphx::make_op("multibroadcast", {{"out_lens", {1, dims.at(1), 1}}}), epsilon);
auto add_epsilon = m.add_instruction(migraphx::make_op("add"), var, epsilon_mbcast); auto add_epsilon = m.add_instruction(migraphx::make_op("add"), var, epsilon_mbcast);
auto sqrt = m.add_instruction(migraphx::make_op("sqrt"), add_epsilon); auto sqrt = m.add_instruction(migraphx::make_op("sqrt"), add_epsilon);
auto sqrt_mbcast = auto sqrt_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", dims}}), sqrt); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), sqrt);
auto div = m.add_instruction(migraphx::make_op("div"), sub, sqrt_mbcast); auto div = m.add_instruction(migraphx::make_op("div"), sub, sqrt_mbcast);
auto scale_mbcast = auto scale_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", dims}}), scale); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), scale);
auto mul = m.add_instruction(migraphx::make_op("mul"), scale_mbcast, div); auto mul = m.add_instruction(migraphx::make_op("mul"), scale_mbcast, div);
auto bias_mbcast = auto bias_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", dims}}), bias); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), bias);
return m.add_instruction(migraphx::make_op("add"), mul, bias_mbcast); return m.add_instruction(migraphx::make_op("add"), mul, bias_mbcast);
} }
...@@ -81,3 +81,20 @@ struct test_layernorm_triadd : verify_program<test_layernorm_triadd> ...@@ -81,3 +81,20 @@ struct test_layernorm_triadd : verify_program<test_layernorm_triadd>
return p; return p;
} }
}; };
struct test_layernorm_triadd_large : verify_program<test_layernorm_triadd_large>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 384, 1024};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, dims});
auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, dims});
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, z);
add_layernorm(*mm, add2, dims);
return p;
}
};
...@@ -13,9 +13,9 @@ struct test_less_brcst : verify_program<test_less_brcst> ...@@ -13,9 +13,9 @@ struct test_less_brcst : verify_program<test_less_brcst>
migraphx::shape s0{migraphx::shape::float_type, {3, 3}}; migraphx::shape s0{migraphx::shape::float_type, {3, 3}};
auto l0 = mm->add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
migraphx::shape s1{migraphx::shape::float_type, {3, 1}}; migraphx::shape s1{migraphx::shape::float_type, {3, 1}};
auto l1 = mm->add_parameter("y", s1); auto l1 = mm->add_parameter("y", s1);
auto bl1 = mm->add_instruction( auto bl1 =
migraphx::make_op("multibroadcast", {{"output_lens", s0.lens()}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s0.lens()}}), l1);
auto r = mm->add_instruction(migraphx::make_op("less"), l0, bl1); auto r = mm->add_instruction(migraphx::make_op("less"), l0, bl1);
mm->add_return({r}); mm->add_return({r});
......
...@@ -11,8 +11,9 @@ struct test_logsoftmax1 : verify_program<test_logsoftmax1> ...@@ -11,8 +11,9 @@ struct test_logsoftmax1 : verify_program<test_logsoftmax1>
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 3, 3, 4}}); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 3, 3, 4}});
auto tx = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {2, 3, 0, 1}}}), x); auto tx =
auto r = mm->add_instruction(migraphx::make_op("logsoftmax", {{"axis", 0}}), tx); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 3, 0, 1}}}), x);
auto r = mm->add_instruction(migraphx::make_op("logsoftmax", {{"axis", 0}}), tx);
mm->add_return({r}); mm->add_return({r});
return p; return p;
} }
......
#include "verify_program.hpp"
#include <migraphx/literal.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_loop : verify_program<test_loop>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape si{migraphx::shape::int64_type};
migraphx::shape s{migraphx::shape::int64_type, {1}};
migraphx::shape sc{migraphx::shape::bool_type};
int64_t iter_num = 10;
auto in_iter = mm->add_literal(migraphx::literal(si, {iter_num}));
auto in_cond = mm->add_parameter("ccond", sc);
int64_t value = 5;
auto in_val = mm->add_literal(migraphx::literal(s, {value}));
auto* body = p.create_module("loop_module");
auto iter = body->add_parameter("iter_num", si);
body->add_parameter("cond", sc);
auto in_v = body->add_parameter("input", s);
std::vector<int64_t> vd = {3};
auto l = body->add_literal(migraphx::literal(si, vd));
auto ad = body->add_instruction(migraphx::make_op("add"), iter, l);
auto val = body->add_instruction(migraphx::make_op("add"), in_v, ad);
auto eq = body->add_instruction(migraphx::make_op("equal"), iter, l);
auto beq = body->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), eq);
auto neq = body->add_instruction(migraphx::make_op("not"), beq);
body->add_return({neq, val, val});
auto rl = mm->add_instruction(
migraphx::make_op("loop", {{"max_iterations", 8}}), {in_iter, in_cond, in_val}, {body});
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), rl);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), rl);
mm->add_return({r0, r1});
return p;
}
};
...@@ -16,9 +16,9 @@ struct test_mul_add : verify_program<test_mul_add> ...@@ -16,9 +16,9 @@ struct test_mul_add : verify_program<test_mul_add>
auto a = mm->add_parameter("a", bs); auto a = mm->add_parameter("a", bs);
auto b = mm->add_parameter("b", bs); auto b = mm->add_parameter("b", bs);
auto ab = mm->add_instruction( auto ab = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", s.lens()}}), a); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", s.lens()}}), a);
auto bb = mm->add_instruction( auto bb = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", s.lens()}}), b); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", s.lens()}}), b);
auto mul = mm->add_instruction(migraphx::make_op("mul"), x, ab); auto mul = mm->add_instruction(migraphx::make_op("mul"), x, ab);
mm->add_instruction(migraphx::make_op("add"), mul, bb); mm->add_instruction(migraphx::make_op("add"), mul, bb);
return p; return p;
......
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_multinomial : verify_program<test_multinomial>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
size_t sample_size = 10;
size_t batch_size = 2;
float seed = 0.0f;
std::mt19937 gen(seed);
std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> rand_samples(batch_size * sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
migraphx::shape rs{migraphx::shape::float_type, {batch_size, sample_size}};
auto rs_lit = mm->add_literal(migraphx::literal{rs, rand_samples});
migraphx::shape s{migraphx::shape::float_type, {batch_size, 5}};
auto input = mm->add_parameter("input", s);
auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input);
auto mb_maxes = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {batch_size, 5}}}), maxes);
auto cdf = mm->add_instruction(migraphx::make_op("sub"), input, mb_maxes);
cdf = mm->add_instruction(migraphx::make_op("exp"), cdf);
cdf = mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);
mm->add_instruction(migraphx::make_op("multinomial"), cdf, rs_lit);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_nonzero : verify_program<test_nonzero>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}};
auto x = mm->add_parameter("data", s);
auto r = mm->add_instruction(migraphx::make_op("nonzero"), x);
mm->add_return({r});
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment