Commit 1f4f6d44 authored by Shucai Xiao's avatar Shucai Xiao Committed by mvermeulen
Browse files

Support initializer not in inputs (#396)

* support onnx file with initializer data not an input

* fix format

* fix review comments and format issue

* format

* reorder alphabetically
parent 1bacc1df
......@@ -1422,21 +1422,14 @@ struct onnx_parser
void parse_graph(const onnx::GraphProto& graph)
{
nodes = get_nodes(graph);
std::unordered_map<std::string, onnx::TensorProto> initializer_data;
for(auto&& f : graph.initializer())
{
initializer_data[f.name()] = f;
}
instructions[f.name()] = prog.add_literal(parse_tensor(f));
for(auto&& input : graph.input())
{
const std::string& name = input.name();
// Does the input have an initializer?
if(contains(initializer_data, name))
{
auto t = initializer_data[name];
instructions[name] = prog.add_literal(parse_tensor(t));
}
else
// input not in initializer_data, so it is a real input
if(!contains(instructions, name))
{
// TODO: Get shape of input parameter
shape s = parse_type(input.type());
......
......@@ -838,6 +838,26 @@ def implicit_sub_bcast_test():
return ([node], [arg0, arg1], [arg_out])
@onnx_test
def initializer_not_an_input():
values = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
w = helper.make_tensor(name='w',
data_type=TensorProto.FLOAT,
dims=values.shape,
vals=values.flatten().astype(np.float))
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [5, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [5, 4])
node = onnx.helper.make_node(
'Gemm',
inputs=['x', 'w'],
outputs=['y'],
)
return ([node], [x], [y], [w])
@onnx_test
def leaky_relu_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
......
......@@ -48,8 +48,8 @@ TEST_CASE(add_fp16_test)
TEST_CASE(add_scalar_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1}});
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, m0, m1);
......@@ -585,6 +585,19 @@ TEST_CASE(implicit_sub_bcast_test)
EXPECT(p == prog);
}
TEST_CASE(initializer_not_an_input)
{
migraphx::program p;
std::vector<float> w = {1, 2, 3, 4, 5, 6, 7, 8};
auto l1 = p.add_literal(migraphx::literal({migraphx::shape::float_type, {2, 4}}, w));
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 2}});
p.add_instruction(migraphx::op::dot{}, l0, l1);
auto prog = migraphx::parse_onnx("initializer_not_an_input.onnx");
EXPECT(p == prog);
}
TEST_CASE(leaky_relu_test)
{
migraphx::program p;
......@@ -858,9 +871,9 @@ TEST_CASE(reshape_test)
migraphx::program p;
migraphx::op::reshape op;
std::vector<int64_t> reshape_dims{3, 8};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, reshape_dims});
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
op.dims = reshape_dims;
p.add_instruction(op, l0);
p.add_instruction(op, l0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment