Unverified Commit 39dca6ac authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Embedding bag op (limited support) (#545)

* initial progress on embedding_bag

* formatting

* add test

* move enum

* formatting

* fix tidy

* improve test for more coverage

* formatting

* update arg and test

* formatting

* add more tests

* formatting

* fix enum

* formatting
parent cb722cf9
......@@ -86,6 +86,7 @@ struct onnx_parser
add_variadic_op("Max", op::max{});
add_variadic_op("Min", op::min{});
add_mem_op("ATen", &onnx_parser::parse_aten);
add_mem_op("AveragePool", &onnx_parser::parse_pooling);
add_mem_op("ArgMax", &onnx_parser::parse_arg_op<op::argmax>);
add_mem_op("ArgMin", &onnx_parser::parse_arg_op<op::argmin>);
......@@ -112,10 +113,12 @@ struct onnx_parser
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
add_mem_op("LogSoftmax", &onnx_parser::parse_softmax<op::logsoftmax>);
add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("MatMul", &onnx_parser::parse_matmul<op::dot>);
add_mem_op("MatMulInteger", &onnx_parser::parse_matmul<op::quant_dot>);
add_mem_op("MaxPool", &onnx_parser::parse_pooling);
add_mem_op("OneHot", &onnx_parser::parse_onehot);
add_mem_op("Pad", &onnx_parser::parse_pad);
add_mem_op("Range", &onnx_parser::parse_range);
add_mem_op("ReduceL1", &onnx_parser::parse_reduce_l1);
add_mem_op("ReduceL2", &onnx_parser::parse_reduce_l2);
......@@ -129,7 +132,6 @@ struct onnx_parser
add_mem_op("ReduceSumSquare", &onnx_parser::parse_reduce_sum_square);
add_mem_op("Reshape", &onnx_parser::parse_reshape);
add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("Pad", &onnx_parser::parse_pad);
add_mem_op("Shape", &onnx_parser::parse_shape);
add_mem_op("Slice", &onnx_parser::parse_slice);
add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
......@@ -138,7 +140,6 @@ struct onnx_parser
add_mem_op("Tile", &onnx_parser::parse_tile);
add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
add_mem_op("LSTM", &onnx_parser::parse_lstm);
// init the activation function map
init_actv_func();
......@@ -2105,6 +2106,47 @@ struct onnx_parser
return l0;
}
enum class reduce_mode_t
{
sum = 0,
mean = 1,
max = 2
};
instruction_ref parse_embedding_bag(const node_info& info, std::vector<instruction_ref> args)
{
if(args[2]->get_shape().elements() != 1)
MIGRAPHX_THROW("PARSE_EMBEDDING_BAG: MIGraphX only supports offsets of size 1");
reduce_mode_t reduce_mode = reduce_mode_t::sum;
if(contains(info.attributes, "mode"))
{
reduce_mode = static_cast<reduce_mode_t>(info.attributes.at("mode").i());
}
auto l0 = prog.add_instruction(op::gather{}, args[0], args[1]);
switch(reduce_mode)
{
case reduce_mode_t::sum: l0 = prog.add_instruction(op::reduce_sum{{0}}, l0); break;
case reduce_mode_t::mean: l0 = prog.add_instruction(op::reduce_mean{{0}}, l0); break;
case reduce_mode_t::max: l0 = prog.add_instruction(op::reduce_max{{0}}, l0); break;
}
return l0;
}
instruction_ref
parse_aten(const std::string&, const node_info& info, std::vector<instruction_ref> args)
{
if(contains(info.attributes, "operator"))
{
auto op_name = info.attributes.at("operator").s();
if(op_name.find("embedding_bag") != std::string::npos)
{
return parse_embedding_bag(info, std::move(args));
}
}
MIGRAPHX_THROW("PARSE_ATEN: unsupported custom operator");
}
void parse_from(std::istream& is)
{
onnx::ModelProto model;
......
......@@ -941,6 +941,98 @@ def elu_test():
return ([node], [x], [y])
@onnx_test
def embedding_bag_test():
index_val = np.array([1, 0, 2])
offset_val = np.array([0])
index_tensor = helper.make_tensor(name='index_val',
data_type=TensorProto.INT32,
dims=index_val.shape,
vals=index_val.astype(np.int32))
index = onnx.helper.make_node('Constant',
inputs=[],
outputs=['index'],
value=index_tensor)
offset_tensor = helper.make_tensor(name='offset_val',
data_type=TensorProto.INT32,
dims=offset_val.reshape(()).shape,
vals=offset_val.astype(np.int32))
offset = onnx.helper.make_node('Constant',
inputs=[],
outputs=['offset'],
value=offset_tensor)
weight = helper.make_tensor_value_info('weight', TensorProto.FLOAT, [4, 2])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [1, 2])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [1, 2])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [1, 2])
node1 = onnx.helper.make_node('ATen',
inputs=['weight', 'index', 'offset'],
outputs=['y1'],
mode=0,
operator='embedding_bag')
node2 = onnx.helper.make_node('ATen',
inputs=['weight', 'index', 'offset'],
outputs=['y2'],
mode=1,
operator='embedding_bag')
node3 = onnx.helper.make_node('ATen',
inputs=['weight', 'index', 'offset'],
outputs=['y3'],
mode=2,
operator='embedding_bag')
return ([index, offset, node1, node2, node3], [weight], [y1, y2, y3])
@onnx_test
def embedding_bag_offset_test():
index_val = np.array([1, 0])
offset_val = np.array([0, 1])
index_tensor = helper.make_tensor(name='index_val',
data_type=TensorProto.INT32,
dims=index_val.shape,
vals=index_val.astype(np.int32))
index = onnx.helper.make_node('Constant',
inputs=[],
outputs=['index'],
value=index_tensor)
offset_tensor = helper.make_tensor(name='offset_val',
data_type=TensorProto.INT32,
dims=offset_val.shape,
vals=offset_val.astype(np.int32))
offset = onnx.helper.make_node('Constant',
inputs=[],
outputs=['offset'],
value=offset_tensor)
weight = helper.make_tensor_value_info('weight', TensorProto.FLOAT, [2, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3])
node = onnx.helper.make_node('ATen',
inputs=['weight', 'index', 'offset'],
outputs=['y'],
mode=0,
operator='embedding_bag')
return ([index, offset, node], [weight], [y])
@onnx_test
def erf_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
......@@ -2477,6 +2569,23 @@ def unknown_test():
return ([node, node2], [x, y], [a])
@onnx_test
def unknown_aten_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 4])
helper.make_tensor_value_info('2', TensorProto.FLOAT, [2, 3, 4, 5])
a = helper.make_tensor_value_info('3', TensorProto.FLOAT, [2, 3, 4, 5])
node = onnx.helper.make_node('ATen',
inputs=['0', '1'],
outputs=['2'],
operator='unknown')
return ([node], [x, y], [a])
@onnx_test
def variable_batch_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT,
......
......@@ -685,6 +685,31 @@ TEST_CASE(elu_test)
EXPECT(p == prog);
}
TEST_CASE(embedding_bag_test)
{
migraphx::program p;
auto l0 = p.add_parameter("weight", migraphx::shape{migraphx::shape::float_type, {4, 2}});
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {3}}, {1, 0, 2}};
auto l1 = p.add_literal(l);
p.add_literal(0);
auto l4 = p.add_instruction(migraphx::op::gather{}, l0, l1);
auto r1 = p.add_instruction(migraphx::op::reduce_sum{{0}}, l4);
auto l5 = p.add_instruction(migraphx::op::gather{}, l0, l1);
auto r2 = p.add_instruction(migraphx::op::reduce_mean{{0}}, l5);
auto l6 = p.add_instruction(migraphx::op::gather{}, l0, l1);
auto r3 = p.add_instruction(migraphx::op::reduce_max{{0}}, l6);
p.add_return({r1, r2, r3});
auto prog = migraphx::parse_onnx("embedding_bag_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(embedding_bag_offset_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("embedding_bag_offset_test.onnx"); }));
}
TEST_CASE(erf_test)
{
migraphx::program p;
......@@ -1854,6 +1879,11 @@ TEST_CASE(unknown_test)
EXPECT(p == prog);
}
TEST_CASE(unknown_aten_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("unknown_aten_test.onnx"); }));
}
TEST_CASE(unknown_test_throw)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("unknown_test.onnx"); }));
......
unknown_aten_test:
'
0
12"ATen*
operator"unknownunknown_aten_testZ
0




Z
1


b
3




B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment