Commit f612182c authored by Lara Haidar's avatar Lara Haidar Committed by Francisco Massa
Browse files

Support Exporting Mask Rcnn to ONNX (#1461)

* Support Exporting Mask Rcnn to ONNX

* update tetst

* add control flow test

* fix

* update test and fix img_shape
parent 30cb4e10
...@@ -8,6 +8,7 @@ from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionPro ...@@ -8,6 +8,7 @@ from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionPro
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.roi_heads import RoIHeads from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
from torchvision.models.detection.mask_rcnn import MaskRCNNHeads, MaskRCNNPredictor
from collections import OrderedDict from collections import OrderedDict
...@@ -259,7 +260,7 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -259,7 +260,7 @@ class ONNXExporterTester(unittest.TestCase):
model = RoiHeadsModule(images) model = RoiHeadsModule(images)
model.eval() model.eval()
model(features) model(features)
self.run_model(model, [(features,), (test_features,)], tolerate_small_mismatch=True) self.run_model(model, [(features,), (test_features,)])
def get_image_from_url(self, url): def get_image_from_url(self, url):
import requests import requests
...@@ -294,6 +295,45 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -294,6 +295,45 @@ class ONNXExporterTester(unittest.TestCase):
model(images) model(images)
self.run_model(model, [(images,), (test_images,)]) self.run_model(model, [(images,), (test_images,)])
# Verify that paste_mask_in_image beahves the same in tracing.
# This test also compares both paste_masks_in_image and _onnx_paste_masks_in_image
# (since jit_trace witll call _onnx_paste_masks_in_image).
def test_paste_mask_in_image(self):
masks = torch.rand(10, 1, 26, 26)
boxes = torch.rand(10, 4)
boxes[:, 2:] += torch.rand(10, 2)
boxes *= 50
o_im_s = (100, 100)
from torchvision.models.detection.roi_heads import paste_masks_in_image
out = paste_masks_in_image(masks, boxes, o_im_s)
jit_trace = torch.jit.trace(paste_masks_in_image,
(masks, boxes,
[torch.tensor(o_im_s[0]),
torch.tensor(o_im_s[1])]))
out_trace = jit_trace(masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])])
assert torch.all(out.eq(out_trace))
masks2 = torch.rand(20, 1, 26, 26)
boxes2 = torch.rand(20, 4)
boxes2[:, 2:] += torch.rand(20, 2)
boxes2 *= 100
o_im_s2 = (200, 200)
from torchvision.models.detection.roi_heads import paste_masks_in_image
out2 = paste_masks_in_image(masks2, boxes2, o_im_s2)
out_trace2 = jit_trace(masks2, boxes2, [torch.tensor(o_im_s2[0]), torch.tensor(o_im_s2[1])])
assert torch.all(out2.eq(out_trace2))
@unittest.skip("Disable test until Resize opset 11 is implemented in ONNX Runtime")
def test_mask_rcnn(self):
images, test_images = self.get_test_images()
model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True)
model.eval()
model(images)
self.run_model(model, [(images,), (test_images,)])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
import torch import torch
import torchvision
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
...@@ -73,7 +74,11 @@ def maskrcnn_inference(x, labels): ...@@ -73,7 +74,11 @@ def maskrcnn_inference(x, labels):
index = torch.arange(num_masks, device=labels.device) index = torch.arange(num_masks, device=labels.device)
mask_prob = mask_prob[index, labels][:, None] mask_prob = mask_prob[index, labels][:, None]
mask_prob = mask_prob.split(boxes_per_image, dim=0) if len(boxes_per_image) == 1:
# TODO : remove when dynamic split supported in ONNX
mask_prob = (mask_prob,)
else:
mask_prob = mask_prob.split(boxes_per_image, dim=0)
return mask_prob return mask_prob
...@@ -250,10 +255,29 @@ def keypointrcnn_inference(x, boxes): ...@@ -250,10 +255,29 @@ def keypointrcnn_inference(x, boxes):
return kp_probs, kp_scores return kp_probs, kp_scores
def _onnx_expand_boxes(boxes, scale):
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
y_c = (boxes[:, 3] + boxes[:, 1]) * .5
w_half = w_half.to(dtype=torch.float32) * scale
h_half = h_half.to(dtype=torch.float32) * scale
boxes_exp0 = x_c - w_half
boxes_exp1 = y_c - h_half
boxes_exp2 = x_c + w_half
boxes_exp3 = y_c + h_half
boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
return boxes_exp
# the next two functions should be merged inside Masker # the next two functions should be merged inside Masker
# but are kept here for the moment while we need them # but are kept here for the moment while we need them
# temporarily for paste_mask_in_image # temporarily for paste_mask_in_image
def expand_boxes(boxes, scale): def expand_boxes(boxes, scale):
if torchvision._is_tracing():
return _onnx_expand_boxes(boxes, scale)
w_half = (boxes[:, 2] - boxes[:, 0]) * .5 w_half = (boxes[:, 2] - boxes[:, 0]) * .5
h_half = (boxes[:, 3] - boxes[:, 1]) * .5 h_half = (boxes[:, 3] - boxes[:, 1]) * .5
x_c = (boxes[:, 2] + boxes[:, 0]) * .5 x_c = (boxes[:, 2] + boxes[:, 0]) * .5
...@@ -272,7 +296,10 @@ def expand_boxes(boxes, scale): ...@@ -272,7 +296,10 @@ def expand_boxes(boxes, scale):
def expand_masks(mask, padding): def expand_masks(mask, padding):
M = mask.shape[-1] M = mask.shape[-1]
scale = float(M + 2 * padding) / M if torchvision._is_tracing():
scale = (M + 2 * padding).to(torch.float32) / M.to(torch.float32)
else:
scale = float(M + 2 * padding) / M
padded_mask = torch.nn.functional.pad(mask, (padding,) * 4) padded_mask = torch.nn.functional.pad(mask, (padding,) * 4)
return padded_mask, scale return padded_mask, scale
...@@ -303,11 +330,69 @@ def paste_mask_in_image(mask, box, im_h, im_w): ...@@ -303,11 +330,69 @@ def paste_mask_in_image(mask, box, im_h, im_w):
return im_mask return im_mask
def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
one = torch.ones(1, dtype=torch.int64)
zero = torch.zeros(1, dtype=torch.int64)
w = (box[2] - box[0] + one)
h = (box[3] - box[1] + one)
w = torch.max(torch.cat((w, one)))
h = torch.max(torch.cat((h, one)))
# Set shape to [batchxCxHxW]
mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
# Resize mask
mask = torch.nn.functional.interpolate(mask, size=(int(h), int(w)), mode='bilinear', align_corners=False)
mask = mask[0][0]
x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
unpaded_im_mask = mask[(y_0 - box[1]):(y_1 - box[1]),
(x_0 - box[0]):(x_1 - box[0])]
# TODO : replace below with a dynamic padding when support is added in ONNX
# pad y
zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
concat_0 = torch.cat((zeros_y0,
unpaded_im_mask.to(dtype=torch.float32),
zeros_y1), 0)[0:im_h, :]
# pad x
zeros_x0 = torch.zeros(concat_0.size(0), x_0)
zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
im_mask = torch.cat((zeros_x0,
concat_0,
zeros_x1), 1)[:, :im_w]
return im_mask
@torch.jit.script
def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
res_append = torch.zeros(0, im_h, im_w)
for i in range(masks.size(0)):
mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
mask_res = mask_res.unsqueeze(0)
res_append = torch.cat((res_append, mask_res))
return res_append
def paste_masks_in_image(masks, boxes, img_shape, padding=1): def paste_masks_in_image(masks, boxes, img_shape, padding=1):
masks, scale = expand_masks(masks, padding=padding) masks, scale = expand_masks(masks, padding=padding)
boxes = expand_boxes(boxes, scale).to(dtype=torch.int64).tolist() boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
# im_h, im_w = img_shape.tolist() # im_h, im_w = img_shape.tolist()
im_h, im_w = img_shape im_h, im_w = img_shape
if torchvision._is_tracing():
return _onnx_paste_masks_in_image_loop(masks, boxes,
torch.scalar_tensor(im_h, dtype=torch.int64),
torch.scalar_tensor(im_w, dtype=torch.int64))[:, None]
boxes = boxes.tolist()
res = [ res = [
paste_mask_in_image(m[0], b, im_h, im_w) paste_mask_in_image(m[0], b, im_h, im_w)
for m, b in zip(masks, boxes) for m, b in zip(masks, boxes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment