Commit d1239978 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

fixed Upsample-7 special case handling

parent fe11d758
...@@ -313,8 +313,6 @@ struct parse_resize : op_parser<parse_resize> ...@@ -313,8 +313,6 @@ struct parse_resize : op_parser<parse_resize>
// scale // scale
std::vector<double> vec_scale; std::vector<double> vec_scale;
// Look at inputs and infer either output size or scale, depending on input type
// The input ROI is not currently suported
for(const auto& arg : args) for(const auto& arg : args)
{ {
if(arg != args[0] and arg->get_shape().dynamic()) if(arg != args[0] and arg->get_shape().dynamic())
...@@ -322,14 +320,44 @@ struct parse_resize : op_parser<parse_resize> ...@@ -322,14 +320,44 @@ struct parse_resize : op_parser<parse_resize>
MIGRAPHX_THROW("PARSE_" + opd.onnx_name + MIGRAPHX_THROW("PARSE_" + opd.onnx_name +
": no dynamic input shapes allowed except the first one"); ": no dynamic input shapes allowed except the first one");
} }
}
// Special-case handling for the deprecated Upsample-7 op, which handles scales as an
// attribute rather than an input.
if(opd.onnx_name == "Upsample" and contains(info.attributes, "scales"))
{
// "scales" attribute is a vector of float
literal scales = parser.parse_value(info.attributes.at("scales"));
scales.visit([&](auto s) { vec_scale.assign(s.begin(), s.end()); });
// skip first input (if dynamic) and any empty inputs if(in_dims.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_" + opd.onnx_name +
": specified scales rank does not match input shape rank");
}
std::transform(in_dims.begin(),
in_dims.end(),
vec_scale.begin(),
out_lens.begin(),
[&](auto idx, auto scale) {
// inferred output size is floor(idx.max * scale)
return idx.max * scale;
});
}
else
{
// Search the input list for either output size or scale, depending on input type.
// The first non-empty input after the input shape is assumed to be it, since we can't read
// Onnx input names.
// Todo: The input "ROI" is not currently supported
for(const auto& arg : args)
{
// skip first input and any empty inputs
auto lens = arg->get_shape().to_static(1).lens(); auto lens = arg->get_shape().to_static(1).lens();
if(arg->name() == "undefined" or (arg == args[0] and arg->get_shape().dynamic()) or lens.empty()) if(arg->name() == "undefined" or (arg == args[0]) or lens.empty())
{ {
continue; continue;
} }
auto type = arg->get_shape().type(); auto type = arg->get_shape().type();
// This input is inferred to mean output size if type == int64_type; otherwise // This input is inferred to mean output size if type == int64_type; otherwise
...@@ -362,26 +390,9 @@ struct parse_resize : op_parser<parse_resize> ...@@ -362,26 +390,9 @@ struct parse_resize : op_parser<parse_resize>
{ {
// scale input // scale input
auto arg_scale = arg->eval(); auto arg_scale = arg->eval();
// Special-case for the deprecated Upsample-7 operation: scales is an attribute
// rather than an input. Upsample-9 uses scales as an input
if(opd.onnx_name == "Upsample" and arg_scale.empty())
{
if(contains(info.attributes, "scales"))
{
// vec_scale = info.attributes.at("scales").vector<float>();
literal scales = parser.parse_value(info.attributes.at("scales"));
scales.visit([&](auto s) { vec_scale.assign(s.begin(), s.end()); });
}
else
MIGRAPHX_THROW("PARSE_" + opd.onnx_name +
": scales attribute missing");
}
else
{
check_arg_empty(arg_scale, check_arg_empty(arg_scale,
"PARSE_" + opd.onnx_name + ": dynamic input scale is not supported!"); "PARSE_" + opd.onnx_name +
": dynamic input scale is not supported!");
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); }); arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });
if(in_dims.size() != vec_scale.size()) if(in_dims.size() != vec_scale.size())
...@@ -389,8 +400,6 @@ struct parse_resize : op_parser<parse_resize> ...@@ -389,8 +400,6 @@ struct parse_resize : op_parser<parse_resize>
MIGRAPHX_THROW("PARSE_" + opd.onnx_name + MIGRAPHX_THROW("PARSE_" + opd.onnx_name +
": specified scale rank does not match input rank"); ": specified scale rank does not match input rank");
} }
}
std::transform(in_dims.begin(), std::transform(in_dims.begin(),
in_dims.end(), in_dims.end(),
vec_scale.begin(), vec_scale.begin(),
...@@ -402,11 +411,13 @@ struct parse_resize : op_parser<parse_resize> ...@@ -402,11 +411,13 @@ struct parse_resize : op_parser<parse_resize>
break; break;
} }
} }
}
if(out_lens.size() == 0) if(out_lens.size() == 0)
MIGRAPHX_THROW("PARSE_" + opd.onnx_name + ": no input was given for scale or output size"); MIGRAPHX_THROW("PARSE_" + opd.onnx_name +
": no input was given for scale or output size");
// Dynamic batch: Only args[0] can have a dynamic shape, only the 0'th // Dynamic batch: Only args[0] can have a dynamic shape, only the 0'th
// dimension--batch size--can be non-fixed, and the only resize mode allowed is "nearest" // dimension--batch size--can be non-fixed, and the only resize mode supported is "nearest"
if(args[0]->get_shape().dynamic()) if(args[0]->get_shape().dynamic())
{ {
return dynamic_nearest_parse(out_lens, vec_scale, opd, info, args); return dynamic_nearest_parse(out_lens, vec_scale, opd, info, args);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment