Unverified Commit a88d1d28 authored by Dmitry Sidnev's avatar Dmitry Sidnev Committed by GitHub
Browse files

[Feature] enable exporting to onnx for PointRend (#953)



* Fix export to onnx for PointRend

* Fix codestyle

* Fix codestyle

* Fix type in docstring

* Minor fix

* Fix export with custom ops

* Fix codestyle

* Add tests for bilinear_grid_sample function

* Remove redundant operation and rename variables

* Fix bug in bilinear_grid_sample and update test

* Fix getting batch size

* skip torch==1.3.1

* remove unused import

* fix lint

* support export with batch

* fix dynamic clip

* skip test for torch<1.5.0

* Add docstrings and comments

* Minor fix

* Recover clipping code

* Fix clamping in pytorch 1.7.0

* Fix bilinear_grid_sampler

* Minor fix
Co-authored-by: default avatarmaningsheng <maningsheng@sensetime.com>
parent 1076958c
# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend # noqa # Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend # noqa
from os import path as osp
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torch.onnx.operators import shape_as_tensor
def bilinear_grid_sample(im, grid, align_corners=False):
"""Given an input and a flow-field grid, computes the output using input
values and pixel locations from grid. Supported only bilinear interpolation
method to sample the input pixels.
Args:
im (torch.Tensor): Input feature map, shape (N, C, H, W)
grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
align_corners {bool}: If set to True, the extrema (-1 and 1) are
considered as referring to the center points of the input’s
corner pixels. If set to False, they are instead considered as
referring to the corner points of the input’s corner pixels,
making the sampling more resolution agnostic.
Returns:
torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
"""
n, c, h, w = im.shape
gn, gh, gw, _ = grid.shape
assert n == gn
x = grid[:, :, :, 0]
y = grid[:, :, :, 1]
if align_corners:
x = ((x + 1) / 2) * (w - 1)
y = ((y + 1) / 2) * (h - 1)
else:
x = ((x + 1) * w - 1) / 2
y = ((y + 1) * h - 1) / 2
x = x.view(n, -1)
y = y.view(n, -1)
x0 = torch.floor(x).long()
y0 = torch.floor(y).long()
x1 = x0 + 1
y1 = y0 + 1
wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
wb = ((x1 - x) * (y - y0)).unsqueeze(1)
wc = ((x - x0) * (y1 - y)).unsqueeze(1)
wd = ((x - x0) * (y - y0)).unsqueeze(1)
# Apply default for grid_sample function zero padding
im_padded = F.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)
padded_h = h + 2
padded_w = w + 2
# save points positions after padding
x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1
# Clip coordinates to padded image size
x0 = torch.where(x0 < 0, torch.tensor(0), x0)
x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0)
x1 = torch.where(x1 < 0, torch.tensor(0), x1)
x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1)
y0 = torch.where(y0 < 0, torch.tensor(0), y0)
y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0)
y1 = torch.where(y1 < 0, torch.tensor(0), y1)
y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1)
im_padded = im_padded.view(n, c, -1)
x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
Ia = torch.gather(im_padded, 2, x0_y0)
Ib = torch.gather(im_padded, 2, x0_y1)
Ic = torch.gather(im_padded, 2, x1_y0)
Id = torch.gather(im_padded, 2, x1_y1)
return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)
def is_in_onnx_export_without_custom_ops():
from mmcv.ops import get_onnxruntime_op_path
ort_custom_op_path = get_onnxruntime_op_path()
return torch.onnx.is_in_onnx_export(
) and not osp.exists(ort_custom_op_path)
def normalize(grid): def normalize(grid):
...@@ -70,25 +155,42 @@ def rel_roi_point_to_abs_img_point(rois, rel_roi_points): ...@@ -70,25 +155,42 @@ def rel_roi_point_to_abs_img_point(rois, rel_roi_points):
if rois.size(1) == 5: if rois.size(1) == 5:
rois = rois[:, 1:] rois = rois[:, 1:]
abs_img_points = rel_roi_points.clone() abs_img_points = rel_roi_points.clone()
abs_img_points[:, :, 0] = abs_img_points[:, :, 0] * ( # To avoid an error during exporting to onnx use independent
rois[:, None, 2] - rois[:, None, 0]) # variables instead inplace computation
abs_img_points[:, :, 1] = abs_img_points[:, :, 1] * ( xs = abs_img_points[:, :, 0] * (rois[:, None, 2] - rois[:, None, 0])
rois[:, None, 3] - rois[:, None, 1]) ys = abs_img_points[:, :, 1] * (rois[:, None, 3] - rois[:, None, 1])
abs_img_points[:, :, 0] += rois[:, None, 0] xs += rois[:, None, 0]
abs_img_points[:, :, 1] += rois[:, None, 1] ys += rois[:, None, 1]
abs_img_points = torch.stack([xs, ys], dim=2)
return abs_img_points return abs_img_points
def abs_img_point_to_rel_img_point(abs_img_points, def get_shape_from_feature_map(x):
img_shape, """Get spatial resolution of input feature map considering exporting to
spatial_scale=1.): onnx mode.
Args:
x (torch.Tensor): Input tensor, shape (N, C, H, W)
Returns:
torch.Tensor: Spatial resolution (width, height), shape (1, 1, 2)
"""
if torch.onnx.is_in_onnx_export():
img_shape = shape_as_tensor(x)[2:].flip(0).view(1, 1, 2).to(
x.device).float()
else:
img_shape = torch.tensor(x.shape[2:]).flip(0).view(1, 1, 2).to(
x.device).float()
return img_shape
def abs_img_point_to_rel_img_point(abs_img_points, img, spatial_scale=1.):
"""Convert image based absolute point coordinates to image based relative """Convert image based absolute point coordinates to image based relative
coordinates for sampling. coordinates for sampling.
Args: Args:
abs_img_points (Tensor): Image based absolute point coordinates, abs_img_points (Tensor): Image based absolute point coordinates,
shape (N, P, 2) shape (N, P, 2)
img_shape (tuple): (height, width) of image or feature map. img (tuple/Tensor): (height, width) of image or feature map.
spatial_scale (float): Scale points by this factor. Default: 1. spatial_scale (float): Scale points by this factor. Default: 1.
Returns: Returns:
...@@ -96,20 +198,24 @@ def abs_img_point_to_rel_img_point(abs_img_points, ...@@ -96,20 +198,24 @@ def abs_img_point_to_rel_img_point(abs_img_points,
shape (N, P, 2) shape (N, P, 2)
""" """
assert isinstance(img_shape, tuple) and len(img_shape) == 2 assert (isinstance(img, tuple) and len(img) == 2) or \
h, w = img_shape (isinstance(img, torch.Tensor) and len(img.shape) == 4)
scale = torch.tensor([w, h],
dtype=torch.float,
device=abs_img_points.device)
scale = scale.view(1, 1, 2)
rel_img_points = abs_img_points / scale * spatial_scale
return rel_img_points if isinstance(img, tuple):
h, w = img
scale = torch.tensor([w, h],
dtype=torch.float,
device=abs_img_points.device)
scale = scale.view(1, 1, 2)
else:
scale = get_shape_from_feature_map(img)
return abs_img_points / scale * spatial_scale
def rel_roi_point_to_rel_img_point(rois, def rel_roi_point_to_rel_img_point(rois,
rel_roi_points, rel_roi_points,
img_shape, img,
spatial_scale=1.): spatial_scale=1.):
"""Convert roi based relative point coordinates to image based absolute """Convert roi based relative point coordinates to image based absolute
point coordinates. point coordinates.
...@@ -118,7 +224,7 @@ def rel_roi_point_to_rel_img_point(rois, ...@@ -118,7 +224,7 @@ def rel_roi_point_to_rel_img_point(rois,
rois (Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5) rois (Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5)
rel_roi_points (Tensor): Point coordinates inside RoI, relative to rel_roi_points (Tensor): Point coordinates inside RoI, relative to
RoI, location, range (0, 1), shape (N, P, 2) RoI, location, range (0, 1), shape (N, P, 2)
img_shape (tuple): (height, width) of image or feature map. img (tuple/Tensor): (height, width) of image or feature map.
spatial_scale (float): Scale points by this factor. Default: 1. spatial_scale (float): Scale points by this factor. Default: 1.
Returns: Returns:
...@@ -127,7 +233,7 @@ def rel_roi_point_to_rel_img_point(rois, ...@@ -127,7 +233,7 @@ def rel_roi_point_to_rel_img_point(rois,
""" """
abs_img_point = rel_roi_point_to_abs_img_point(rois, rel_roi_points) abs_img_point = rel_roi_point_to_abs_img_point(rois, rel_roi_points)
rel_img_point = abs_img_point_to_rel_img_point(abs_img_point, img_shape, rel_img_point = abs_img_point_to_rel_img_point(abs_img_point, img,
spatial_scale) spatial_scale)
return rel_img_point return rel_img_point
...@@ -153,8 +259,15 @@ def point_sample(input, points, align_corners=False, **kwargs): ...@@ -153,8 +259,15 @@ def point_sample(input, points, align_corners=False, **kwargs):
if points.dim() == 3: if points.dim() == 3:
add_dim = True add_dim = True
points = points.unsqueeze(2) points = points.unsqueeze(2)
output = F.grid_sample( if is_in_onnx_export_without_custom_ops():
input, denormalize(points), align_corners=align_corners, **kwargs) # If custom ops for onnx runtime not compiled use python
# implementation of grid_sample function to make onnx graph
# with supported nodes
output = bilinear_grid_sample(
input, denormalize(points), align_corners=align_corners)
else:
output = F.grid_sample(
input, denormalize(points), align_corners=align_corners, **kwargs)
if add_dim: if add_dim:
output = output.squeeze(3) output = output.squeeze(3)
return output return output
...@@ -181,29 +294,38 @@ class SimpleRoIAlign(nn.Module): ...@@ -181,29 +294,38 @@ class SimpleRoIAlign(nn.Module):
self.aligned = aligned self.aligned = aligned
def forward(self, features, rois): def forward(self, features, rois):
num_imgs = features.size(0) num_imgs = features.size(0)
num_rois = rois.size(0) num_rois = rois.size(0)
rel_roi_points = generate_grid( rel_roi_points = generate_grid(
num_rois, self.output_size, device=rois.device) num_rois, self.output_size, device=rois.device)
point_feats = [] if torch.onnx.is_in_onnx_export():
for batch_ind in range(num_imgs): rel_img_points = rel_roi_point_to_rel_img_point(
# unravel batch dim rois, rel_roi_points, features, self.spatial_scale)
feat = features[batch_ind].unsqueeze(0) rel_img_points = rel_img_points.reshape(num_imgs, -1,
inds = (rois[:, 0].long() == batch_ind) *rel_img_points.shape[1:])
if inds.any(): point_feats = point_sample(
rel_img_points = rel_roi_point_to_rel_img_point( features, rel_img_points, align_corners=not self.aligned)
rois[inds], rel_roi_points[inds], feat.shape[2:], point_feats = point_feats.transpose(1, 2)
self.spatial_scale).unsqueeze(0) else:
point_feat = point_sample( point_feats = []
feat, rel_img_points, align_corners=not self.aligned) for batch_ind in range(num_imgs):
point_feat = point_feat.squeeze(0).transpose(0, 1) # unravel batch dim
point_feats.append(point_feat) feat = features[batch_ind].unsqueeze(0)
inds = (rois[:, 0].long() == batch_ind)
if inds.any():
rel_img_points = rel_roi_point_to_rel_img_point(
rois[inds], rel_roi_points[inds], feat,
self.spatial_scale).unsqueeze(0)
point_feat = point_sample(
feat, rel_img_points, align_corners=not self.aligned)
point_feat = point_feat.squeeze(0).transpose(0, 1)
point_feats.append(point_feat)
point_feats = torch.cat(point_feats, dim=0)
channels = features.size(1) channels = features.size(1)
roi_feats = torch.cat(point_feats, dim=0) roi_feats = point_feats.reshape(num_rois, channels, *self.output_size)
roi_feats = roi_feats.reshape(num_rois, channels, *self.output_size)
return roi_feats return roi_feats
......
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class TestBilinearGridSample(object):
def _test_bilinear_grid_sample(self,
dtype=torch.float,
align_corners=False,
multiplier=1,
precision=1e-3):
from mmcv.ops.point_sample import bilinear_grid_sample
input = torch.rand(1, 1, 20, 20, dtype=dtype)
grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
grid = nn.functional.affine_grid(grid, (1, 1, 15, 15)).type_as(input)
grid *= multiplier
out = bilinear_grid_sample(input, grid, align_corners=align_corners)
ref_out = F.grid_sample(input, grid, align_corners=align_corners)
assert np.allclose(out.data.detach().cpu().numpy(),
ref_out.data.detach().cpu().numpy(), precision)
def test_bilinear_grid_sample(self):
self._test_bilinear_grid_sample(torch.double, False)
self._test_bilinear_grid_sample(torch.double, True)
self._test_bilinear_grid_sample(torch.float, False)
self._test_bilinear_grid_sample(torch.float, True)
self._test_bilinear_grid_sample(torch.float, False)
self._test_bilinear_grid_sample(torch.float, True, 5)
self._test_bilinear_grid_sample(torch.float, False, 10)
self._test_bilinear_grid_sample(torch.float, True, -6)
self._test_bilinear_grid_sample(torch.float, False, -10)
self._test_bilinear_grid_sample(torch.double, True, 5)
self._test_bilinear_grid_sample(torch.double, False, 10)
self._test_bilinear_grid_sample(torch.double, True, -6)
self._test_bilinear_grid_sample(torch.double, False, -10)
...@@ -23,31 +23,7 @@ class WrapFunction(nn.Module): ...@@ -23,31 +23,7 @@ class WrapFunction(nn.Module):
return self.wrapped_function(*args, **kwargs) return self.wrapped_function(*args, **kwargs)
@pytest.mark.parametrize('mode', ['bilinear', 'nearest']) def process_grid_sample(func, input, grid, ort_custom_op_path=''):
@pytest.mark.parametrize('padding_mode', ['zeros', 'border', 'reflection'])
@pytest.mark.parametrize('align_corners', [True, False])
def test_grid_sample(mode, padding_mode, align_corners):
from mmcv.onnx.symbolic import register_extra_symbolics
opset_version = 11
register_extra_symbolics(opset_version)
from mmcv.ops import get_onnxruntime_op_path
ort_custom_op_path = get_onnxruntime_op_path()
if not os.path.exists(ort_custom_op_path):
pytest.skip('custom ops for onnxruntime are not compiled.')
input = torch.rand(1, 1, 10, 10)
grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
grid = nn.functional.affine_grid(grid, (1, 1, 15, 15)).type_as(input)
def func(input, grid):
return nn.functional.grid_sample(
input,
grid,
mode=mode,
padding_mode=padding_mode,
align_corners=align_corners)
wrapped_model = WrapFunction(func).eval() wrapped_model = WrapFunction(func).eval()
input_names = ['input', 'grid'] input_names = ['input', 'grid']
...@@ -66,7 +42,8 @@ def test_grid_sample(mode, padding_mode, align_corners): ...@@ -66,7 +42,8 @@ def test_grid_sample(mode, padding_mode, align_corners):
onnx_model = onnx.load(onnx_file) onnx_model = onnx.load(onnx_file)
session_options = rt.SessionOptions() session_options = rt.SessionOptions()
session_options.register_custom_ops_library(ort_custom_op_path) if ort_custom_op_path:
session_options.register_custom_ops_library(ort_custom_op_path)
# get onnx output # get onnx output
input_all = [node.name for node in onnx_model.graph.input] input_all = [node.name for node in onnx_model.graph.input]
...@@ -83,6 +60,51 @@ def test_grid_sample(mode, padding_mode, align_corners): ...@@ -83,6 +60,51 @@ def test_grid_sample(mode, padding_mode, align_corners):
assert np.allclose(pytorch_results, ort_result, atol=1e-3) assert np.allclose(pytorch_results, ort_result, atol=1e-3)
@pytest.mark.parametrize('mode', ['bilinear', 'nearest'])
@pytest.mark.parametrize('padding_mode', ['zeros', 'border', 'reflection'])
@pytest.mark.parametrize('align_corners', [True, False])
def test_grid_sample(mode, padding_mode, align_corners):
from mmcv.onnx.symbolic import register_extra_symbolics
opset_version = 11
register_extra_symbolics(opset_version)
from mmcv.ops import get_onnxruntime_op_path
ort_custom_op_path = get_onnxruntime_op_path()
if not os.path.exists(ort_custom_op_path):
pytest.skip('custom ops for onnxruntime are not compiled.')
input = torch.rand(1, 1, 10, 10)
grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
grid = nn.functional.affine_grid(grid, (1, 1, 15, 15)).type_as(input)
def func(input, grid):
return nn.functional.grid_sample(
input,
grid,
mode=mode,
padding_mode=padding_mode,
align_corners=align_corners)
return process_grid_sample(func, input, grid, ort_custom_op_path)
@pytest.mark.parametrize('align_corners', [True, False])
def test_bilinear_grid_sample(align_corners):
from mmcv.ops.point_sample import bilinear_grid_sample
# only support pytorch >= 1.5.0
if version.parse(torch.__version__) < version.parse('1.5.0'):
pytest.skip('Only support PyTorch >= 1.5.0')
input = torch.rand(1, 1, 10, 10)
grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
grid = nn.functional.affine_grid(grid, (1, 1, 15, 15)).type_as(input)
def func(input, grid):
return bilinear_grid_sample(input, grid, align_corners=align_corners)
return process_grid_sample(func, input, grid)
def test_nms(): def test_nms():
if torch.__version__ == 'parrots': if torch.__version__ == 'parrots':
pytest.skip('onnx is not supported in parrots directly') pytest.skip('onnx is not supported in parrots directly')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment