Commit d1239978 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

fixed Upsample-7 special case handling

parent fe11d758
......@@ -313,8 +313,6 @@ struct parse_resize : op_parser<parse_resize>
// scale
std::vector<double> vec_scale;
// Look at inputs and infer either output size or scale, depending on input type
// The input ROI is not currently suported
for(const auto& arg : args)
{
if(arg != args[0] and arg->get_shape().dynamic())
......@@ -322,14 +320,44 @@ struct parse_resize : op_parser<parse_resize>
MIGRAPHX_THROW("PARSE_" + opd.onnx_name +
": no dynamic input shapes allowed except the first one");
}
}
// Special-case handling for the deprecated Upsample-7 op, which handles scales as an
// attribute rather than an input.
if(opd.onnx_name == "Upsample" and contains(info.attributes, "scales"))
{
// "scales" attribute is a vector of float
literal scales = parser.parse_value(info.attributes.at("scales"));
scales.visit([&](auto s) { vec_scale.assign(s.begin(), s.end()); });
// skip first input (if dynamic) and any empty inputs
if(in_dims.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_" + opd.onnx_name +
": specified scales rank does not match input shape rank");
}
std::transform(in_dims.begin(),
in_dims.end(),
vec_scale.begin(),
out_lens.begin(),
[&](auto idx, auto scale) {
// inferred output size is floor(idx.max * scale)
return idx.max * scale;
});
}
else
{
// Search the input list for either output size or scale, depending on input type.
// The first non-empty input after the input shape is assumed to be it, since we can't read
// Onnx input names.
// Todo: The input "ROI" is not currently supported
for(const auto& arg : args)
{
// skip first input and any empty inputs
auto lens = arg->get_shape().to_static(1).lens();
if(arg->name() == "undefined" or (arg == args[0] and arg->get_shape().dynamic()) or lens.empty())
if(arg->name() == "undefined" or (arg == args[0]) or lens.empty())
{
continue;
}
auto type = arg->get_shape().type();
// This input is inferred to mean output size if type == int64_type; otherwise
......@@ -362,26 +390,9 @@ struct parse_resize : op_parser<parse_resize>
{
// scale input
auto arg_scale = arg->eval();
// Special-case for the deprecated Upsample-7 operation: scales is an attribute
// rather than an input. Upsample-9 uses scales as an input
if(opd.onnx_name == "Upsample" and arg_scale.empty())
{
if(contains(info.attributes, "scales"))
{
// vec_scale = info.attributes.at("scales").vector<float>();
literal scales = parser.parse_value(info.attributes.at("scales"));
scales.visit([&](auto s) { vec_scale.assign(s.begin(), s.end()); });
}
else
MIGRAPHX_THROW("PARSE_" + opd.onnx_name +
": scales attribute missing");
}
else
{
check_arg_empty(arg_scale,
"PARSE_" + opd.onnx_name + ": dynamic input scale is not supported!");
"PARSE_" + opd.onnx_name +
": dynamic input scale is not supported!");
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });
if(in_dims.size() != vec_scale.size())
......@@ -389,8 +400,6 @@ struct parse_resize : op_parser<parse_resize>
MIGRAPHX_THROW("PARSE_" + opd.onnx_name +
": specified scale rank does not match input rank");
}
}
std::transform(in_dims.begin(),
in_dims.end(),
vec_scale.begin(),
......@@ -402,11 +411,13 @@ struct parse_resize : op_parser<parse_resize>
break;
}
}
}
if(out_lens.size() == 0)
MIGRAPHX_THROW("PARSE_" + opd.onnx_name + ": no input was given for scale or output size");
MIGRAPHX_THROW("PARSE_" + opd.onnx_name +
": no input was given for scale or output size");
// Dynamic batch: Only args[0] can have a dynamic shape, only the 0'th
// dimension--batch size--can be non-fixed, and the only resize mode allowed is "nearest"
// dimension--batch size--can be non-fixed, and the only resize mode supported is "nearest"
if(args[0]->get_shape().dynamic())
{
return dynamic_nearest_parse(out_lens, vec_scale, opd, info, args);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment