Commit d1239978 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

fixed Upsample-7 special case handling

parent fe11d758
...@@ -313,100 +313,111 @@ struct parse_resize : op_parser<parse_resize> ...@@ -313,100 +313,111 @@ struct parse_resize : op_parser<parse_resize>
// scale // scale
std::vector<double> vec_scale; std::vector<double> vec_scale;
// Look at inputs and infer either output size or scale, depending on input type
// The input ROI is not currently suported
for(const auto& arg : args) for(const auto& arg : args)
{ {
if(arg != args[0] and arg->get_shape().dynamic()) if(arg != args[0] and arg->get_shape().dynamic())
{ {
MIGRAPHX_THROW("PARSE_" + opd.onnx_name + MIGRAPHX_THROW("PARSE_" + opd.onnx_name +
": no dynamic input shapes allowed except the first one"); ": no dynamic input shapes allowed except the first one");
} }
}
// Special-case handling for the deprecated Upsample-7 op, which handles scales as an
// attribute rather than an input.
if(opd.onnx_name == "Upsample" and contains(info.attributes, "scales"))
{
// "scales" attribute is a vector of float
literal scales = parser.parse_value(info.attributes.at("scales"));
scales.visit([&](auto s) { vec_scale.assign(s.begin(), s.end()); });
// skip first input (if dynamic) and any empty inputs if(in_dims.size() != vec_scale.size())
auto lens = arg->get_shape().to_static(1).lens();
if(arg->name() == "undefined" or (arg == args[0] and arg->get_shape().dynamic()) or lens.empty())
{ {
continue; MIGRAPHX_THROW("PARSE_" + opd.onnx_name +
": specified scales rank does not match input shape rank");
} }
std::transform(in_dims.begin(),
auto type = arg->get_shape().type(); in_dims.end(),
vec_scale.begin(),
// This input is inferred to mean output size if type == int64_type; otherwise out_lens.begin(),
// read it as the scales [&](auto idx, auto scale) {
if(type == shape::int64_type) // inferred output size is floor(idx.max * scale)
return idx.max * scale;
});
}
else
{
// Search the input list for either output size or scale, depending on input type.
// The first non-empty input after the input shape is assumed to be it, since we can't read
// Onnx input names.
// Todo: The input "ROI" is not currently supported
for(const auto& arg : args)
{ {
auto arg_out_s = arg->eval(); // skip first input and any empty inputs
check_arg_empty(arg_out_s, auto lens = arg->get_shape().to_static(1).lens();
"PARSE_" + opd.onnx_name + ": dynamic output size is not supported!"); if(arg->name() == "undefined" or (arg == args[0]) or lens.empty())
out_lens.clear();
arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_s.ndim())
{ {
MIGRAPHX_THROW("PARSE_" + opd.onnx_name + continue;
": specified output rank does not match input rank");
} }
auto type = arg->get_shape().type();
// compute the scale in each dimension // This input is inferred to mean output size if type == int64_type; otherwise
vec_scale.resize(in_s.ndim()); // read it as the scales
if(type == shape::int64_type)
std::transform(in_dims.begin(),
in_dims.end(),
out_lens.begin(),
vec_scale.begin(),
[](auto iss, auto oss) { return double(1.0 * oss / iss.max); });
break;
}
else
{
// scale input
auto arg_scale = arg->eval();
// Special-case for the deprecated Upsample-7 operation: scales is an attribute
// rather than an input. Upsample-9 uses scales as an input
if(opd.onnx_name == "Upsample" and arg_scale.empty())
{ {
if(contains(info.attributes, "scales")) auto arg_out_s = arg->eval();
check_arg_empty(arg_out_s,
"PARSE_" + opd.onnx_name + ": dynamic output size is not supported!");
out_lens.clear();
arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_s.ndim())
{ {
// vec_scale = info.attributes.at("scales").vector<float>();
literal scales = parser.parse_value(info.attributes.at("scales"));
scales.visit([&](auto s) { vec_scale.assign(s.begin(), s.end()); });
}
else
MIGRAPHX_THROW("PARSE_" + opd.onnx_name + MIGRAPHX_THROW("PARSE_" + opd.onnx_name +
": scales attribute missing"); ": specified output rank does not match input rank");
}
// compute the scale in each dimension
vec_scale.resize(in_s.ndim());
std::transform(in_dims.begin(),
in_dims.end(),
out_lens.begin(),
vec_scale.begin(),
[](auto iss, auto oss) { return double(1.0 * oss / iss.max); });
break;
} }
else else
{ {
// scale input
auto arg_scale = arg->eval();
check_arg_empty(arg_scale, check_arg_empty(arg_scale,
"PARSE_" + opd.onnx_name + ": dynamic input scale is not supported!"); "PARSE_" + opd.onnx_name +
": dynamic input scale is not supported!");
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); }); arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });
if(in_dims.size() != vec_scale.size()) if(in_dims.size() != vec_scale.size())
{ {
MIGRAPHX_THROW("PARSE_" + opd.onnx_name + MIGRAPHX_THROW("PARSE_" + opd.onnx_name +
": specified scale rank does not match input rank"); ": specified scale rank does not match input rank");
} }
std::transform(in_dims.begin(),
} in_dims.end(),
std::transform(in_dims.begin(), vec_scale.begin(),
in_dims.end(), out_lens.begin(),
vec_scale.begin(), [&](auto idx, auto scale) {
out_lens.begin(), // inferred output size is floor(idx.max * scale)
[&](auto idx, auto scale) { return idx.max * scale;
// inferred output size is floor(idx.max * scale) });
return idx.max * scale; break;
}); }
break;
} }
} }
if(out_lens.size() == 0) if(out_lens.size() == 0)
MIGRAPHX_THROW("PARSE_" + opd.onnx_name + ": no input was given for scale or output size"); MIGRAPHX_THROW("PARSE_" + opd.onnx_name +
": no input was given for scale or output size");
// Dynamic batch: Only args[0] can have a dynamic shape, only the 0'th // Dynamic batch: Only args[0] can have a dynamic shape, only the 0'th
// dimension--batch size--can be non-fixed, and the only resize mode allowed is "nearest" // dimension--batch size--can be non-fixed, and the only resize mode supported is "nearest"
if(args[0]->get_shape().dynamic()) if(args[0]->get_shape().dynamic())
{ {
return dynamic_nearest_parse(out_lens, vec_scale, opd, info, args); return dynamic_nearest_parse(out_lens, vec_scale, opd, info, args);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment