Unverified Commit 0928c6cb authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Onehot operator (#510)

* new onehot implemenation by combining other operators

* clang format

* backup

* update unit tests

* clang format
parent f0e530f0
...@@ -1842,33 +1842,46 @@ struct onnx_parser ...@@ -1842,33 +1842,46 @@ struct onnx_parser
parse_onehot(const std::string&, node_info info, std::vector<instruction_ref> args) parse_onehot(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
migraphx::argument depth_arg = args[1]->eval(); migraphx::argument depth_arg = args[1]->eval();
check_arg_empty(depth_arg, "ONEHOT: depth - dynamic shape not supported"); check_arg_empty(depth_arg, "PARSE_ONEHOT: depth - dynamic shape not supported");
size_t depth = depth_arg.at<size_t>(); size_t depth = depth_arg.at<size_t>();
int64_t axis = -1; int64_t axis = -1;
std::vector<float> on_off_vals; if(contains(info.attributes, "axis"))
{
migraphx::argument values_arg = args[2]->eval(); axis = info.attributes.at("axis").i();
check_arg_empty(values_arg, "ONEHOT: values - dynamic shape not supported"); }
values_arg.visit([&](auto v) { copy(v, std::back_inserter(on_off_vals)); });
float off_value = on_off_vals[0];
float on_value = on_off_vals[1];
std::vector<float> depth_input(depth * depth, off_value); std::vector<float> depth_input(depth * depth, 0.0f);
for(int i = 0; i < depth; i++) for(int i = 0; i < depth; i++)
{ {
depth_input[depth * i + i] = on_value; depth_input[depth * i + i] = 1.0f;
} }
if(contains(info.attributes, "axis")) auto type = args[2]->get_shape().type();
axis = info.attributes.at("axis").i(); shape s{type, {depth, depth}};
if(axis == -1) auto l_val = prog.add_literal({s, depth_input});
auto gather_out = prog.add_instruction(op::gather{0}, {l_val, args[0]});
// Finally, we need a transpose to move the inner most dim to the axis dim
int n_rank = gather_out->get_shape().lens().size();
if(axis < -n_rank or axis >= n_rank)
{ {
shape s{shape::float_type, {depth, depth}}; MIGRAPHX_THROW("PARSE_ONEHOT: axis out of range");
auto l0 = prog.add_literal({s, depth_input});
return prog.add_instruction(op::gather{0}, {l0, args[0]});
} }
MIGRAPHX_THROW("ONEHOT: MIGraphX does not support axis != -1"); int64_t tuned_axis = (axis < 0) ? axis + n_rank : axis;
std::vector<int64_t> perm(n_rank - 1);
std::iota(perm.begin(), perm.end(), 0);
perm.insert(perm.begin() + tuned_axis, n_rank - 1);
auto tr_out = prog.add_instruction(op::transpose{perm}, gather_out);
auto lens = tr_out->get_shape().lens();
auto off_val = prog.add_instruction(op::slice{{0}, {0}, {1}}, args[2]);
auto on_val = prog.add_instruction(op::slice{{0}, {1}, {2}}, args[2]);
auto diff = prog.add_instruction(op::sub{}, on_val, off_val);
auto unsq_off_val = prog.add_instruction(op::multibroadcast{lens}, off_val);
auto unsq_diff_val = prog.add_instruction(op::multibroadcast{lens}, diff);
auto l_mul = prog.add_instruction(op::mul{}, tr_out, unsq_diff_val);
return prog.add_instruction(op::add{}, l_mul, unsq_off_val);
} }
void parse_from(std::istream& is) void parse_from(std::istream& is)
......
...@@ -1412,35 +1412,24 @@ def no_pad_test(): ...@@ -1412,35 +1412,24 @@ def no_pad_test():
@onnx_test @onnx_test
def onehot_test(): def onehot_test():
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [5, 2]) axis_value = 0
indices = np.ones((5)) depth = np.array([3])
axis_value = -1 indices = helper.make_tensor_value_info("indices", TensorProto.INT32,
on_value = 1.0 [5, 2])
off_value = 0.0 values = helper.make_tensor_value_info("values", TensorProto.FLOAT16, [2])
values = np.array([off_value, on_value]) y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [3, 5, 2])
depth = np.array([2])
indices_tensor = helper.make_tensor(name="indices",
data_type=TensorProto.INT32,
dims=indices.shape,
vals=indices.astype(int))
depth_tensor = helper.make_tensor(name="depth", depth_tensor = helper.make_tensor(name="depth",
data_type=TensorProto.INT32, data_type=TensorProto.INT32,
dims=None, dims=None,
vals=depth.astype(int)) vals=depth.astype(int))
values_tensor = helper.make_tensor(name="values",
data_type=TensorProto.FLOAT,
dims=values.shape,
vals=values.astype(float))
node = onnx.helper.make_node('OneHot', node = onnx.helper.make_node('OneHot',
inputs=['indices', 'depth', 'values'], inputs=['indices', 'depth', 'values'],
outputs=['y'], outputs=['y'],
axis=axis_value) axis=axis_value)
return ([node], [], [y], [indices_tensor, depth_tensor, values_tensor]) return ([node], [indices, values], [y], [depth_tensor])
@onnx_test @onnx_test
......
...@@ -1100,15 +1100,26 @@ TEST_CASE(no_pad_test) ...@@ -1100,15 +1100,26 @@ TEST_CASE(no_pad_test)
TEST_CASE(onehot_test) TEST_CASE(onehot_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_literal( migraphx::shape s_ind{migraphx::shape::int32_type, {5, 2}};
migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {5}}, {1, 1, 1, 1, 1}}); migraphx::shape s_val{migraphx::shape::half_type, {2}};
p.add_literal(2); p.add_literal(3);
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2}}, {0, 1}}); auto l_ind = p.add_parameter("indices", s_ind);
auto l1 = p.add_literal( auto l_val = p.add_parameter("values", s_val);
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 2}}, {1, 0, 0, 1}}); migraphx::shape s_dep{migraphx::shape::half_type, {3, 3}};
int axis = 0; std::vector<float> data_dep{1, 0, 0, 0, 1, 0, 0, 0, 1};
p.add_instruction(migraphx::op::gather{axis}, l1, l0); auto l_dep = p.add_literal(migraphx::literal(s_dep, data_dep));
auto prog = optimize_onnx("onehot_test.onnx"); auto gather_out = p.add_instruction(migraphx::op::gather{0}, l_dep, l_ind);
auto tr_out = p.add_instruction(migraphx::op::transpose{{2, 0, 1}}, gather_out);
auto off_val = p.add_instruction(migraphx::op::slice{{0}, {0}, {1}}, l_val);
auto on_val = p.add_instruction(migraphx::op::slice{{0}, {1}, {2}}, l_val);
auto diff = p.add_instruction(migraphx::op::sub{}, on_val, off_val);
auto mb_off_val = p.add_instruction(migraphx::op::multibroadcast{{3, 5, 2}}, off_val);
auto mb_diff = p.add_instruction(migraphx::op::multibroadcast{{3, 5, 2}}, diff);
auto mul = p.add_instruction(migraphx::op::mul{}, tr_out, mb_diff);
auto r = p.add_instruction(migraphx::op::add{}, mul, mb_off_val);
p.add_return({r});
auto prog = migraphx::parse_onnx("onehot_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment