Commit 85d9fbef authored by Brian Pickrell's avatar Brian Pickrell
Browse files

more tidy; reduces branch count

parent 75dc1ac3
...@@ -209,7 +209,7 @@ struct parse_resize : op_parser<parse_resize> ...@@ -209,7 +209,7 @@ struct parse_resize : op_parser<parse_resize>
bool mostly_fixed = bool mostly_fixed =
std::all_of(some_dims.begin() + 1, std::all_of(some_dims.begin() + 1,
some_dims.end(), some_dims.end(),
[](shape::dynamic_dimension dd) { return dd.is_fixed(); }); [](const shape::dynamic_dimension& dd) { return dd.is_fixed(); });
if(not mostly_fixed) if(not mostly_fixed)
MIGRAPHX_THROW("PARSE_" + opd.op_name + MIGRAPHX_THROW("PARSE_" + opd.op_name +
...@@ -261,12 +261,12 @@ struct parse_resize : op_parser<parse_resize> ...@@ -261,12 +261,12 @@ struct parse_resize : op_parser<parse_resize>
// Put the value into index vector // Put the value into index vector
} }
// Create a 1D shape literal // Create a 1D shape literal
auto index_litA = info.add_literal(literal( auto index_lit = info.add_literal(literal(
migraphx::shape(migraphx::shape::int64_type, {fixed_out_lens[ii]}), in_idx)); migraphx::shape(migraphx::shape::int64_type, {fixed_out_lens[ii]}), in_idx));
// add a "gather" instruction // add a "gather" instruction
gather_ins = info.add_instruction( gather_ins = info.add_instruction(
make_op("gather", {{"axis", 1 + ii}}), gather_ins, index_litA); make_op("gather", {{"axis", 1 + ii}}), gather_ins, index_lit);
} }
return gather_ins; return gather_ins;
} }
...@@ -314,20 +314,15 @@ struct parse_resize : op_parser<parse_resize> ...@@ -314,20 +314,15 @@ struct parse_resize : op_parser<parse_resize>
// Look at inputs and infer either output size or scale, depending on input type // Look at inputs and infer either output size or scale, depending on input type
for(const auto& arg : args) for(const auto& arg : args)
{ {
if(arg->name() == "undefined" or arg == args.front())
{
continue;
}
if(arg != args[0] and arg->get_shape().dynamic()) if(arg != args[0] and arg->get_shape().dynamic())
{ {
MIGRAPHX_THROW("PARSE_" + opd.op_name + MIGRAPHX_THROW("PARSE_" + opd.op_name +
": no dynamic input shapes allowed except the first one"); ": no dynamic input shapes allowed except the first one");
} }
// skip any empty inputs // skip first input and any empty inputs
auto lens = arg->get_shape().lens(); auto lens = arg->get_shape().to_static(1).lens();
if(lens.empty()) if(arg->name() == "undefined" or arg == args.front() or lens.empty())
{ {
continue; continue;
} }
......
...@@ -6616,13 +6616,12 @@ def resize_downsample_f_dyn_test(): ...@@ -6616,13 +6616,12 @@ def resize_downsample_f_dyn_test():
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [None, 1, 5, 9]) X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [None, 1, 5, 9])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, []) Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [])
node = onnx.helper.make_node( node = onnx.helper.make_node('Resize',
'Resize', inputs=['X', '', 'scales'],
inputs=['X', '', 'scales'], outputs=['Y'],
outputs=['Y'], coordinate_transformation_mode='asymmetric',
coordinate_transformation_mode='asymmetric', mode='nearest',
mode='nearest', nearest_mode='floor')
nearest_mode='floor')
return ([node], [X], [Y], [scale_tensor]) return ([node], [X], [Y], [scale_tensor])
...@@ -6638,13 +6637,12 @@ def resize_upsample_f_dyn_test(): ...@@ -6638,13 +6637,12 @@ def resize_upsample_f_dyn_test():
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [None, 1, 3, 5]) X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [None, 1, 3, 5])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, []) Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [])
node = onnx.helper.make_node( node = onnx.helper.make_node('Resize',
'Resize', inputs=['X', '', 'scales'],
inputs=['X', '', 'scales'], outputs=['Y'],
outputs=['Y'], coordinate_transformation_mode='half_pixel',
coordinate_transformation_mode='half_pixel', mode='nearest',
mode='nearest', nearest_mode='round_prefer_ceil')
nearest_mode='round_prefer_ceil')
return ([node], [X], [Y], [scale_tensor]) return ([node], [X], [Y], [scale_tensor])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment