Commit 1f4f6d44 authored by Shucai Xiao's avatar Shucai Xiao Committed by mvermeulen
Browse files

Support initializer not in inputs (#396)

* support onnx file with initializer data not an input

* fix format

* fix review comments and format issue

* format

* reorder alphabetically
parent 1bacc1df
...@@ -1422,21 +1422,14 @@ struct onnx_parser ...@@ -1422,21 +1422,14 @@ struct onnx_parser
void parse_graph(const onnx::GraphProto& graph) void parse_graph(const onnx::GraphProto& graph)
{ {
nodes = get_nodes(graph); nodes = get_nodes(graph);
std::unordered_map<std::string, onnx::TensorProto> initializer_data;
for(auto&& f : graph.initializer()) for(auto&& f : graph.initializer())
{ instructions[f.name()] = prog.add_literal(parse_tensor(f));
initializer_data[f.name()] = f;
}
for(auto&& input : graph.input()) for(auto&& input : graph.input())
{ {
const std::string& name = input.name(); const std::string& name = input.name();
// Does the input have an initializer? // input not in initializer_data, so it is a real input
if(contains(initializer_data, name)) if(!contains(instructions, name))
{
auto t = initializer_data[name];
instructions[name] = prog.add_literal(parse_tensor(t));
}
else
{ {
// TODO: Get shape of input parameter // TODO: Get shape of input parameter
shape s = parse_type(input.type()); shape s = parse_type(input.type());
......
...@@ -838,6 +838,26 @@ def implicit_sub_bcast_test(): ...@@ -838,6 +838,26 @@ def implicit_sub_bcast_test():
return ([node], [arg0, arg1], [arg_out]) return ([node], [arg0, arg1], [arg_out])
@onnx_test
def initializer_not_an_input():
values = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
w = helper.make_tensor(name='w',
data_type=TensorProto.FLOAT,
dims=values.shape,
vals=values.flatten().astype(np.float))
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [5, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [5, 4])
node = onnx.helper.make_node(
'Gemm',
inputs=['x', 'w'],
outputs=['y'],
)
return ([node], [x], [y], [w])
@onnx_test @onnx_test
def leaky_relu_test(): def leaky_relu_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
......
...@@ -48,8 +48,8 @@ TEST_CASE(add_fp16_test) ...@@ -48,8 +48,8 @@ TEST_CASE(add_fp16_test)
TEST_CASE(add_scalar_test) TEST_CASE(add_scalar_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1}}); auto l1 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1}});
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0); auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, m0, m1); p.add_instruction(migraphx::op::add{}, m0, m1);
...@@ -585,6 +585,19 @@ TEST_CASE(implicit_sub_bcast_test) ...@@ -585,6 +585,19 @@ TEST_CASE(implicit_sub_bcast_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(initializer_not_an_input)
{
migraphx::program p;
std::vector<float> w = {1, 2, 3, 4, 5, 6, 7, 8};
auto l1 = p.add_literal(migraphx::literal({migraphx::shape::float_type, {2, 4}}, w));
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 2}});
p.add_instruction(migraphx::op::dot{}, l0, l1);
auto prog = migraphx::parse_onnx("initializer_not_an_input.onnx");
EXPECT(p == prog);
}
TEST_CASE(leaky_relu_test) TEST_CASE(leaky_relu_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -858,9 +871,9 @@ TEST_CASE(reshape_test) ...@@ -858,9 +871,9 @@ TEST_CASE(reshape_test)
migraphx::program p; migraphx::program p;
migraphx::op::reshape op; migraphx::op::reshape op;
std::vector<int64_t> reshape_dims{3, 8}; std::vector<int64_t> reshape_dims{3, 8};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
p.add_literal( p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, reshape_dims}); migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, reshape_dims});
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
op.dims = reshape_dims; op.dims = reshape_dims;
p.add_instruction(op, l0); p.add_instruction(op, l0);
p.add_instruction(op, l0); p.add_instruction(op, l0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment