Unverified Commit 369b9f60 authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Add range op (#529)



* add range op and test

* add float test

* use generate for vector

* formatting

* fix delta arg check

* fix assert

* formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 0079028a
......@@ -116,6 +116,7 @@ struct onnx_parser
add_mem_op("MatMulInteger", &onnx_parser::parse_matmul<op::quant_dot>);
add_mem_op("MaxPool", &onnx_parser::parse_pooling);
add_mem_op("OneHot", &onnx_parser::parse_onehot);
add_mem_op("Range", &onnx_parser::parse_range);
add_mem_op("ReduceL1", &onnx_parser::parse_reduce_l1);
add_mem_op("ReduceL2", &onnx_parser::parse_reduce_l2);
add_mem_op("ReduceLogSum", &onnx_parser::parse_reduce_log_sum);
......@@ -1989,6 +1990,47 @@ struct onnx_parser
return l0;
}
instruction_ref
parse_range(const std::string&, const node_info&, std::vector<instruction_ref> args)
{
auto start_arg = args[0]->eval();
check_arg_empty(start_arg, "PARSE_RANGE: start arg dynamic shape is not supported");
auto limit_arg = args[1]->eval();
check_arg_empty(limit_arg, "PARSE_RANGE: limit arg dynamic shape is not supported");
auto delta_arg = args[2]->eval();
check_arg_empty(delta_arg, "PARSE_RANGE: delta arg dynamic shape is not supported");
assert(args[0]->get_shape().elements() == 1 and args[1]->get_shape().elements() == 1 and
args[2]->get_shape().elements() == 1);
instruction_ref l0;
visit_all(start_arg, limit_arg, delta_arg)([&](auto start, auto limit, auto delta) {
auto start_val = start.front();
auto limit_val = limit.front();
auto delta_val = delta.front();
size_t num_elements = static_cast<size_t>(
ceil(static_cast<double>(limit_val - start_val) / static_cast<double>(delta_val)));
assert(num_elements > 0);
using type = decltype(start_val);
std::vector<type> range_vals(num_elements);
std::generate(range_vals.begin(), range_vals.end(), [&]() {
auto result = start_val;
start_val += delta_val;
return result;
});
l0 = prog.add_literal({shape{args[0]->get_shape().type(), {num_elements}}, range_vals});
});
return l0;
}
void parse_from(std::istream& is)
{
onnx::ModelProto model;
......
......@@ -1554,6 +1554,92 @@ def prelu_brcst_test():
return ([node], [arg0, arg1], [arg_out])
@onnx_test
def range_test():
start_val = np.array([10])
limit_val = np.array([6])
delta_val = np.array([-3])
start_tensor = helper.make_tensor(name='start_val',
data_type=TensorProto.INT64,
dims=start_val.reshape(()).shape,
vals=start_val.astype(np.int64))
start = onnx.helper.make_node('Constant',
inputs=[],
outputs=['start'],
value=start_tensor)
limit_tensor = helper.make_tensor(name='limit_val',
data_type=TensorProto.INT64,
dims=limit_val.reshape(()).shape,
vals=limit_val.astype(np.int64))
limit = onnx.helper.make_node('Constant',
inputs=[],
outputs=['limit'],
value=limit_tensor)
delta_tensor = helper.make_tensor(name='delta_val',
data_type=TensorProto.INT64,
dims=delta_val.reshape(()).shape,
vals=delta_val.astype(np.int64))
delta = onnx.helper.make_node('Constant',
inputs=[],
outputs=['delta'],
value=delta_tensor)
node = onnx.helper.make_node('Range',
inputs=['start', 'limit', 'delta'],
outputs=['1'])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3])
return ([start, limit, delta, node], [], [y])
@onnx_test
def range_float_test():
start_val = np.array([2])
limit_val = np.array([11])
delta_val = np.array([2])
start_tensor = helper.make_tensor(name='start_val',
data_type=TensorProto.FLOAT,
dims=start_val.reshape(()).shape,
vals=start_val.astype(np.float))
start = onnx.helper.make_node('Constant',
inputs=[],
outputs=['start'],
value=start_tensor)
limit_tensor = helper.make_tensor(name='limit_val',
data_type=TensorProto.FLOAT,
dims=limit_val.reshape(()).shape,
vals=limit_val.astype(np.float))
limit = onnx.helper.make_node('Constant',
inputs=[],
outputs=['limit'],
value=limit_tensor)
delta_tensor = helper.make_tensor(name='delta_val',
data_type=TensorProto.FLOAT,
dims=delta_val.reshape(()).shape,
vals=delta_val.astype(np.float))
delta = onnx.helper.make_node('Constant',
inputs=[],
outputs=['delta'],
value=delta_tensor)
node = onnx.helper.make_node('Range',
inputs=['start', 'limit', 'delta'],
outputs=['1'])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3])
return ([start, limit, delta, node], [], [y])
@onnx_test
def recip_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3])
......
......@@ -1228,6 +1228,32 @@ TEST_CASE(prelu_brcst_test)
EXPECT(p == prog);
}
TEST_CASE(range_test)
{
migraphx::program p;
p.add_literal(int64_t{10});
p.add_literal(int64_t{6});
p.add_literal(int64_t{-3});
p.add_literal(migraphx::literal{{migraphx::shape::int64_type, {2}}, {10, 7}});
auto prog = optimize_onnx("range_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(range_float_test)
{
migraphx::program p;
p.add_literal(float{2});
p.add_literal(float{11});
p.add_literal(float{2});
p.add_literal(migraphx::literal{{migraphx::shape::float_type, {5}}, {2, 4, 6, 8, 10}});
auto prog = optimize_onnx("range_float_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(recip_test)
{
migraphx::program p;
......

range_test:
/start"Constant*
value*:
B start_val
/limit"Constant*
value*:B limit_val
8delta"Constant*%
value*:
B delta_val

start
limit
delta1"Range
range_testb
1

B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment