Unverified Commit cb2eb576 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

Remove onnx (#2295)

* remove onnx

* remove print
parent c9ed3f52
# Copyright (c) OpenMMLab. All rights reserved.
from .symbolic import register_extra_symbolics
__all__ = ['register_extra_symbolics']
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
"""Modified from https://github.com/pytorch/pytorch."""
import warnings
from functools import wraps
from sys import maxsize
import torch
import torch.onnx
# This import monkey-patches graph manipulation methods on Graph, used for the
# ONNX symbolics
import torch.onnx.utils
from torch._C import ListType
# ---------------------------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------------------------
# Save some builtins as locals, because we'll shadown them below
_sum = sum
def _parse_arg(value, desc):
if desc == 'none':
return value
if desc == 'v' or not _is_value(value):
return value
if value.node().mustBeNone():
return None
if value.node().kind() == 'onnx::Constant':
tval = value.node()['value']
if desc == 'i':
return int(tval)
elif desc == 'f':
return float(tval)
elif desc == 'b':
return bool(tval)
elif desc == 's':
return str(tval)
elif desc == 't':
return tval
elif desc == 'is':
return [int(v) for v in tval]
elif desc == 'fs':
return [float(v) for v in tval]
else:
raise RuntimeError(
"ONNX symbolic doesn't know to interpret Constant node")
elif value.node().kind() == 'prim::ListConstruct':
if desc == 'is':
for v in value.node().inputs():
if v.node().kind() != 'onnx::Constant':
raise RuntimeError(
"Failed to export an ONNX attribute '" +
v.node().kind() +
"', since it's not constant, please try to make "
'things (e.g., kernel size) static if possible')
return [int(v.node()['value']) for v in value.node().inputs()]
else:
raise RuntimeError(
"ONNX symbolic doesn't know to interpret ListConstruct node")
raise RuntimeError(f'Unexpected node type: {value.node().kind()}')
def _maybe_get_const(value, desc):
if _is_value(value) and value.node().kind() == 'onnx::Constant':
return _parse_arg(value, desc)
return value
def _maybe_get_scalar(value):
value_t = _maybe_get_const(value, 't')
if isinstance(value_t, torch.Tensor) and value_t.shape == ():
return value_t
return value
def _get_const(value, desc, arg_name):
if _is_value(value) and value.node().kind() not in ('onnx::Constant',
'prim::Constant'):
raise RuntimeError('ONNX symbolic expected a constant'
' value of the {} argument, got `{}`'.format(
arg_name, value))
return _parse_arg(value, desc)
def _unpack_list(list_value):
list_node = list_value.node()
assert list_node.kind() == 'prim::ListConstruct'
return list(list_node.inputs())
# Check if list_value is output from prim::ListConstruct
# This is usually called before _unpack_list to ensure the list can be
# unpacked.
def _is_packed_list(list_value):
return _is_value(
list_value) and list_value.node().kind() == 'prim::ListConstruct'
def parse_args(*arg_descriptors):
def decorator(fn):
fn._arg_descriptors = arg_descriptors
def wrapper(g, *args):
# some args may be optional, so the length may be smaller
assert len(arg_descriptors) >= len(args)
args = [
_parse_arg(arg, arg_desc)
for arg, arg_desc in zip(args, arg_descriptors)
]
return fn(g, *args)
# In Python 2 functools.wraps chokes on partially applied functions, so
# we need this as a workaround
try:
wrapper = wraps(fn)(wrapper)
except Exception:
pass
return wrapper
return decorator
def _scalar(x):
"""Convert a scalar tensor into a Python value."""
assert x.numel() == 1
return x.item()
def _if_scalar_type_as(g, self, tensor):
"""Convert self into the same type of tensor, as necessary."""
if isinstance(self, torch._C.Value):
return self
scalar_type = tensor.type().scalarType()
if scalar_type:
ty = scalar_type.lower()
return getattr(self, ty)()
return self
def _is_none(x):
return x.node().mustBeNone()
def _is_value(x):
return isinstance(x, torch._C.Value)
def _is_tensor_list(x):
return x.type().isSubtypeOf(ListType.ofTensors())
def _unimplemented(op, msg):
warnings.warn('ONNX export failed on ' + op + ' because ' + msg +
' not supported')
def _try_get_scalar_type(*args):
for arg in args:
try:
return arg.type().scalarType()
except RuntimeError:
pass
return None
def _topk_helper(g, input, k, dim, largest=True, sorted=False, out=None):
if out is not None:
_unimplemented('TopK', 'Out parameter is not supported')
if not _is_value(k):
k = g.op('Constant', value_t=torch.tensor([k], dtype=torch.int64))
else:
k = g.op('Reshape', k, g.op('Constant', value_t=torch.tensor([1])))
return g.op(
'TopK',
input,
k,
axis_i=dim,
largest_i=largest,
sorted_i=sorted,
outputs=2)
def _slice_helper(g,
input,
axes,
starts,
ends,
steps=None,
dynamic_slice=False):
# TODO(ruobing): add support for opset<10
from torch.onnx.symbolic_opset10 import _slice
return _slice(g, input, axes, starts, ends, steps, dynamic_slice)
def _unsqueeze_helper(g, input, dim):
from torch.onnx.symbolic_opset9 import unsqueeze
return unsqueeze(g, input, dim)
def _interpolate_size_to_scales(g, input, output_size, dim):
output_size = _maybe_get_const(output_size, 'is')
if _is_value(output_size):
offset = 2
offsets = g.op(
'Constant', value_t=torch.ones(offset, dtype=torch.float32))
dividend = g.op(
'Cast', output_size, to_i=cast_pytorch_to_onnx['Float'])
divisor = _slice_helper(
g, g.op('Shape', input), axes=[0], ends=[maxsize], starts=[offset])
divisor = g.op('Cast', divisor, to_i=cast_pytorch_to_onnx['Float'])
scale_dims = g.op('Div', dividend, divisor)
scales = g.op('Concat', offsets, scale_dims, axis_i=0)
else:
scales_constant = [
1. if i < 2 else float(output_size[-(dim - i)]) /
float(input.type().sizes()[-(dim - i)]) for i in range(0, dim)
]
scales = g.op(
'Constant',
value_t=torch.tensor(scales_constant, dtype=torch.float32))
return scales
def _interpolate_get_scales_if_available(g, scales):
if len(scales) == 0:
return None
# scales[0] is NoneType in Pytorch == 1.5.1
# scales[0] is TensorType with sizes = [] in Pytorch == 1.6.0
# scales[0] is ListType in Pytorch == 1.7.0
# scales[0] is TensorType with sizes = [2] in Pytorch == 1.8.0
scale_desc = 'fs' if scales[0].type().kind() == 'ListType' or (
scales[0].type().kind() == 'TensorType' and
(sum(scales[0].type().sizes()) > 1)) else 'f'
available_scales = _maybe_get_const(
scales[0], scale_desc) != -1 and not _is_none(scales[0])
if not available_scales:
return None
offsets = g.op('Constant', value_t=torch.ones(2, dtype=torch.float32))
if scale_desc == 'fs':
scales_list = g.op(
'Constant',
value_t=torch.tensor(_maybe_get_const(scales[0], scale_desc)))
# modify to support PyTorch==1.7.0
# https://github.com/pytorch/pytorch/blob/75ee5756715e7161314ce037474843b68f69fc04/torch/onnx/symbolic_helper.py#L375 # noqa: E501
scales = g.op('Concat', offsets, scales_list, axis_i=0)
else:
# for PyTorch < 1.7.0
scales_list = []
for scale in scales:
unsqueezed_scale = _unsqueeze_helper(g, scale, 0)
# ONNX only supports float for the scales. double -> float.
unsqueezed_scale = g.op(
'Cast', unsqueezed_scale, to_i=cast_pytorch_to_onnx['Float'])
scales_list.append(unsqueezed_scale)
scales = g.op('Concat', offsets, *scales_list, axis_i=0)
return scales
def _get_interpolate_attributes(g, mode, args):
if mode == 'nearest':
align_corners = None
scales = args[0:]
else:
align_corners = args[0]
scales = args[1:]
scales = _interpolate_get_scales_if_available(g, scales)
return scales, align_corners
def _interpolate_get_scales(g, scale_factor, dim):
offsets = g.op('Constant', value_t=torch.ones(2, dtype=torch.float32))
if isinstance(scale_factor.type(), torch._C.ListType):
return g.op('Concat', offsets, scale_factor, axis_i=0)
else:
scale_factor = _unsqueeze_helper(g, scale_factor, 0)
scale_factor = g.op(
'Cast', scale_factor, to_i=cast_pytorch_to_onnx['Float'])
scales = [scale_factor for i in range(dim - 2)]
scale_factor = g.op('Concat', offsets, *scales, axis_i=0)
return scale_factor
def _size_helper(g, self, dim):
full_shape = g.op('Shape', self)
from torch.onnx.symbolic_opset9 import select
return select(g, full_shape, g.op('Constant', value_t=torch.tensor([0])),
dim)
def _avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override,
name):
if divisor_override and divisor_override.node().kind() != 'prim::Constant':
return _unimplemented(name, 'divisor_override')
if not stride:
stride = kernel_size
padding = tuple(tuple_fn(padding))
return padding
# Metaprogram symbolics for each ATen native specialized cast operator.
# For e.g. we specify a function named `_cast_uint8_t` that instantiates an
# ONNX cast node with `to` attribute 'UINT8'
#
# TODO: remove these once we support Type's in the JIT IR and we can once again
# use the unified toType operator
cast_pytorch_to_onnx = {
'Byte': torch.onnx.TensorProtoDataType.UINT8,
'Char': torch.onnx.TensorProtoDataType.INT8,
'Double': torch.onnx.TensorProtoDataType.DOUBLE,
'Float': torch.onnx.TensorProtoDataType.FLOAT,
'Half': torch.onnx.TensorProtoDataType.FLOAT16,
'Int': torch.onnx.TensorProtoDataType.INT32,
'Long': torch.onnx.TensorProtoDataType.INT64,
'Short': torch.onnx.TensorProtoDataType.INT16,
'Bool': torch.onnx.TensorProtoDataType.BOOL,
'ComplexFloat': torch.onnx.TensorProtoDataType.COMPLEX64,
'ComplexDouble': torch.onnx.TensorProtoDataType.COMPLEX128,
'Undefined': torch.onnx.TensorProtoDataType.UNDEFINED,
}
# Global set to store the list of quantized operators in the network.
# This is currently only used in the conversion of quantized ops from PT
# -> C2 via ONNX.
_quantized_ops: set = set()
# Copyright (c) OpenMMLab. All rights reserved.
"""Modified from https://github.com/pytorch/pytorch."""
import os
import warnings
import numpy as np
import torch
from torch.nn.modules.utils import _pair, _single, _triple
from torch.onnx.symbolic_helper import parse_args
from torch.onnx.symbolic_registry import register_op
from .onnx_utils import symbolic_helper as sym_help
def _interpolate(name, dim, interpolate_mode):
def symbolic_fn(g, input, output_size, *args):
scales, align_corners = sym_help._get_interpolate_attributes(
g, interpolate_mode, args)
align_corners = sym_help._maybe_get_scalar(align_corners)
transformation_mode = 'asymmetric' \
if interpolate_mode == 'nearest' \
else 'align_corners' if align_corners else 'pytorch_half_pixel'
empty_tensor = g.op(
'Constant', value_t=torch.tensor([], dtype=torch.float32))
if scales is None:
if 'ONNX_BACKEND' in os.environ and os.environ[
'ONNX_BACKEND'] == 'TensorRT':
input_size = input.type().sizes()
# slice the first two dim
input_size = input_size[:2]
# convert output_size to int type
output_size = sym_help._maybe_get_const(output_size, 'is')
input_size.extend(output_size)
output_size = g.op(
'Constant',
value_t=torch.tensor(input_size, dtype=torch.int64))
else:
input_size = g.op('Shape', input)
input_size_beg = sym_help._slice_helper(
g, input_size, axes=[0], ends=[2], starts=[0])
output_size = g.op(
'Cast',
output_size,
to_i=sym_help.cast_pytorch_to_onnx['Long'])
output_size = g.op(
'Concat', input_size_beg, output_size, axis_i=0)
scales = g.op(
'Constant', value_t=torch.tensor([], dtype=torch.float32))
return g.op(
'Resize',
input,
empty_tensor,
# roi only takes effect with
# coordinate_transformation_mode="tf_crop_and_resize"
scales, # scales is not needed since we are sending out_size
output_size,
coordinate_transformation_mode_s=transformation_mode,
cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
mode_s=interpolate_mode, # nearest, linear, or cubic
nearest_mode_s='floor') # only valid when mode="nearest"
else:
return g.op(
'Resize',
input,
empty_tensor,
# roi only takes effect with
# coordinate_transformation_mode="tf_crop_and_resize"
scales, # scales is not needed since we are sending out_size
coordinate_transformation_mode_s=transformation_mode,
cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
mode_s=interpolate_mode, # nearest, linear, or cubic
nearest_mode_s='floor') # only valid when mode="nearest"
return symbolic_fn
upsample_nearest1d = _interpolate('upsample_nearest1d', 3, 'nearest')
upsample_nearest2d = _interpolate('upsample_nearest2d', 4, 'nearest')
upsample_nearest3d = _interpolate('upsample_nearest3d', 5, 'nearest')
upsample_linear1d = _interpolate('upsample_linear1d', 3, 'linear')
upsample_bilinear2d = _interpolate('upsample_bilinear2d', 4, 'linear')
upsample_trilinear3d = _interpolate('upsample_trilinear3d', 5, 'linear')
upsample_bicubic2d = _interpolate('upsample_bicubic2d', 4, 'cubic')
@parse_args('v', 'v', 'i', 'i', 'i', 'none')
def topk(g, self, k, dim, largest, sorted, out=None):
return sym_help._topk_helper(
g, self, k, dim, largest=largest, sorted=sorted, out=out)
def masked_select(g, self, mask):
from torch.onnx.symbolic_opset9 import expand_as, nonzero
index = nonzero(g, expand_as(g, mask, self))
return g.op('GatherND', self, index)
def _prepare_onnx_paddings(g, dim, pad):
pad_len = torch.onnx.symbolic_opset9.size(
g, pad, g.op('Constant', value_t=torch.tensor([0])))
# Set extension = [0] * (dim * 2 - len(pad))
extension = g.op(
'Sub',
g.op('Mul',
g.op('Constant', value_t=torch.tensor(dim, dtype=torch.int64)),
g.op('Constant', value_t=torch.tensor(2, dtype=torch.int64))),
pad_len)
pad = g.op('Cast', pad, to_i=sym_help.cast_pytorch_to_onnx['Long'])
paddings = g.op(
'Concat',
pad,
g.op(
'ConstantOfShape',
extension,
value_t=torch.tensor([0], dtype=torch.int64)),
axis_i=0)
paddings = g.op('Reshape', paddings,
g.op('Constant', value_t=torch.tensor([-1, 2])))
paddings = g.op(
'Transpose',
torch.onnx.symbolic_opset10.flip(g, paddings, [0]),
perm_i=[1, 0])
paddings = g.op('Reshape', paddings,
g.op('Constant', value_t=torch.tensor([-1])))
padding_c = g.op(
'Cast', paddings, to_i=sym_help.cast_pytorch_to_onnx['Long'])
return padding_c
def constant_pad_nd(g, input, padding, value=None):
mode = 'constant'
value = sym_help._maybe_get_scalar(value)
value = sym_help._if_scalar_type_as(g, value, input)
pad = _prepare_onnx_paddings(g, input.type().dim(), padding)
return g.op('Pad', input, pad, value, mode_s=mode)
def reflection_pad(g, input, padding):
mode = 'reflect'
paddings = _prepare_onnx_paddings(g, input.type().dim(), padding)
return g.op('Pad', input, paddings, mode_s=mode)
reflection_pad1d = reflection_pad
reflection_pad2d = reflection_pad
reflection_pad3d = reflection_pad
def _avg_pool(name, tuple_fn):
@parse_args('v', 'is', 'is', 'is', 'i', 'i', 'none')
def symbolic_fn(g,
input,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override=None):
padding = sym_help._avgpool_helper(tuple_fn, padding, kernel_size,
stride, divisor_override, name)
if not stride:
stride = kernel_size
if count_include_pad:
input = g.op(
'Pad',
input,
g.op(
'Constant',
value_t=torch.tensor(((0, ) * 2 + padding) * 2)),
mode_s='constant')
padding = (0, ) * len(padding)
output = g.op(
'AveragePool',
input,
kernel_shape_i=tuple_fn(kernel_size),
strides_i=tuple_fn(stride),
pads_i=padding * 2,
ceil_mode_i=ceil_mode)
return output
return symbolic_fn
avg_pool1d = _avg_pool('avg_pool1d', _single)
avg_pool2d = _avg_pool('avg_pool2d', _pair)
avg_pool3d = _avg_pool('avg_pool3d', _triple)
def _get_im2col_indices_along_dim(g, input_d, kernel_size_d, dilation_d,
padding_d, stride_d):
# Input is always 4-D (N, C, H, W)
# Calculate indices of sliding blocks along spatial dimension
# Slide kernel over input each dim d:
# each dimension d ranges from 0 to
# input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1)
# with steps = stride
blocks_d = g.op('Add', input_d,
g.op('Constant', value_t=torch.tensor(padding_d * 2)))
blocks_d = g.op(
'Sub', blocks_d,
g.op(
'Constant',
value_t=torch.tensor(dilation_d * (kernel_size_d - 1))))
# Stride kernel over input and find starting indices along dim d
blocks_d_indices = g.op('Range', g.op('Constant', value_t=torch.tensor(0)),
blocks_d,
g.op('Constant', value_t=torch.tensor(stride_d)))
# Apply dilation on kernel and find its indices along dim d
kernel_grid = np.arange(0, kernel_size_d * dilation_d, dilation_d)
kernel_grid = g.op('Constant', value_t=torch.tensor([kernel_grid]))
# Broadcast and add kernel staring positions (indices) with
# kernel_grid along dim d, to get block indices along dim d
blocks_d_indices = g.op(
'Unsqueeze', blocks_d_indices, axes_i=[0]) # Reshape to [1, -1]
kernel_mask = g.op('Reshape', kernel_grid,
g.op('Constant', value_t=torch.tensor([-1, 1])))
block_mask = g.op('Add', blocks_d_indices, kernel_mask)
return block_mask
def _get_im2col_padded_input(g, input, padding_h, padding_w):
# Input is always 4-D tensor (N, C, H, W)
# Padding tensor has the following format: (padding_h, padding_w)
# Reshape the padding to follow ONNX format:
# (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...)
pad = g.op(
'Constant', value_t=torch.LongTensor([0, 0, padding_h, padding_w] * 2))
return g.op('Pad', input, pad)
def _get_im2col_output_shape(g, input, kernel_h, kernel_w):
batch_dim = size(g, input, g.op('Constant', value_t=torch.tensor(0)))
channel_dim = size(g, input, g.op('Constant', value_t=torch.tensor(1)))
channel_unfolded = g.op(
'Mul', channel_dim,
g.op('Constant', value_t=torch.tensor(kernel_h * kernel_w)))
return g.op(
'Concat',
g.op('Unsqueeze', batch_dim, axes_i=[0]),
g.op('Unsqueeze', channel_unfolded, axes_i=[0]),
g.op('Constant', value_t=torch.tensor([-1])),
axis_i=0)
def size(g, self, dim=None):
if dim is None:
return g.op('Shape', self)
return sym_help._size_helper(g, self, dim)
@parse_args('v', 'is', 'is', 'is', 'is')
def im2col(g, input, kernel_size, dilation, padding, stride):
# Input is always 4-D tensor (N, C, H, W)
# All other args are int[2]
input_h = size(g, input, g.op('Constant', value_t=torch.tensor(2)))
input_w = size(g, input, g.op('Constant', value_t=torch.tensor(3)))
stride_h, stride_w = stride[0], stride[1]
padding_h, padding_w = padding[0], padding[1]
dilation_h, dilation_w = dilation[0], dilation[1]
kernel_h, kernel_w = kernel_size[0], kernel_size[1]
blocks_row_indices = _get_im2col_indices_along_dim(g, input_h, kernel_h,
dilation_h, padding_h,
stride_h)
blocks_col_indices = _get_im2col_indices_along_dim(g, input_w, kernel_w,
dilation_w, padding_w,
stride_w)
output_shape = _get_im2col_output_shape(g, input, kernel_h, kernel_w)
padded_input = _get_im2col_padded_input(g, input, padding_h, padding_w)
output = g.op('Gather', padded_input, blocks_row_indices, axis_i=2)
output = g.op('Gather', output, blocks_col_indices, axis_i=4)
output = g.op('Transpose', output, perm_i=[0, 1, 2, 4, 3, 5])
return g.op('Reshape', output, output_shape)
@parse_args('v', 'i')
def one_hot(g, self, num_classes):
values = g.op('Constant', value_t=torch.LongTensor([0, 1]))
depth = g.op('Constant', value_t=torch.LongTensor([num_classes]))
return g.op('OneHot', self, depth, values, axis_i=-1)
@parse_args('v', 'i', 'none')
def softmax(g, input, dim, dtype=None):
input_dim = input.type().dim()
if input_dim:
# TODO: remove this as onnx opset 11 spec allows negative axes
if dim < 0:
dim = input_dim + dim
if input_dim == dim + 1:
softmax = g.op('Softmax', input, axis_i=dim)
if dtype and dtype.node().kind() != 'prim::Constant':
parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype')
softmax = g.op(
'Cast',
softmax,
to_i=sym_help.scalar_type_to_onnx[parsed_dtype])
return softmax
max_value = g.op('ReduceMax', input, axes_i=[dim], keepdims_i=1)
input = g.op('Sub', input, max_value)
exp = g.op('Exp', input)
sum = g.op('ReduceSum', exp, axes_i=[dim])
softmax = g.op('Div', exp, sum)
if dtype and dtype.node().kind() != 'prim::Constant':
parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype')
softmax = g.op(
'Cast', softmax, to_i=sym_help.scalar_type_to_onnx[parsed_dtype])
return softmax
def _adaptive_pool(name, type, tuple_fn, fn=None):
@parse_args('v', 'is')
def symbolic_fn(g, input, output_size):
if output_size == [1] * len(output_size) and type == 'AveragePool':
return g.op('GlobalAveragePool', input)
if not input.isCompleteTensor():
if output_size == [1] * len(output_size):
return g.op('GlobalMaxPool', input), None
raise NotImplementedError(
'[Adaptive pool]:input size not accessible')
dim = input.type().sizes()[2:]
if output_size == [1] * len(output_size) and type == 'MaxPool':
return g.op('GlobalMaxPool', input), None
# compute stride = floor(input_size / output_size)
s = [int(dim[i] / output_size[i]) for i in range(0, len(dim))]
# compute kernel_size = input_size - (output_size - 1) * stride
k = [dim[i] - (output_size[i] - 1) * s[i] for i in range(0, len(dim))]
# call max_poolxd_with_indices to get indices in the output
if type == 'MaxPool':
return fn(g, input, k, k, (0, ) * len(dim), (1, ) * len(dim),
False)
output = g.op(
type,
input,
kernel_shape_i=tuple_fn(k),
strides_i=tuple_fn(s),
ceil_mode_i=False)
return output
return symbolic_fn
adaptive_avg_pool1d = _adaptive_pool('adaptive_avg_pool1d', 'AveragePool',
_single)
adaptive_avg_pool2d = _adaptive_pool('adaptive_avg_pool2d', 'AveragePool',
_pair)
adaptive_avg_pool3d = _adaptive_pool('adaptive_avg_pool3d', 'AveragePool',
_triple)
def new_full(g,
self,
size,
fill_value,
dtype,
layout,
device,
pin_memory=False):
from torch.onnx.symbolic_opset9 import full
if dtype is None and self.isCompleteTensor():
dtype = self.type().scalarType()
dtype = sym_help.scalar_type_to_onnx.index(
sym_help.cast_pytorch_to_onnx[dtype])
return full(g, size, fill_value, dtype, layout, device, pin_memory)
@parse_args('v', 'v', 'i', 'i', 'i')
def grid_sampler(g,
input,
grid,
interpolation_mode,
padding_mode,
align_corners=False):
return g.op(
'mmcv::grid_sampler',
input,
grid,
interpolation_mode_i=interpolation_mode,
padding_mode_i=padding_mode,
align_corners_i=align_corners)
@parse_args('v', 'i')
def cummax(g, input, dim):
return g.op('mmcv::cummax', input, dim_i=dim, outputs=2)
@parse_args('v', 'i')
def cummin(g, input, dim):
return g.op('mmcv::cummin', input, dim_i=dim, outputs=2)
@parse_args('v', 'v', 'is')
def roll(g, input, shifts, dims):
from packaging import version
from torch.onnx.symbolic_opset9 import squeeze
input_shape = g.op('Shape', input)
need_flatten = len(dims) == 0
# If dims is not specified, the tensor will be flattened before
# rolling and then restored to the original shape.
if need_flatten:
resize_shape = input_shape
input = g.op('Reshape', input,
g.op('Constant', value_t=torch.LongTensor([1, -1])))
input_shape = g.op('Shape', input)
dims = [1]
for index, dim in enumerate(dims):
end_size = sym_help._slice_helper(
g, input_shape, axes=[0], ends=[dim + 1], starts=[dim])
shift_size = sym_help._slice_helper(
g, shifts, axes=[0], ends=[index + 1], starts=[index])
slice_size = g.op('Sub', end_size, shift_size)
# Can not use Mod because tensorrt does not support
div_size = g.op('Div', slice_size, end_size)
slice_size = g.op('Sub', slice_size, g.op('Mul', end_size, div_size))
if version.parse(torch.__version__) >= version.parse('1.7.0'):
# add dim=0 for pytorch 1.9.0
end_size = squeeze(g, end_size, 0)
slice_size = squeeze(g, slice_size, 0)
else:
end_size = g.op('Squeeze', end_size)
slice_size = g.op('Squeeze', slice_size)
dim = torch.LongTensor([dim])
input_slice0 = sym_help._slice_helper(
g,
input,
axes=dim,
starts=torch.LongTensor([0]),
ends=slice_size,
dynamic_slice=True)
input_slice1 = sym_help._slice_helper(
g,
input,
axes=dim,
ends=end_size,
starts=slice_size,
dynamic_slice=True)
input = g.op('Concat', input_slice1, input_slice0, axis_i=dim)
if need_flatten:
input = g.op('Reshape', input, resize_shape)
return input
def register_extra_symbolics(opset=11):
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This function will be deprecated in future. '
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)
register_op('one_hot', one_hot, '', opset)
register_op('im2col', im2col, '', opset)
register_op('topk', topk, '', opset)
register_op('softmax', softmax, '', opset)
register_op('constant_pad_nd', constant_pad_nd, '', opset)
register_op('reflection_pad1d', reflection_pad1d, '', opset)
register_op('reflection_pad2d', reflection_pad2d, '', opset)
register_op('reflection_pad3d', reflection_pad3d, '', opset)
register_op('avg_pool1d', avg_pool1d, '', opset)
register_op('avg_pool2d', avg_pool2d, '', opset)
register_op('avg_pool3d', avg_pool3d, '', opset)
register_op('adaptive_avg_pool1d', adaptive_avg_pool1d, '', opset)
register_op('adaptive_avg_pool2d', adaptive_avg_pool2d, '', opset)
register_op('adaptive_avg_pool3d', adaptive_avg_pool3d, '', opset)
register_op('masked_select', masked_select, '', opset)
register_op('upsample_nearest1d', upsample_nearest1d, '', opset)
register_op('upsample_nearest2d', upsample_nearest2d, '', opset)
register_op('upsample_nearest3d', upsample_nearest3d, '', opset)
register_op('upsample_linear1d', upsample_linear1d, '', opset)
register_op('upsample_bilinear2d', upsample_bilinear2d, '', opset)
register_op('upsample_trilinear3d', upsample_trilinear3d, '', opset)
register_op('upsample_bicubic2d', upsample_bicubic2d, '', opset)
register_op('new_full', new_full, '', opset)
register_op('grid_sampler', grid_sampler, '', opset)
register_op('cummax', cummax, '', opset)
register_op('cummin', cummin, '', opset)
register_op('roll', roll, '', opset)
...@@ -33,38 +33,6 @@ class NMSop(torch.autograd.Function): ...@@ -33,38 +33,6 @@ class NMSop(torch.autograd.Function):
inds = valid_inds[inds] inds = valid_inds[inds]
return inds return inds
@staticmethod
def symbolic(g, bboxes, scores, iou_threshold, offset, score_threshold,
max_num):
from torch.onnx.symbolic_opset9 import select, squeeze, unsqueeze
from ..onnx.onnx_utils.symbolic_helper import _size_helper
boxes = unsqueeze(g, bboxes, 0)
scores = unsqueeze(g, unsqueeze(g, scores, 0), 0)
if max_num > 0:
max_num = g.op(
'Constant', value_t=torch.tensor(max_num, dtype=torch.long))
else:
dim = g.op('Constant', value_t=torch.tensor(0))
max_num = _size_helper(g, bboxes, dim)
max_output_per_class = max_num
iou_threshold = g.op(
'Constant',
value_t=torch.tensor([iou_threshold], dtype=torch.float))
score_threshold = g.op(
'Constant',
value_t=torch.tensor([score_threshold], dtype=torch.float))
nms_out = g.op('NonMaxSuppression', boxes, scores,
max_output_per_class, iou_threshold, score_threshold)
return squeeze(
g,
select(
g, nms_out, 1,
g.op('Constant', value_t=torch.tensor([2], dtype=torch.long))),
1)
class SoftNMSop(torch.autograd.Function): class SoftNMSop(torch.autograd.Function):
...@@ -330,7 +298,7 @@ def batched_nms(boxes: Tensor, ...@@ -330,7 +298,7 @@ def batched_nms(boxes: Tensor,
split_thr = nms_cfg_.pop('split_thr', 10000) split_thr = nms_cfg_.pop('split_thr', 10000)
# Won't split to multiple nms nodes when exporting to onnx # Won't split to multiple nms nodes when exporting to onnx
if boxes_for_nms.shape[0] < split_thr or torch.onnx.is_in_onnx_export(): if boxes_for_nms.shape[0] < split_thr:
dets, keep = nms_op(boxes_for_nms, scores, **nms_cfg_) dets, keep = nms_op(boxes_for_nms, scores, **nms_cfg_)
boxes = boxes[keep] boxes = boxes[keep]
......
...@@ -7,7 +7,6 @@ import torch.nn as nn ...@@ -7,7 +7,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torch.onnx.operators import shape_as_tensor
def bilinear_grid_sample(im: Tensor, def bilinear_grid_sample(im: Tensor,
...@@ -178,12 +177,8 @@ def get_shape_from_feature_map(x: Tensor) -> Tensor: ...@@ -178,12 +177,8 @@ def get_shape_from_feature_map(x: Tensor) -> Tensor:
Returns: Returns:
torch.Tensor: Spatial resolution (width, height), shape (1, 1, 2) torch.Tensor: Spatial resolution (width, height), shape (1, 1, 2)
""" """
if torch.onnx.is_in_onnx_export(): img_shape = torch.tensor(x.shape[2:]).flip(0).view(1, 1,
img_shape = shape_as_tensor(x)[2:].flip(0).view(1, 1, 2).to( 2).to(x.device).float()
x.device).float()
else:
img_shape = torch.tensor(x.shape[2:]).flip(0).view(1, 1, 2).to(
x.device).float()
return img_shape return img_shape
...@@ -272,15 +267,8 @@ def point_sample(input: Tensor, ...@@ -272,15 +267,8 @@ def point_sample(input: Tensor,
if points.dim() == 3: if points.dim() == 3:
add_dim = True add_dim = True
points = points.unsqueeze(2) points = points.unsqueeze(2)
if torch.onnx.is_in_onnx_export(): output = F.grid_sample(
# If custom ops for onnx runtime not compiled use python input, denormalize(points), align_corners=align_corners, **kwargs)
# implementation of grid_sample function to make onnx graph
# with supported nodes
output = bilinear_grid_sample(
input, denormalize(points), align_corners=align_corners)
else:
output = F.grid_sample(
input, denormalize(points), align_corners=align_corners, **kwargs)
if add_dim: if add_dim:
output = output.squeeze(3) output = output.squeeze(3)
return output return output
...@@ -315,33 +303,25 @@ class SimpleRoIAlign(nn.Module): ...@@ -315,33 +303,25 @@ class SimpleRoIAlign(nn.Module):
rel_roi_points = generate_grid( rel_roi_points = generate_grid(
num_rois, self.output_size, device=rois.device) num_rois, self.output_size, device=rois.device)
if torch.onnx.is_in_onnx_export(): point_feats = []
rel_img_points = rel_roi_point_to_rel_img_point( for batch_ind in range(num_imgs):
rois, rel_roi_points, features, self.spatial_scale) # unravel batch dim
rel_img_points = rel_img_points.reshape(num_imgs, -1, feat = features[batch_ind].unsqueeze(0)
*rel_img_points.shape[1:]) inds = (rois[:, 0].long() == batch_ind)
point_feats = point_sample( if inds.any():
features, rel_img_points, align_corners=not self.aligned) rel_img_points = rel_roi_point_to_rel_img_point(
point_feats = point_feats.transpose(1, 2) rois[inds], rel_roi_points[inds], feat,
else: self.spatial_scale).unsqueeze(0)
point_feats = [] point_feat = point_sample(
for batch_ind in range(num_imgs): feat, rel_img_points, align_corners=not self.aligned)
# unravel batch dim point_feat = point_feat.squeeze(0).transpose(0, 1)
feat = features[batch_ind].unsqueeze(0) point_feats.append(point_feat)
inds = (rois[:, 0].long() == batch_ind)
if inds.any(): point_feats_t = torch.cat(point_feats, dim=0)
rel_img_points = rel_roi_point_to_rel_img_point(
rois[inds], rel_roi_points[inds], feat,
self.spatial_scale).unsqueeze(0)
point_feat = point_sample(
feat, rel_img_points, align_corners=not self.aligned)
point_feat = point_feat.squeeze(0).transpose(0, 1)
point_feats.append(point_feat)
point_feats = torch.cat(point_feats, dim=0)
channels = features.size(1) channels = features.size(1)
roi_feats = point_feats.reshape(num_rois, channels, *self.output_size) roi_feats = point_feats_t.reshape(num_rois, channels,
*self.output_size)
return roi_feats return roi_feats
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os import os
from functools import partial
import numpy as np import numpy as np
import onnx import onnx
...@@ -8,7 +7,6 @@ import onnxruntime as rt ...@@ -8,7 +7,6 @@ import onnxruntime as rt
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
onnx_file = 'tmp.onnx' onnx_file = 'tmp.onnx'
if torch.__version__ == 'parrots': if torch.__version__ == 'parrots':
...@@ -38,49 +36,6 @@ class WrapFunction(nn.Module): ...@@ -38,49 +36,6 @@ class WrapFunction(nn.Module):
return self.wrapped_function(*args, **kwargs) return self.wrapped_function(*args, **kwargs)
def test_nms():
from mmcv.ops import nms
np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0],
[3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]],
dtype=np.float32)
np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32)
boxes = torch.from_numpy(np_boxes)
scores = torch.from_numpy(np_scores)
nms = partial(
nms, iou_threshold=0.3, offset=0, score_threshold=0, max_num=0)
pytorch_dets, _ = nms(boxes, scores)
pytorch_score = pytorch_dets[:, 4]
wrapped_model = WrapFunction(nms)
wrapped_model.cpu().eval()
with torch.no_grad():
torch.onnx.export(
wrapped_model, (boxes, scores),
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['boxes', 'scores'],
opset_version=11)
onnx_model = onnx.load(onnx_file)
session_options = rt.SessionOptions()
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [node.name for node in onnx_model.graph.initializer]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 2)
sess = rt.InferenceSession(
onnx_file, session_options, providers=['CPUExecutionProvider'])
onnx_dets, _ = sess.run(None, {
'scores': scores.detach().numpy(),
'boxes': boxes.detach().numpy()
})
onnx_score = onnx_dets[:, 4]
assert np.allclose(pytorch_score, onnx_score, atol=1e-3)
def test_roialign(): def test_roialign():
try: try:
from mmcv.ops import roi_align from mmcv.ops import roi_align
...@@ -212,69 +167,6 @@ def test_roipool(): ...@@ -212,69 +167,6 @@ def test_roipool():
assert np.allclose(pytorch_output, onnx_output, atol=1e-3) assert np.allclose(pytorch_output, onnx_output, atol=1e-3)
def test_interpolate():
from mmcv.onnx.symbolic import register_extra_symbolics
opset_version = 11
register_extra_symbolics(opset_version)
def func(feat, scale_factor=2):
out = F.interpolate(feat, scale_factor=scale_factor)
return out
net = WrapFunction(func)
net = net.cpu().eval()
dummy_input = torch.randn(2, 4, 8, 8).cpu()
torch.onnx.export(
net,
dummy_input,
onnx_file,
input_names=['input'],
opset_version=opset_version)
sess = rt.InferenceSession(onnx_file, providers=['CPUExecutionProvider'])
onnx_result = sess.run(None, {'input': dummy_input.detach().numpy()})
pytorch_result = func(dummy_input).detach().numpy()
assert np.allclose(pytorch_result, onnx_result, atol=1e-3)
@pytest.mark.parametrize('shifts_dims_pair', [([-3, 5], [2, 0]), (5, None)])
def test_roll(shifts_dims_pair):
opset = 11
from mmcv.onnx.symbolic import register_extra_symbolics
register_extra_symbolics(opset)
input = torch.arange(0, 4 * 5 * 6, dtype=torch.float32).view(4, 5, 6)
shifts, dims = shifts_dims_pair
func = partial(torch.roll, shifts=shifts, dims=dims)
wrapped_model = WrapFunction(func).eval()
with torch.no_grad():
torch.onnx.export(
wrapped_model,
input,
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['input'],
output_names=['output'],
opset_version=opset)
onnx_model = onnx.load(onnx_file)
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [node.name for node in onnx_model.graph.initializer]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 1)
sess = rt.InferenceSession(onnx_file, providers=['CPUExecutionProvider'])
ort_output = sess.run(None, {'input': input.detach().numpy()})[0]
with torch.no_grad():
pytorch_output = wrapped_model(input.clone())
torch.testing.assert_allclose(ort_output, pytorch_output)
def _test_symbolic(model, inputs, symbol_name): def _test_symbolic(model, inputs, symbol_name):
with torch.no_grad(): with torch.no_grad():
torch.onnx.export(model, inputs, onnx_file, opset_version=11) torch.onnx.export(model, inputs, onnx_file, opset_version=11)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment