Commit 1d6145d1 authored by Lara Haidar's avatar Lara Haidar Committed by Francisco Massa
Browse files

Support Exporting RPN to ONNX (#1329)

* Support Exporting RPN to ONNX

* address PR comments

* fix cat

* add flatten

* replace cat by stack

* update test to run only on rpn module

* use tolerate_small_mismatch
parent f16b6723
import io import io
import torch import torch
from torchvision import ops from torchvision import ops
from torchvision.models.detection.image_list import ImageList
from torchvision.models.detection.transform import GeneralizedRCNNTransform from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from collections import OrderedDict from collections import OrderedDict
...@@ -20,7 +23,7 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -20,7 +23,7 @@ class ONNXExporterTester(unittest.TestCase):
def setUpClass(cls): def setUpClass(cls):
torch.manual_seed(123) torch.manual_seed(123)
def run_model(self, model, inputs_list): def run_model(self, model, inputs_list, tolerate_small_mismatch=False):
model.eval() model.eval()
onnx_io = io.BytesIO() onnx_io = io.BytesIO()
...@@ -36,9 +39,9 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -36,9 +39,9 @@ class ONNXExporterTester(unittest.TestCase):
test_ouputs = model(*test_inputs) test_ouputs = model(*test_inputs)
if isinstance(test_ouputs, torch.Tensor): if isinstance(test_ouputs, torch.Tensor):
test_ouputs = (test_ouputs,) test_ouputs = (test_ouputs,)
self.ort_validate(onnx_io, test_inputs, test_ouputs) self.ort_validate(onnx_io, test_inputs, test_ouputs, tolerate_small_mismatch)
def ort_validate(self, onnx_io, inputs, outputs): def ort_validate(self, onnx_io, inputs, outputs, tolerate_small_mismatch=False):
inputs, _ = torch.jit._flatten(inputs) inputs, _ = torch.jit._flatten(inputs)
outputs, _ = torch.jit._flatten(outputs) outputs, _ = torch.jit._flatten(outputs)
...@@ -58,7 +61,13 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -58,7 +61,13 @@ class ONNXExporterTester(unittest.TestCase):
ort_outs = ort_session.run(None, ort_inputs) ort_outs = ort_session.run(None, ort_inputs)
for i in range(0, len(outputs)): for i in range(0, len(outputs)):
try:
torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05) torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05)
except AssertionError as error:
if tolerate_small_mismatch:
assert ("(0.00%)" in str(error)), str(error)
else:
assert False, str(error)
def test_nms(self): def test_nms(self):
boxes = torch.rand(5, 4) boxes = torch.rand(5, 4)
...@@ -91,11 +100,7 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -91,11 +100,7 @@ class ONNXExporterTester(unittest.TestCase):
class TransformModule(torch.nn.Module): class TransformModule(torch.nn.Module):
def __init__(self_module): def __init__(self_module):
super(TransformModule, self_module).__init__() super(TransformModule, self_module).__init__()
min_size = 800 self_module.transform = self._init_test_generalized_rcnn_transform()
max_size = 1333
image_mean = [0.485, 0.456, 0.406]
image_std = [0.229, 0.224, 0.225]
self_module.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
def forward(self_module, images): def forward(self_module, images):
return self_module.transform(images)[0].tensors return self_module.transform(images)[0].tensors
...@@ -104,6 +109,66 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -104,6 +109,66 @@ class ONNXExporterTester(unittest.TestCase):
input_test = [torch.rand(3, 800, 1280), torch.rand(3, 800, 800)] input_test = [torch.rand(3, 800, 1280), torch.rand(3, 800, 800)]
self.run_model(TransformModule(), [input, input_test]) self.run_model(TransformModule(), [input, input_test])
def _init_test_generalized_rcnn_transform(self):
min_size = 800
max_size = 1333
image_mean = [0.485, 0.456, 0.406]
image_std = [0.229, 0.224, 0.225]
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
return transform
def _init_test_rpn(self):
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
out_channels = 256
rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
rpn_fg_iou_thresh = 0.7
rpn_bg_iou_thresh = 0.3
rpn_batch_size_per_image = 256
rpn_positive_fraction = 0.5
rpn_pre_nms_top_n = dict(training=2000, testing=1000)
rpn_post_nms_top_n = dict(training=2000, testing=1000)
rpn_nms_thresh = 0.7
rpn = RegionProposalNetwork(
rpn_anchor_generator, rpn_head,
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
rpn_batch_size_per_image, rpn_positive_fraction,
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh)
return rpn
def test_rpn(self):
class RPNModule(torch.nn.Module):
def __init__(self_module, images):
super(RPNModule, self_module).__init__()
self_module.rpn = self._init_test_rpn()
self_module.images = ImageList(images, [i.shape[-2:] for i in images])
def forward(self_module, features):
return self_module.rpn(self_module.images, features)
def get_features(images):
s0, s1 = images.shape[-2:]
features = [
('0', torch.rand(2, 256, s0 // 4, s1 // 4)),
('1', torch.rand(2, 256, s0 // 8, s1 // 8)),
('2', torch.rand(2, 256, s0 // 16, s1 // 16)),
('3', torch.rand(2, 256, s0 // 32, s1 // 32)),
('4', torch.rand(2, 256, s0 // 64, s1 // 64)),
]
features = OrderedDict(features)
return features
images = torch.rand(2, 3, 600, 600)
features = get_features(images)
test_features = get_features(images)
model = RPNModule(images)
model.eval()
model(features)
self.run_model(model, [(features,), (test_features,)], tolerate_small_mismatch=True)
def test_multi_scale_roi_align(self): def test_multi_scale_roi_align(self):
class TransformModule(torch.nn.Module): class TransformModule(torch.nn.Module):
......
...@@ -3,6 +3,7 @@ from __future__ import division ...@@ -3,6 +3,7 @@ from __future__ import division
import math import math
import torch import torch
import torchvision
class BalancedPositiveNegativeSampler(object): class BalancedPositiveNegativeSampler(object):
...@@ -162,7 +163,7 @@ class BoxCoder(object): ...@@ -162,7 +163,7 @@ class BoxCoder(object):
if isinstance(rel_codes, (list, tuple)): if isinstance(rel_codes, (list, tuple)):
rel_codes = torch.cat(rel_codes, dim=0) rel_codes = torch.cat(rel_codes, dim=0)
assert isinstance(rel_codes, torch.Tensor) assert isinstance(rel_codes, torch.Tensor)
boxes_per_image = [len(b) for b in boxes] boxes_per_image = [b.size(0) for b in boxes]
concat_boxes = torch.cat(boxes, dim=0) concat_boxes = torch.cat(boxes, dim=0)
pred_boxes = self.decode_single( pred_boxes = self.decode_single(
rel_codes.reshape(sum(boxes_per_image), -1), concat_boxes rel_codes.reshape(sum(boxes_per_image), -1), concat_boxes
...@@ -201,16 +202,11 @@ class BoxCoder(object): ...@@ -201,16 +202,11 @@ class BoxCoder(object):
pred_w = torch.exp(dw) * widths[:, None] pred_w = torch.exp(dw) * widths[:, None]
pred_h = torch.exp(dh) * heights[:, None] pred_h = torch.exp(dh) * heights[:, None]
pred_boxes = torch.zeros_like(rel_codes) pred_boxes1 = pred_ctr_x - torch.tensor(0.5, dtype=pred_ctr_x.dtype) * pred_w
# x1 pred_boxes2 = pred_ctr_y - torch.tensor(0.5, dtype=pred_ctr_y.dtype) * pred_h
pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w pred_boxes3 = pred_ctr_x + torch.tensor(0.5, dtype=pred_ctr_x.dtype) * pred_w
# y1 pred_boxes4 = pred_ctr_y + torch.tensor(0.5, dtype=pred_ctr_y.dtype) * pred_h
pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
# x2
pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w
# y2
pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h
return pred_boxes return pred_boxes
......
...@@ -3,11 +3,25 @@ import torch ...@@ -3,11 +3,25 @@ import torch
from torch.nn import functional as F from torch.nn import functional as F
from torch import nn from torch import nn
import torchvision
from torchvision.ops import boxes as box_ops from torchvision.ops import boxes as box_ops
from . import _utils as det_utils from . import _utils as det_utils
@torch.jit.unused
def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
from torch.onnx import operators
num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0)
# TODO : remove cast to IntTensor/num_anchors.dtype when
# ONNX Runtime version is updated with ReduceMin int64 support
pre_nms_top_n = torch.min(torch.cat(
(torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype),
num_anchors), 0).to(torch.int32)).to(num_anchors.dtype)
return num_anchors, pre_nms_top_n
class AnchorGenerator(nn.Module): class AnchorGenerator(nn.Module):
""" """
Module that generates anchors for a set of feature maps and Module that generates anchors for a set of feature maps and
...@@ -85,6 +99,10 @@ class AnchorGenerator(nn.Module): ...@@ -85,6 +99,10 @@ class AnchorGenerator(nn.Module):
): ):
grid_height, grid_width = size grid_height, grid_width = size
stride_height, stride_width = stride stride_height, stride_width = stride
if torchvision._is_tracing():
# required in ONNX export for mult operation with float32
stride_width = torch.tensor(stride_width, dtype=torch.float32)
stride_height = torch.tensor(stride_height, dtype=torch.float32)
device = base_anchors.device device = base_anchors.device
shifts_x = torch.arange( shifts_x = torch.arange(
0, grid_width, dtype=torch.float32, device=device 0, grid_width, dtype=torch.float32, device=device
...@@ -92,6 +110,12 @@ class AnchorGenerator(nn.Module): ...@@ -92,6 +110,12 @@ class AnchorGenerator(nn.Module):
shifts_y = torch.arange( shifts_y = torch.arange(
0, grid_height, dtype=torch.float32, device=device 0, grid_height, dtype=torch.float32, device=device
) * stride_height ) * stride_height
# TODO: remove tracing pass when exporting torch.meshgrid()
# is suported in ONNX
if torchvision._is_tracing():
shift_y = shifts_y.view(-1, 1).expand(grid_height, grid_width)
shift_x = shifts_x.view(1, -1).expand(grid_height, grid_width)
else:
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1) shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1) shift_y = shift_y.reshape(-1)
...@@ -114,7 +138,9 @@ class AnchorGenerator(nn.Module): ...@@ -114,7 +138,9 @@ class AnchorGenerator(nn.Module):
def forward(self, image_list, feature_maps): def forward(self, image_list, feature_maps):
grid_sizes = tuple([feature_map.shape[-2:] for feature_map in feature_maps]) grid_sizes = tuple([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:] image_size = image_list.tensors.shape[-2:]
strides = tuple((image_size[0] / g[0], image_size[1] / g[1]) for g in grid_sizes) strides = tuple((float(image_size[0]) / float(g[0]),
float(image_size[1]) / float(g[1]))
for g in grid_sizes)
dtype, device = feature_maps[0].dtype, feature_maps[0].device dtype, device = feature_maps[0].dtype, feature_maps[0].device
self.set_cell_anchors(dtype, device) self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides) anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
...@@ -300,6 +326,9 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -300,6 +326,9 @@ class RegionProposalNetwork(torch.nn.Module):
r = [] r = []
offset = 0 offset = 0
for ob in objectness.split(num_anchors_per_level, 1): for ob in objectness.split(num_anchors_per_level, 1):
if torchvision._is_tracing():
num_anchors, pre_nms_top_n = _onnx_get_num_anchors_and_pre_nms_top_n(ob, self.pre_nms_top_n)
else:
num_anchors = ob.shape[1] num_anchors = ob.shape[1]
pre_nms_top_n = min(self.pre_nms_top_n, num_anchors) pre_nms_top_n = min(self.pre_nms_top_n, num_anchors)
_, top_n_idx = ob.topk(pre_nms_top_n, dim=1) _, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
......
...@@ -132,7 +132,7 @@ def interpolate( ...@@ -132,7 +132,7 @@ def interpolate(
# This is not in nn # This is not in nn
class FrozenBatchNorm2d(torch.jit.ScriptModule): class FrozenBatchNorm2d(torch.nn.Module):
""" """
BatchNorm2d where the batch statistics and the affine parameters BatchNorm2d where the batch statistics and the affine parameters
are fixed are fixed
...@@ -145,7 +145,6 @@ class FrozenBatchNorm2d(torch.jit.ScriptModule): ...@@ -145,7 +145,6 @@ class FrozenBatchNorm2d(torch.jit.ScriptModule):
self.register_buffer("running_mean", torch.zeros(n)) self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n)) self.register_buffer("running_var", torch.ones(n))
@torch.jit.script_method
def forward(self, x): def forward(self, x):
# move reshapes to the beginning # move reshapes to the beginning
# to make it fuser-friendly # to make it fuser-friendly
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment