Commit 52db0473 authored by Khalique's avatar Khalique
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into scalar_parsing

parents fb4b631d ca69c190
...@@ -63,7 +63,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -63,7 +63,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()}); auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()});
auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()}); auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()});
auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->inputs()[0], l_weights}); auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->inputs()[0], l_weights});
auto b = p.insert_instruction(ins, op::broadcast{1, c->get_shape()}, l_bias); auto b = p.insert_instruction(ins, op::broadcast{1, c->get_shape().lens()}, l_bias);
p.replace_instruction(ins, op::add{}, {c, b}); p.replace_instruction(ins, op::add{}, {c, b});
} }
} }
......
...@@ -27,38 +27,36 @@ namespace op { ...@@ -27,38 +27,36 @@ namespace op {
struct broadcast struct broadcast
{ {
uint64_t axis = 0; uint64_t axis = 0;
std::vector<std::size_t> broadcast_lens;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.axis, "axis")); return pack(f(self.axis, "axis"), f(self.broadcast_lens, "dims"));
} }
shape broadcast_shape;
std::string name() const { return "broadcast"; } std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
auto input = inputs.at(0); auto input = inputs.at(0);
std::vector<size_t> bcast_strides(broadcast_shape.lens().size(), 0); std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
if(std::all_of(broadcast_shape.lens().cbegin(), broadcast_shape.lens().cend(), [&](auto x) { if(std::all_of(
return x == 1; broadcast_lens.cbegin(), broadcast_lens.cend(), [&](auto x) { return x == 1; }))
}))
{ {
if(axis != 0) if(axis != 0)
MIGRAPHX_THROW("when broadcasting tensor of size 1, axis should be 0"); MIGRAPHX_THROW("BROADCAST: when broadcasting tensor of size 1, axis should be 0");
return {t, broadcast_shape.lens(), std::move(bcast_strides)}; return {t, broadcast_lens, std::move(bcast_strides)};
} }
else else
{ {
assert(broadcast_shape.lens().size() - axis >= input.lens().size()); assert(broadcast_lens.size() - axis >= input.lens().size());
if(!std::equal( if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
input.lens().begin(), input.lens().end(), broadcast_shape.lens().begin() + axis)) MIGRAPHX_THROW("BROADCAST: when broadcasting success sizes must match");
MIGRAPHX_THROW("when broadcasting success sizes must match");
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis); std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
return {t, broadcast_shape.lens(), std::move(bcast_strides)}; return {t, broadcast_lens, std::move(bcast_strides)};
} }
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
......
...@@ -18,7 +18,13 @@ namespace op { ...@@ -18,7 +18,13 @@ namespace op {
struct scalar struct scalar
{ {
shape scalar_bcast; std::vector<std::size_t> scalar_bcast_lens;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.scalar_bcast_lens, "scalar_bcst_dims"));
}
std::string name() const { return "scalar"; } std::string name() const { return "scalar"; }
...@@ -26,8 +32,8 @@ struct scalar ...@@ -26,8 +32,8 @@ struct scalar
{ {
assert(check_shapes{inputs}.has(1).only_dims(1).size() == 1); assert(check_shapes{inputs}.has(1).only_dims(1).size() == 1);
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
std::vector<std::size_t> strides(scalar_bcast.lens().size(), 0); std::vector<std::size_t> strides(scalar_bcast_lens.size(), 0);
return {t, scalar_bcast.lens(), strides}; return {t, scalar_bcast_lens, strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
......
...@@ -141,8 +141,8 @@ struct onnx_parser ...@@ -141,8 +141,8 @@ struct onnx_parser
if(broadcasted != 0) if(broadcasted != 0)
{ {
uint64_t axis = parse_value(attributes.at("axis")).at<uint64_t>(); uint64_t axis = parse_value(attributes.at("axis")).at<uint64_t>();
auto l = auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]); args[1]);
return prog.add_instruction(x, args[0], l); return prog.add_instruction(x, args[0], l);
} }
return prog.add_instruction(x, args); return prog.add_instruction(x, args);
...@@ -306,7 +306,7 @@ struct onnx_parser ...@@ -306,7 +306,7 @@ struct onnx_parser
{ {
uint64_t axis = 1; uint64_t axis = 1;
auto l1 = prog.add_instruction(op, args[0], args[1]); auto l1 = prog.add_instruction(op, args[0], args[1]);
auto l2 = prog.add_instruction(op::broadcast{axis, l1->get_shape()}, args[2]); auto l2 = prog.add_instruction(op::broadcast{axis, l1->get_shape().lens()}, args[2]);
return prog.add_instruction(op::add{}, l1, l2); return prog.add_instruction(op::add{}, l1, l2);
} }
return prog.add_instruction(op, l0, args[1]); return prog.add_instruction(op, l0, args[1]);
...@@ -671,15 +671,15 @@ struct onnx_parser ...@@ -671,15 +671,15 @@ struct onnx_parser
auto&& bias_floats = attributes["bias"].floats(); auto&& bias_floats = attributes["bias"].floats();
bias = std::vector<float>(bias_floats.begin(), bias_floats.end()); bias = std::vector<float>(bias_floats.begin(), bias_floats.end());
} }
auto input_shape = args.front()->get_shape(); auto input_lens = args.front()->get_shape().lens();
auto scale_val = prog.add_literal(scale); auto scale_val = prog.add_literal(scale);
auto bias_vals = prog.add_literal( auto bias_vals = prog.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {bias.size()}}, bias}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {bias.size()}}, bias});
auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_shape}, scale_val); auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
auto img_scaled = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor); auto img_scaled = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
auto bias_bcast = prog.add_instruction(migraphx::op::broadcast{1, input_shape}, bias_vals); auto bias_bcast = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast); return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
} }
......
...@@ -214,7 +214,7 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -214,7 +214,7 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias); auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto b = prog.insert_instruction(ins, op::add{}, wb, rb); auto b = prog.insert_instruction(ins, op::add{}, wb, rb);
bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b); bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape().lens()}, b);
} }
instruction_ref hidden_out = prog.end(); instruction_ref hidden_out = prog.end();
...@@ -521,25 +521,26 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -521,25 +521,26 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
instruction_ref brcst_bh{}; instruction_ref brcst_bh{};
if(bias != prog.end()) if(bias != prog.end())
{ {
auto broadcast_lens = sih->get_shape().lens();
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias); auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias); auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
brcst_wbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wbh); brcst_wbh = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, wbh);
auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias); auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias); auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias); auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
brcst_rbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rbh); brcst_rbh = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, rbh);
auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz); auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz);
brcst_bz = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bz); brcst_bz = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, bz);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr); auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
brcst_br = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, br); brcst_br = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, br);
auto bh = prog.insert_instruction(ins, op::add{}, wbh, rbh); auto bh = prog.insert_instruction(ins, op::add{}, wbh, rbh);
brcst_bh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bh); brcst_bh = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, bh);
} }
for(long i = 0; i < seq_len; i++) for(long i = 0; i < seq_len; i++)
...@@ -947,7 +948,7 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -947,7 +948,7 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
// initial cell state // initial cell state
auto sic = prog.insert_instruction(ins, op::squeeze{{0}}, ic); auto sic = prog.insert_instruction(ins, op::squeeze{{0}}, ic);
auto ic_shape = sic->get_shape(); auto ic_lens = sic->get_shape().lens();
// bias // bias
instruction_ref bi_brcst{}; instruction_ref bi_brcst{};
...@@ -956,26 +957,27 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -956,26 +957,27 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
instruction_ref bc_brcst{}; instruction_ref bc_brcst{};
if(bias != prog.end()) if(bias != prog.end())
{ {
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto bxi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto bxi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto bhi = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias); auto bhi = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto bi = prog.insert_instruction(ins, op::add{}, bxi, bhi); auto bi = prog.insert_instruction(ins, op::add{}, bxi, bhi);
bi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bi); bi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bi);
auto bxo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias); auto bxo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto bho = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias); auto bho = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
auto bo = prog.insert_instruction(ins, op::add{}, bxo, bho); auto bo = prog.insert_instruction(ins, op::add{}, bxo, bho);
bo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bo); bo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bo);
auto bxf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias); auto bxf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
auto bhf = prog.insert_instruction(ins, op::slice{{0}, {6 * hs}, {7 * hs}}, sbias); auto bhf = prog.insert_instruction(ins, op::slice{{0}, {6 * hs}, {7 * hs}}, sbias);
auto bf = prog.insert_instruction(ins, op::add{}, bxf, bhf); auto bf = prog.insert_instruction(ins, op::add{}, bxf, bhf);
bf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bf); bf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bf);
auto bxc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias); auto bxc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto bhc = prog.insert_instruction(ins, op::slice{{0}, {7 * hs}, {8 * hs}}, sbias); auto bhc = prog.insert_instruction(ins, op::slice{{0}, {7 * hs}, {8 * hs}}, sbias);
auto bc = prog.insert_instruction(ins, op::add{}, bxc, bhc); auto bc = prog.insert_instruction(ins, op::add{}, bxc, bhc);
bc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bc); bc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bc);
} }
// peep hole // peep hole
...@@ -987,13 +989,13 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -987,13 +989,13 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
{ {
auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph); auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph); auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph);
pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphi); pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphi);
auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph); auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph);
ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, ppho); ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, ppho);
auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph); auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph);
pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphf); pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphf);
} }
for(long i = 0; i < seq_len; ++i) for(long i = 0; i < seq_len; ++i)
......
...@@ -238,7 +238,7 @@ struct tf_parser ...@@ -238,7 +238,7 @@ struct tf_parser
parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{ {
uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel) uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
auto l0 = prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]); auto l0 = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()}, args[1]);
return prog.add_instruction(op::add{}, args[0], l0); return prog.add_instruction(op::add{}, args[0], l0);
} }
......
...@@ -60,7 +60,7 @@ TEST_CASE(after_literal_broadcast) ...@@ -60,7 +60,7 @@ TEST_CASE(after_literal_broadcast)
auto l2 = p.add_literal(get_2()); auto l2 = p.add_literal(get_2());
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted()); EXPECT(not p.get_shape().broadcasted());
auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape()}, l2); auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2);
p.add_instruction(pass_op{}, b); p.add_instruction(pass_op{}, b);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().broadcasted()); EXPECT(p.get_shape().broadcasted());
...@@ -91,7 +91,7 @@ TEST_CASE(after_param_broadcast) ...@@ -91,7 +91,7 @@ TEST_CASE(after_param_broadcast)
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {2}}); auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {2}});
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted()); EXPECT(not p.get_shape().broadcasted());
auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape()}, l2); auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2);
p.add_instruction(pass_op{}, b); p.add_instruction(pass_op{}, b);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().broadcasted()); EXPECT(p.get_shape().broadcasted());
......
...@@ -351,7 +351,7 @@ TEST_CASE(gemm_mutli_dim1_2_3) ...@@ -351,7 +351,7 @@ TEST_CASE(gemm_mutli_dim1_2_3)
float beta = 0.41; float beta = 0.41;
auto m12_alpha = p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2); auto m12_alpha = p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2);
auto l_beta = p.add_literal(beta); auto l_beta = p.add_literal(beta);
auto b_beta = p.add_instruction(migraphx::op::scalar{m12_alpha->get_shape()}, l_beta); auto b_beta = p.add_instruction(migraphx::op::scalar{m12_alpha->get_shape().lens()}, l_beta);
auto m3_beta = p.add_instruction(migraphx::op::mul{}, b_beta, l3); auto m3_beta = p.add_instruction(migraphx::op::mul{}, b_beta, l3);
p.add_instruction(migraphx::op::add{}, m3_beta, m12_alpha); p.add_instruction(migraphx::op::add{}, m3_beta, m12_alpha);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
......
...@@ -651,7 +651,7 @@ TEST_CASE(broadcast_test) ...@@ -651,7 +651,7 @@ TEST_CASE(broadcast_test)
uint64_t axis = 0; uint64_t axis = 0;
auto l1 = p.add_literal(migraphx::literal{a_shape, a_data}); auto l1 = p.add_literal(migraphx::literal{a_shape, a_data});
auto l2 = p.add_literal(migraphx::literal{b_shape, b_data}); auto l2 = p.add_literal(migraphx::literal{b_shape, b_data});
p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape()}, l2); p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape().lens()}, l2);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
auto output = result.get<int32_t>(); auto output = result.get<int32_t>();
...@@ -671,7 +671,7 @@ TEST_CASE(add_broadcast_test) ...@@ -671,7 +671,7 @@ TEST_CASE(add_broadcast_test)
uint64_t axis = 0; uint64_t axis = 0;
auto l1 = p.add_literal(migraphx::literal{a_shape, a_data}); auto l1 = p.add_literal(migraphx::literal{a_shape, a_data});
auto l2 = p.add_literal(migraphx::literal{b_shape, b_data}); auto l2 = p.add_literal(migraphx::literal{b_shape, b_data});
auto l3 = p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape()}, l2); auto l3 = p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape().lens()}, l2);
p.add_instruction(migraphx::op::add{}, l1, l3); p.add_instruction(migraphx::op::add{}, l1, l3);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -809,11 +809,11 @@ TEST_CASE(imagescaler_test) ...@@ -809,11 +809,11 @@ TEST_CASE(imagescaler_test)
0.35, 0.35,
0.45}}); 0.45}});
auto scale_val = p.add_literal(2.f); auto scale_val = p.add_literal(2.f);
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s}, scale_val); auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s.lens()}, scale_val);
auto img_scaled = p.add_instruction(migraphx::op::mul{}, img, scaled_tensor); auto img_scaled = p.add_instruction(migraphx::op::mul{}, img, scaled_tensor);
auto bias_vals = p.add_literal( auto bias_vals = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s}, bias_vals); auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals);
p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast); p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
......
...@@ -371,7 +371,7 @@ struct test_scale : verify_program<test_scale> ...@@ -371,7 +371,7 @@ struct test_scale : verify_program<test_scale>
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", migraphx::shape::float_type); auto y = p.add_parameter("y", migraphx::shape::float_type);
auto scale = p.add_instruction(migraphx::op::scalar{s}, y); auto scale = p.add_instruction(migraphx::op::scalar{s.lens()}, y);
p.add_instruction(migraphx::op::mul{}, x, scale); p.add_instruction(migraphx::op::mul{}, x, scale);
return p; return p;
} }
...@@ -417,7 +417,7 @@ struct test_triadd2 : verify_program<test_triadd2> ...@@ -417,7 +417,7 @@ struct test_triadd2 : verify_program<test_triadd2>
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s); auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", b); auto z = p.add_parameter("z", b);
auto zb = p.add_instruction(migraphx::op::broadcast{1, s}, z); auto zb = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, z);
auto sum = p.add_instruction(migraphx::op::add{}, x, y); auto sum = p.add_instruction(migraphx::op::add{}, x, y);
p.add_instruction(migraphx::op::add{}, sum, zb); p.add_instruction(migraphx::op::add{}, sum, zb);
return p; return p;
...@@ -432,7 +432,7 @@ struct test_add_broadcast : verify_program<test_add_broadcast> ...@@ -432,7 +432,7 @@ struct test_add_broadcast : verify_program<test_add_broadcast>
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {2, 2}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {2, 2}});
auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by); p.add_instruction(migraphx::op::add{}, x, by);
return p; return p;
} }
...@@ -446,7 +446,7 @@ struct test_add_broadcast2 : verify_program<test_add_broadcast2> ...@@ -446,7 +446,7 @@ struct test_add_broadcast2 : verify_program<test_add_broadcast2>
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 3, 4}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 3, 4}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {3}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {3}});
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by); p.add_instruction(migraphx::op::add{}, x, by);
return p; return p;
} }
...@@ -460,7 +460,7 @@ struct test_add_broadcast3 : verify_program<test_add_broadcast3> ...@@ -460,7 +460,7 @@ struct test_add_broadcast3 : verify_program<test_add_broadcast3>
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 4, 5}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 4, 5}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {4}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {4}});
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by); p.add_instruction(migraphx::op::add{}, x, by);
return p; return p;
} }
...@@ -474,7 +474,7 @@ struct test_add_broadcast4 : verify_program<test_add_broadcast4> ...@@ -474,7 +474,7 @@ struct test_add_broadcast4 : verify_program<test_add_broadcast4>
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 3, 5}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 3, 5}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {3}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {3}});
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by); p.add_instruction(migraphx::op::add{}, x, by);
return p; return p;
} }
...@@ -488,7 +488,7 @@ struct test_add_broadcast5 : verify_program<test_add_broadcast5> ...@@ -488,7 +488,7 @@ struct test_add_broadcast5 : verify_program<test_add_broadcast5>
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 4, 8}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 4, 8}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {4}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {4}});
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by); p.add_instruction(migraphx::op::add{}, x, by);
return p; return p;
} }
...@@ -503,7 +503,7 @@ struct test_triadd_broadcast : verify_program<test_triadd_broadcast> ...@@ -503,7 +503,7 @@ struct test_triadd_broadcast : verify_program<test_triadd_broadcast>
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {2, 2}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {2, 2}});
auto z = p.add_parameter("z", {migraphx::shape::float_type, {2, 2, 3}}); auto z = p.add_parameter("z", {migraphx::shape::float_type, {2, 2, 3}});
auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape().lens()}, y);
auto sum = p.add_instruction(migraphx::op::add{}, x, by); auto sum = p.add_instruction(migraphx::op::add{}, x, by);
p.add_instruction(migraphx::op::add{}, sum, z); p.add_instruction(migraphx::op::add{}, sum, z);
return p; return p;
...@@ -535,7 +535,7 @@ struct test_sub2 : verify_program<test_sub2> ...@@ -535,7 +535,7 @@ struct test_sub2 : verify_program<test_sub2>
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s); auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", b); auto z = p.add_parameter("z", b);
auto zb = p.add_instruction(migraphx::op::broadcast{1, s}, z); auto zb = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, z);
auto diff = p.add_instruction(migraphx::op::sub{}, x, y); auto diff = p.add_instruction(migraphx::op::sub{}, x, y);
p.add_instruction(migraphx::op::sub{}, diff, zb); p.add_instruction(migraphx::op::sub{}, diff, zb);
return p; return p;
......
...@@ -15,7 +15,7 @@ TEST_CASE(pytorch_conv_bias_test) ...@@ -15,7 +15,7 @@ TEST_CASE(pytorch_conv_bias_test)
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}}); auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1); auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2); auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
p.add_instruction(migraphx::op::add{}, l3, l4); p.add_instruction(migraphx::op::add{}, l3, l4);
auto prog = migraphx::parse_onnx("conv.onnx"); auto prog = migraphx::parse_onnx("conv.onnx");
...@@ -30,7 +30,7 @@ TEST_CASE(pytorch_conv_relu_maxpool) ...@@ -30,7 +30,7 @@ TEST_CASE(pytorch_conv_relu_maxpool)
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}}); auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1); auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2); auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4); auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraphx::op::relu{}, l5); auto l6 = p.add_instruction(migraphx::op::relu{}, l5);
p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6); p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
...@@ -52,7 +52,7 @@ TEST_CASE(pytorch_conv_bn_relu_maxpool) ...@@ -52,7 +52,7 @@ TEST_CASE(pytorch_conv_bn_relu_maxpool)
auto p6 = p.add_parameter("6", {migraphx::shape::float_type, {1}}); auto p6 = p.add_parameter("6", {migraphx::shape::float_type, {1}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1); auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2); auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4); auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraphx::op::batch_norm_inference{1.0e-5f}, l5, p3, p4, p5, p6); auto l6 = p.add_instruction(migraphx::op::batch_norm_inference{1.0e-5f}, l5, p3, p4, p5, p6);
auto l7 = p.add_instruction(migraphx::op::relu{}, l6); auto l7 = p.add_instruction(migraphx::op::relu{}, l6);
...@@ -70,7 +70,7 @@ TEST_CASE(pytorch_conv_relu_maxpool_x2) ...@@ -70,7 +70,7 @@ TEST_CASE(pytorch_conv_relu_maxpool_x2)
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {5}}); auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {5}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1); auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2); auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4); auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraphx::op::relu{}, l5); auto l6 = p.add_instruction(migraphx::op::relu{}, l5);
auto l7 = p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6); auto l7 = p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
...@@ -78,7 +78,7 @@ TEST_CASE(pytorch_conv_relu_maxpool_x2) ...@@ -78,7 +78,7 @@ TEST_CASE(pytorch_conv_relu_maxpool_x2)
auto l8 = p.add_parameter("3", {migraphx::shape::float_type, {1, 5, 5, 5}}); auto l8 = p.add_parameter("3", {migraphx::shape::float_type, {1, 5, 5, 5}});
auto l9 = p.add_parameter("4", {migraphx::shape::float_type, {1}}); auto l9 = p.add_parameter("4", {migraphx::shape::float_type, {1}});
auto l10 = p.add_instruction(migraphx::op::convolution{}, l7, l8); auto l10 = p.add_instruction(migraphx::op::convolution{}, l7, l8);
auto l11 = p.add_instruction(migraphx::op::broadcast{axis, l10->get_shape()}, l9); auto l11 = p.add_instruction(migraphx::op::broadcast{axis, l10->get_shape().lens()}, l9);
auto l12 = p.add_instruction(migraphx::op::add{}, l10, l11); auto l12 = p.add_instruction(migraphx::op::add{}, l10, l11);
auto l13 = p.add_instruction(migraphx::op::relu{}, l12); auto l13 = p.add_instruction(migraphx::op::relu{}, l12);
p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13); p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13);
...@@ -108,9 +108,9 @@ TEST_CASE(imagescaler_test) ...@@ -108,9 +108,9 @@ TEST_CASE(imagescaler_test)
auto scale_val = p.add_literal(0.5f); auto scale_val = p.add_literal(0.5f);
auto bias_vals = p.add_literal( auto bias_vals = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s}, scale_val); auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s.lens()}, scale_val);
auto img_scaled = p.add_instruction(migraphx::op::mul{}, l0, scaled_tensor); auto img_scaled = p.add_instruction(migraphx::op::mul{}, l0, scaled_tensor);
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s}, bias_vals); auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals);
p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast); p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
auto prog = migraphx::parse_onnx("imagescaler_test.onnx"); auto prog = migraphx::parse_onnx("imagescaler_test.onnx");
...@@ -338,7 +338,7 @@ TEST_CASE(add_bcast_test) ...@@ -338,7 +338,7 @@ TEST_CASE(add_bcast_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape()}, l1); auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::add{}, l0, l2); p.add_instruction(migraphx::op::add{}, l0, l2);
auto prog = migraphx::parse_onnx("add_bcast_test.onnx"); auto prog = migraphx::parse_onnx("add_bcast_test.onnx");
...@@ -365,7 +365,7 @@ TEST_CASE(sub_bcast_test) ...@@ -365,7 +365,7 @@ TEST_CASE(sub_bcast_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape()}, l1); auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::sub{}, l0, l2); p.add_instruction(migraphx::op::sub{}, l0, l2);
auto prog = migraphx::parse_onnx("sub_bcast_test.onnx"); auto prog = migraphx::parse_onnx("sub_bcast_test.onnx");
......
...@@ -229,6 +229,36 @@ TEST_CASE(multibroadcast) ...@@ -229,6 +229,36 @@ TEST_CASE(multibroadcast)
} }
} }
TEST_CASE(broadcast)
{
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
migraphx::op::broadcast{0, lens},
input);
}
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 2, 4, 3}, {0, 0, 3, 1}},
migraphx::op::broadcast{2, lens},
input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::op::broadcast{2, lens}, input);
}
}
TEST_CASE(gather) TEST_CASE(gather)
{ {
{ {
......
...@@ -63,7 +63,7 @@ TEST_CASE(biasadd_test) ...@@ -63,7 +63,7 @@ TEST_CASE(biasadd_test)
uint64_t axis = 1; uint64_t axis = 1;
auto l0 = p.add_parameter("0", s0); auto l0 = p.add_parameter("0", s0);
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
auto l2 = p.add_instruction(migraphx::op::broadcast{axis, l0->get_shape()}, l1); auto l2 = p.add_instruction(migraphx::op::broadcast{axis, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::add{}, l0, l2); p.add_instruction(migraphx::op::add{}, l0, l2);
auto prog = migraphx::parse_tf("biasadd_test.pb", true); auto prog = migraphx::parse_tf("biasadd_test.pb", true);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment