Unverified Commit ffde465d authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Add OneHot op (#502)



* add onehot and tests, add missing gen_tf for onehot

* formatting

* fix cast

* specify onehot for error msg
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 10d072b9
...@@ -111,6 +111,7 @@ struct onnx_parser ...@@ -111,6 +111,7 @@ struct onnx_parser
add_mem_op("MatMul", &onnx_parser::parse_matmul<op::dot>); add_mem_op("MatMul", &onnx_parser::parse_matmul<op::dot>);
add_mem_op("MatMulInteger", &onnx_parser::parse_matmul<op::quant_dot>); add_mem_op("MatMulInteger", &onnx_parser::parse_matmul<op::quant_dot>);
add_mem_op("MaxPool", &onnx_parser::parse_pooling); add_mem_op("MaxPool", &onnx_parser::parse_pooling);
add_mem_op("OneHot", &onnx_parser::parse_onehot);
add_mem_op("ReduceL1", &onnx_parser::parse_reduce_l1); add_mem_op("ReduceL1", &onnx_parser::parse_reduce_l1);
add_mem_op("ReduceL2", &onnx_parser::parse_reduce_l2); add_mem_op("ReduceL2", &onnx_parser::parse_reduce_l2);
add_mem_op("ReduceLogSum", &onnx_parser::parse_reduce_log_sum); add_mem_op("ReduceLogSum", &onnx_parser::parse_reduce_log_sum);
...@@ -1807,6 +1808,39 @@ struct onnx_parser ...@@ -1807,6 +1808,39 @@ struct onnx_parser
return ret_ins; return ret_ins;
} }
instruction_ref
parse_onehot(const std::string&, node_info info, std::vector<instruction_ref> args)
{
migraphx::argument depth_arg = args[1]->eval();
check_arg_empty(depth_arg, "ONEHOT: depth - dynamic shape not supported");
size_t depth = depth_arg.at<size_t>();
int64_t axis = -1;
std::vector<float> on_off_vals;
migraphx::argument values_arg = args[2]->eval();
check_arg_empty(values_arg, "ONEHOT: values - dynamic shape not supported");
values_arg.visit([&](auto v) { copy(v, std::back_inserter(on_off_vals)); });
float off_value = on_off_vals[0];
float on_value = on_off_vals[1];
std::vector<float> depth_input(depth * depth, off_value);
for(int i = 0; i < depth; i++)
{
depth_input[depth * i + i] = on_value;
}
if(contains(info.attributes, "axis"))
axis = info.attributes.at("axis").i();
if(axis == -1)
{
shape s{shape::float_type, {depth, depth}};
auto l0 = prog.add_literal({s, depth_input});
return prog.add_instruction(op::gather{0}, {l0, args[0]});
}
MIGRAPHX_THROW("ONEHOT: MIGraphX does not support axis != -1");
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
onnx::ModelProto model; onnx::ModelProto model;
......
...@@ -1410,6 +1410,39 @@ def no_pad_test(): ...@@ -1410,6 +1410,39 @@ def no_pad_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def onehot_test():
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [5, 2])
indices = np.ones((5))
axis_value = -1
on_value = 1.0
off_value = 0.0
values = np.array([off_value, on_value])
depth = np.array([2])
indices_tensor = helper.make_tensor(name="indices",
data_type=TensorProto.INT32,
dims=indices.shape,
vals=indices.astype(int))
depth_tensor = helper.make_tensor(name="depth",
data_type=TensorProto.INT32,
dims=None,
vals=depth.astype(int))
values_tensor = helper.make_tensor(name="values",
data_type=TensorProto.FLOAT,
dims=values.shape,
vals=values.astype(float))
node = onnx.helper.make_node('OneHot',
inputs=['indices', 'depth', 'values'],
outputs=['y'],
axis=axis_value)
return ([node], [], [y], [indices_tensor, depth_tensor, values_tensor])
@onnx_test @onnx_test
def pad_test(): def pad_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 2]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 2])
......
...@@ -1097,6 +1097,22 @@ TEST_CASE(no_pad_test) ...@@ -1097,6 +1097,22 @@ TEST_CASE(no_pad_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(onehot_test)
{
migraphx::program p;
auto l0 = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {5}}, {1, 1, 1, 1, 1}});
p.add_literal(2);
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2}}, {0, 1}});
auto l1 = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 2}}, {1, 0, 0, 1}});
int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, l1, l0);
auto prog = optimize_onnx("onehot_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(pad_test) TEST_CASE(pad_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -202,6 +202,13 @@ def mul_test(g1): ...@@ -202,6 +202,13 @@ def mul_test(g1):
tf.multiply(g1_input, g2_input, name='mul1') tf.multiply(g1_input, g2_input, name='mul1')
@tf_test
def onehot_test(g1):
with g1.as_default():
g1_input = tf.constant((1, 1, 1, 1, 1), dtype=tf.int32)
tf.one_hot(g1_input, 2, name='onehot1')
@tf_test @tf_test
def pack_test(g1): def pack_test(g1):
with g1.as_default(): with g1.as_default():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment