Commit 69e4955e authored by charlie's avatar charlie
Browse files

Formatting

parent abf62c5d
......@@ -70,7 +70,7 @@ struct onnx_parser
onnx_parser&, const node_info&, std::vector<instruction_ref>)>;
node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions;
program prog = program();
program prog = program();
shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0};
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
......
......@@ -18,10 +18,10 @@ template <class... Ts>
program parse_onnx_from(const onnx_options& options, Ts&&... xs)
{
onnx::onnx_parser parser;
parser.map_input_dims = options.map_input_dims;
parser.map_dyn_input_dims = options.map_dyn_input_dims;
auto dim_val = options.default_dim_value;
if (dim_val != 0)
parser.map_input_dims = options.map_input_dims;
parser.map_dyn_input_dims = options.map_dyn_input_dims;
auto dim_val = options.default_dim_value;
if(dim_val != 0)
{
parser.default_dyn_dim_value = {dim_val, dim_val, 0};
}
......
......@@ -256,18 +256,18 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
MIGRAPHX_THROW("module \"" + mod->name() + "\" has parameter name \"" + name +
"\" existing in parent graph!");
}
shape s;
std::vector<std::size_t> dims;
if(map_input_dims.count(name) > 0)
{
dims = map_input_dims.at(name);
s = parse_type(input.type(), dims);
s = parse_type(input.type(), dims);
}
else if(map_dyn_input_dims.count(name) > 0)
{
shape::type_t shape_type = get_type(input.type().tensor_type().elem_type());
s = {shape_type, map_dyn_input_dims.at(name)};
s = {shape_type, map_dyn_input_dims.at(name)};
}
else
{
......@@ -447,22 +447,18 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
return default_dyn_dim_value;
}
});
if(dynamic_dims.empty())
{
return {shape_type};
}
if (std::all_of(dynamic_dims.begin(), dynamic_dims.end(), [](auto dd) { return dd.is_fixed(); }))
if(std::all_of(dynamic_dims.begin(), dynamic_dims.end(), [](auto dd) { return dd.is_fixed(); }))
{
std::vector<std::size_t> dims;
std::transform(
dynamic_dims.begin(),
dynamic_dims.end(),
std::back_inserter(dims),
[](auto d) {
return d.max;
}
);
std::transform(dynamic_dims.begin(),
dynamic_dims.end(),
std::back_inserter(dims),
[](auto d) { return d.max; });
return {shape_type, dims};
}
return {shape_type, dynamic_dims};
......
......@@ -5431,7 +5431,9 @@ TEST_CASE(variable_batch_user_input_test2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {{2, 5, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}}});
auto l0 = mm->add_parameter("0",
migraphx::shape{migraphx::shape::float_type,
{{2, 5, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}}});
auto r = mm->add_instruction(migraphx::make_op("identity"), l0);
mm->add_return({r});
......@@ -5447,7 +5449,9 @@ TEST_CASE(variable_batch_user_input_test3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {{2, 5, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}}});
auto l0 = mm->add_parameter("0",
migraphx::shape{migraphx::shape::float_type,
{{2, 5, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}}});
auto r = mm->add_instruction(migraphx::make_op("identity"), l0);
mm->add_return({r});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment