Unverified Commit 6cca6343 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Reimplement where op (#676)



* reimplement the where op to avoid inf value issue

* clang format

* fixed a bug in a unit test

* clang format

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent fe23d31c
......@@ -2751,13 +2751,45 @@ struct onnx_parser
instruction_ref
parse_where(const std::string&, const node_info&, std::vector<instruction_ref> args)
{
auto type = args[1]->get_shape().type();
// the operation of if cond == 1 select x; else select y,
// is equivalent to cond * (x - y) + y
auto cond = mm->add_instruction(make_op("convert", {{"target_type", type}}), args[0]);
auto diff = add_broadcastable_binary_op(args[1], args[2], "sub");
auto cd = add_broadcastable_binary_op(diff, cond, "mul");
return add_broadcastable_binary_op(cd, args[2], "add");
auto cond =
mm->add_instruction(make_op("convert", {{"target_type", shape::int32_type}}), args[0]);
auto lens = compute_broadcasted_lens(cond->get_shape().lens(), args[1]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(cond->get_shape().lens() != lens)
{
cond = mm->add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), cond);
}
if(args[1]->get_shape().lens() != lens)
{
args[1] =
mm->add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), args[1]);
}
if(args[2]->get_shape().lens() != lens)
{
args[2] =
mm->add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), args[2]);
}
// compute index
auto elem_num = args[1]->get_shape().elements();
// concatenation of input data
auto concat_data = mm->add_instruction(make_op("concat", {{"axis", 0}}), args[2], args[1]);
std::vector<int64_t> dims = {static_cast<int64_t>(2 * elem_num)};
auto rsp_data = mm->add_instruction(make_op("reshape", {{"dims", dims}}), concat_data);
std::vector<int> ind(elem_num);
std::iota(ind.begin(), ind.end(), 0);
shape ind_s{shape::int32_type, lens};
auto l_ind = mm->add_literal(literal(ind_s, ind));
std::vector<int> offset(elem_num, elem_num);
auto l_offset = mm->add_literal(literal({shape::int32_type, lens}, offset));
auto ins_offset = mm->add_instruction(make_op("mul"), l_offset, cond);
auto ins_ind = mm->add_instruction(make_op("add"), ins_offset, l_ind);
return mm->add_instruction(make_op("gather", {{"axis", 0}}), rsp_data, ins_ind);
}
void parse_from(std::istream& is, std::string name = "")
......
......@@ -1993,8 +1993,7 @@ TEST_CASE(reshape_non_standard_test)
TEST_CASE(resize_downsample_f_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto* mm = p.get_main_module();
std::vector<float> ds = {1.0f, 1.0f, 0.6f, 0.6f};
migraphx::shape ss{migraphx::shape::float_type, {4}};
mm->add_literal(migraphx::literal{ss, ds});
......@@ -2659,18 +2658,30 @@ TEST_CASE(variable_batch_leq_zero_test)
TEST_CASE(where_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto lc = mm->add_parameter("c", migraphx::shape{migraphx::shape::bool_type, {2}});
auto lx = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto ly = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 1, 2, 2}});
auto lcc = mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, lc);
auto lxm = mm->add_instruction(migraphx::op::multibroadcast{{2, 2, 2, 2}}, lx);
auto lym = mm->add_instruction(migraphx::op::multibroadcast{{2, 2, 2, 2}}, ly);
auto lxy = mm->add_instruction(migraphx::op::sub{}, lxm, lym);
auto lccm = mm->add_instruction(migraphx::op::multibroadcast{{2, 2, 2, 2}}, lcc);
auto lm = mm->add_instruction(migraphx::op::mul{}, lxy, lccm);
auto lym1 = mm->add_instruction(migraphx::op::multibroadcast{{2, 2, 2, 2}}, ly);
auto r = mm->add_instruction(migraphx::op::add{}, lm, lym1);
auto* mm = p.get_main_module();
auto lc = mm->add_parameter("c", migraphx::shape{migraphx::shape::bool_type, {2}});
auto lx = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto ly = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 1, 2, 2}});
auto int_c = mm->add_instruction(migraphx::op::convert{migraphx::shape::int32_type}, lc);
auto lccm = mm->add_instruction(migraphx::op::multibroadcast{{2, 2, 2, 2}}, int_c);
auto lxm = mm->add_instruction(migraphx::op::multibroadcast{{2, 2, 2, 2}}, lx);
auto lym = mm->add_instruction(migraphx::op::multibroadcast{{2, 2, 2, 2}}, ly);
auto concat_data = mm->add_instruction(migraphx::op::concat{0}, lym, lxm);
auto rsp_data = mm->add_instruction(migraphx::op::reshape{{32}}, concat_data);
std::vector<int> offset(16, 16);
std::vector<int> ind(16);
std::iota(ind.begin(), ind.end(), 0);
migraphx::shape ind_s{migraphx::shape::int32_type, {2, 2, 2, 2}};
auto lind = mm->add_literal(migraphx::literal(ind_s, ind));
auto loffset = mm->add_literal(migraphx::literal(ind_s, offset));
auto ins_co = mm->add_instruction(migraphx::op::mul{}, loffset, lccm);
auto ins_ind = mm->add_instruction(migraphx::op::add{}, ins_co, lind);
auto r = mm->add_instruction(migraphx::op::gather{0}, rsp_data, ins_ind);
mm->add_return({r});
auto prog = migraphx::parse_onnx("where_test.onnx");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment