Unverified Commit 63e35438 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Dynamic ConstantOfShape (#2111)

Makes ConstantOfShape work with a variable dimensions input
parent 36eaf9e5
......@@ -49,6 +49,8 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
{
MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
}
// convert to a scalar literal
l_val = literal(shape{l_val.get_shape().type(), {1}, {0}}, l_val.data());
}
else
{
......@@ -64,30 +66,37 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
migraphx::shape s;
// input is empty, output is a scalar
auto type = l_val.get_shape().type();
// empty input tensor, output is a scalar
if(args[0]->get_shape().elements() == 0)
migraphx::argument input = args[0]->eval();
if(not input.empty())
{
s = migraphx::shape{type, {1}, {0}};
// empty input tensor, output is a scalar
if(args[0]->get_shape().elements() == 0)
{
s = migraphx::shape{type, {1}, {0}};
}
else
{
std::vector<std::size_t> dims;
input.visit([&](auto ia) { dims.assign(ia.begin(), ia.end()); });
s = migraphx::shape{type, dims};
}
literal l_out{};
l_val.visit([&](auto val) {
using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
// l_val contains only one element
std::vector<val_type> out_vec(s.elements(), val.front());
l_out = literal(s, out_vec);
});
return info.add_literal(l_out);
}
// has variable input (dynamic shape buffer)
else
{
migraphx::argument in = args[0]->eval();
check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
s = migraphx::shape{type, dims};
auto dv_lit = info.add_literal(l_val);
auto alloc_ins =
info.add_instruction(make_op("allocate", {{"buf_type", type}}), args[0]);
return info.add_instruction(make_op("fill"), dv_lit, alloc_ins);
}
literal l_out{};
l_val.visit([&](auto val) {
using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
// l_val contains only one element
std::vector<val_type> out_vec(s.elements(), val.front());
l_out = literal(s, out_vec);
});
return info.add_literal(l_out);
}
}
};
......
const_of_shape_default_test:
6shape"Constant*#
value*:B shape_tensor

shapey"ConstantOfShapeconst_of_shape_default_testb
y



B
\ No newline at end of file
const_of_shape_dyn_int64_test:
=
output_dimsy"ConstantOfShape*
value*:
Bvalueconst_of_shape_dyn_int64_testZ
output_dims

b
y



B
\ No newline at end of file
constant-of-shape:
const_of_shape_int64_test:
6shape"Constant*#
value**B shape_tensor 
value*:B shape_tensor
7
shapey"ConstantOfShape*
value*:
Bvalue constant_of_shapeb
Bvalueconst_of_shape_int64_testb
y




B
B
\ No newline at end of file
constant-of-shape:
!const_of_shape_no_value_attr_test:
6shape"Constant*#
value**B shape_tensor 
value*:B shape_tensor

shapey"ConstantOfShapeconstant_of_shapeb
shapey"ConstantOfShape!const_of_shape_no_value_attr_testb
y



B
B
\ No newline at end of file
......@@ -1007,9 +1007,9 @@ def const_of_shape_empty_input_test():
[10])
empty_val = np.array([]).astype(np.int64)
empty_ts = helper.make_tensor(name='empty_tensor',
data_type=TensorProto.INT32,
data_type=TensorProto.INT64,
dims=empty_val.shape,
vals=empty_val.flatten().astype(int))
vals=empty_val.flatten().astype(np.int64))
shape_const = helper.make_node(
'Constant',
inputs=[],
......@@ -1035,9 +1035,9 @@ def const_of_shape_float_test():
shape_val = np.array([2, 3, 4]).astype(np.int64)
shape_ts = helper.make_tensor(name='shape_tensor',
data_type=TensorProto.INT32,
data_type=TensorProto.INT64,
dims=shape_val.shape,
vals=shape_val.flatten().astype(int))
vals=shape_val.flatten().astype(np.int64))
shape_const = helper.make_node(
'Constant',
......@@ -1055,22 +1055,44 @@ def const_of_shape_float_test():
return ([shape_const, node], [], [y])
@onnx_test()
def const_of_shape_default_test():
shape_val = np.array([2, 3, 4]).astype(np.int64)
shape_ts = helper.make_tensor(name='shape_tensor',
data_type=TensorProto.INT64,
dims=shape_val.shape,
vals=shape_val.flatten().astype(np.int64))
shape_const = helper.make_node(
'Constant',
inputs=[],
outputs=['shape'],
value=shape_ts,
)
y = helper.make_tensor_value_info('y', TensorProto.INT64, [2, 3, 4])
node = onnx.helper.make_node('ConstantOfShape',
inputs=['shape'],
outputs=['y'])
return ([shape_const, node], [], [y])
@onnx_test()
def const_of_shape_int64_test():
tensor_val = onnx.helper.make_tensor('value', onnx.TensorProto.INT64, [1],
[10])
shape_val = np.array([2, 3, 4]).astype(np.int64)
shape_ts = helper.make_tensor(name='shape_tensor',
data_type=TensorProto.INT32,
data_type=TensorProto.INT64,
dims=shape_val.shape,
vals=shape_val.flatten().astype(int))
vals=shape_val.flatten().astype(np.int64))
shape_const = helper.make_node(
'Constant',
inputs=[],
outputs=['shape'],
value=shape_ts,
)
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3, 4])
y = helper.make_tensor_value_info('y', TensorProto.INT64, [2, 3, 4])
node = onnx.helper.make_node('ConstantOfShape',
inputs=['shape'],
......@@ -1084,9 +1106,9 @@ def const_of_shape_int64_test():
def const_of_shape_no_value_attr_test():
shape_val = np.array([2, 3, 4]).astype(np.int64)
shape_ts = helper.make_tensor(name='shape_tensor',
data_type=TensorProto.INT32,
data_type=TensorProto.INT64,
dims=shape_val.shape,
vals=shape_val.flatten().astype(int))
vals=shape_val.flatten().astype(np.int64))
shape_const = helper.make_node(
'Constant',
inputs=[],
......@@ -1104,6 +1126,40 @@ def const_of_shape_no_value_attr_test():
return ([shape_const, node], [], [y])
@onnx_test()
def const_of_shape_dyn_float_test():
tensor_val = onnx.helper.make_tensor('value', onnx.TensorProto.FLOAT, [1],
[10])
output_dims = helper.make_tensor_value_info('output_dims',
TensorProto.INT64, [3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3, 4])
node = onnx.helper.make_node('ConstantOfShape',
inputs=['output_dims'],
outputs=['y'],
value=tensor_val)
return ([node], [output_dims], [y])
@onnx_test()
def const_of_shape_dyn_int64_test():
tensor_val = onnx.helper.make_tensor('value', onnx.TensorProto.INT64, [1],
[10])
output_dims = helper.make_tensor_value_info('output_dims',
TensorProto.INT64, [3])
y = helper.make_tensor_value_info('y', TensorProto.INT64, [2, 3, 4])
node = onnx.helper.make_node('ConstantOfShape',
inputs=['output_dims'],
outputs=['y'],
value=tensor_val)
return ([node], [output_dims], [y])
@onnx_test()
def conv_1d_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 5])
......
......@@ -1040,11 +1040,25 @@ TEST_CASE(constant_one_val_int64_test)
EXPECT(p == prog);
}
TEST_CASE(const_of_shape_default_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape output_dims_shape(migraphx::shape::int64_type, {3});
mm->add_literal(migraphx::literal(output_dims_shape, {2, 3, 4}));
migraphx::shape output_shape{migraphx::shape::float_type, {2, 3, 4}};
std::vector<float> vec(output_shape.elements(), 0.0);
mm->add_literal(migraphx::literal(output_shape, vec));
auto prog = optimize_onnx("const_of_shape_default_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(const_of_shape_empty_input_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal(migraphx::shape::int32_type));
mm->add_literal(migraphx::literal(migraphx::shape::int64_type));
migraphx::shape s(migraphx::shape::int64_type, {1}, {0});
std::vector<int64_t> vec(s.elements(), 10);
mm->add_literal(migraphx::literal(s, vec));
......@@ -1057,7 +1071,7 @@ TEST_CASE(const_of_shape_float_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ss(migraphx::shape::int32_type, {3});
migraphx::shape ss(migraphx::shape::int64_type, {3});
mm->add_literal(migraphx::literal(ss, {2, 3, 4}));
migraphx::shape s(migraphx::shape::float_type, {2, 3, 4});
std::vector<float> vec(s.elements(), 10.0f);
......@@ -1071,8 +1085,10 @@ TEST_CASE(const_of_shape_int64_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ss(migraphx::shape::int32_type, {3});
// output_dims
migraphx::shape ss(migraphx::shape::int64_type, {3});
mm->add_literal(migraphx::literal(ss, {2, 3, 4}));
// constant shape literal
migraphx::shape s(migraphx::shape::int64_type, {2, 3, 4});
std::vector<int64_t> vec(s.elements(), 10);
mm->add_literal(migraphx::literal(s, vec));
......@@ -1085,7 +1101,7 @@ TEST_CASE(const_of_shape_no_value_attr_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ss(migraphx::shape::int32_type, {3});
migraphx::shape ss(migraphx::shape::int64_type, {3});
mm->add_literal(migraphx::literal(ss, {2, 3, 4}));
migraphx::shape s(migraphx::shape::float_type, {2, 3, 4});
std::vector<float> vec(s.elements(), 0.0f);
......@@ -1095,6 +1111,42 @@ TEST_CASE(const_of_shape_no_value_attr_test)
EXPECT(p == prog);
}
TEST_CASE(const_of_shape_dyn_float_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto od_param =
mm->add_parameter("output_dims", migraphx::shape{migraphx::shape::int64_type, {3}});
auto alloc_ins = mm->add_instruction(
migraphx::make_op("allocate", {{"buf_type", migraphx::shape::float_type}}), od_param);
migraphx::shape dv_shape(migraphx::shape::float_type, {1}, {0});
auto dv_lit = mm->add_literal(migraphx::literal(dv_shape, {10}));
auto fill_ins = mm->add_instruction(migraphx::make_op("fill"), dv_lit, alloc_ins);
mm->add_return({fill_ins});
migraphx::onnx_options options;
auto prog = parse_onnx("const_of_shape_dyn_float_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(const_of_shape_dyn_int64_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto od_param =
mm->add_parameter("output_dims", migraphx::shape{migraphx::shape::int64_type, {3}});
auto alloc_ins = mm->add_instruction(
migraphx::make_op("allocate", {{"buf_type", migraphx::shape::int64_type}}), od_param);
migraphx::shape dv_shape(migraphx::shape::int64_type, {1}, {0});
auto dv_lit = mm->add_literal(migraphx::literal(dv_shape, {10}));
auto fill_ins = mm->add_instruction(migraphx::make_op("fill"), dv_lit, alloc_ins);
mm->add_return({fill_ins});
migraphx::onnx_options options;
auto prog = parse_onnx("const_of_shape_dyn_int64_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(conv_autopad_fail_test)
{
EXPECT(test::throws([&] { optimize_onnx("conv_autopad_fail_test.onnx"); }));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment