Commit 7f526aa9 authored by Lara Haidar's avatar Lara Haidar Committed by Francisco Massa
Browse files

Lahaidar/export faster rcnn (#1401)

* onnx esport faster rcnn

* test

* address PR comments

* revert unbind workaround

* disable tests for older versions of pytorch
parent 8bc4ab06
import io import io
import torch import torch
from torchvision import ops from torchvision import ops
from torchvision import models
from torchvision.models.detection.image_list import ImageList from torchvision.models.detection.image_list import ImageList
from torchvision.models.detection.transform import GeneralizedRCNNTransform from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
from collections import OrderedDict from collections import OrderedDict
...@@ -59,7 +62,6 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -59,7 +62,6 @@ class ONNXExporterTester(unittest.TestCase):
# compute onnxruntime output prediction # compute onnxruntime output prediction
ort_inputs = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs)) ort_inputs = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs))
ort_outs = ort_session.run(None, ort_inputs) ort_outs = ort_session.run(None, ort_inputs)
for i in range(0, len(outputs)): for i in range(0, len(outputs)):
try: try:
torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05) torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05)
...@@ -138,6 +140,55 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -138,6 +140,55 @@ class ONNXExporterTester(unittest.TestCase):
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh) rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh)
return rpn return rpn
def _init_test_roi_heads_faster_rcnn(self):
out_channels = 256
num_classes = 91
box_fg_iou_thresh = 0.5
box_bg_iou_thresh = 0.5
box_batch_size_per_image = 512
box_positive_fraction = 0.25
bbox_reg_weights = None
box_score_thresh = 0.05
box_nms_thresh = 0.5
box_detections_per_img = 100
box_roi_pool = ops.MultiScaleRoIAlign(
featmap_names=['0', '1', '2', '3'],
output_size=7,
sampling_ratio=2)
resolution = box_roi_pool.output_size[0]
representation_size = 1024
box_head = TwoMLPHead(
out_channels * resolution ** 2,
representation_size)
representation_size = 1024
box_predictor = FastRCNNPredictor(
representation_size,
num_classes)
roi_heads = RoIHeads(
box_roi_pool, box_head, box_predictor,
box_fg_iou_thresh, box_bg_iou_thresh,
box_batch_size_per_image, box_positive_fraction,
bbox_reg_weights,
box_score_thresh, box_nms_thresh, box_detections_per_img)
return roi_heads
def get_features(self, images):
s0, s1 = images.shape[-2:]
features = [
('0', torch.rand(2, 256, s0 // 4, s1 // 4)),
('1', torch.rand(2, 256, s0 // 8, s1 // 8)),
('2', torch.rand(2, 256, s0 // 16, s1 // 16)),
('3', torch.rand(2, 256, s0 // 32, s1 // 32)),
('4', torch.rand(2, 256, s0 // 64, s1 // 64)),
]
features = OrderedDict(features)
return features
def test_rpn(self): def test_rpn(self):
class RPNModule(torch.nn.Module): class RPNModule(torch.nn.Module):
def __init__(self_module, images): def __init__(self_module, images):
...@@ -148,21 +199,9 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -148,21 +199,9 @@ class ONNXExporterTester(unittest.TestCase):
def forward(self_module, features): def forward(self_module, features):
return self_module.rpn(self_module.images, features) return self_module.rpn(self_module.images, features)
def get_features(images):
s0, s1 = images.shape[-2:]
features = [
('0', torch.rand(2, 256, s0 // 4, s1 // 4)),
('1', torch.rand(2, 256, s0 // 8, s1 // 8)),
('2', torch.rand(2, 256, s0 // 16, s1 // 16)),
('3', torch.rand(2, 256, s0 // 32, s1 // 32)),
('4', torch.rand(2, 256, s0 // 64, s1 // 64)),
]
features = OrderedDict(features)
return features
images = torch.rand(2, 3, 600, 600) images = torch.rand(2, 3, 600, 600)
features = get_features(images) features = self.get_features(images)
test_features = get_features(images) test_features = self.get_features(images)
model = RPNModule(images) model = RPNModule(images)
model.eval() model.eval()
...@@ -194,6 +233,67 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -194,6 +233,67 @@ class ONNXExporterTester(unittest.TestCase):
self.run_model(TransformModule(), [(i, [boxes],), (i1, [boxes1],)]) self.run_model(TransformModule(), [(i, [boxes],), (i1, [boxes1],)])
@unittest.skipIf(torch.__version__ < "1.4.", "Disable test if torch version is less than 1.4")
def test_roi_heads(self):
class RoiHeadsModule(torch.nn.Module):
def __init__(self_module, images):
super(RoiHeadsModule, self_module).__init__()
self_module.transform = self._init_test_generalized_rcnn_transform()
self_module.rpn = self._init_test_rpn()
self_module.roi_heads = self._init_test_roi_heads_faster_rcnn()
self_module.original_image_sizes = [img.shape[-2:] for img in images]
self_module.images = ImageList(images, [i.shape[-2:] for i in images])
def forward(self_module, features):
proposals, _ = self_module.rpn(self_module.images, features)
detections, _ = self_module.roi_heads(features, proposals, self_module.images.image_sizes)
detections = self_module.transform.postprocess(detections,
self_module.images.image_sizes,
self_module.original_image_sizes)
return detections
images = torch.rand(2, 3, 600, 600)
features = self.get_features(images)
test_features = self.get_features(images)
model = RoiHeadsModule(images)
model.eval()
model(features)
self.run_model(model, [(features,), (test_features,)], tolerate_small_mismatch=True)
def get_image_from_url(self, url):
import requests
import numpy
from PIL import Image
from io import BytesIO
from torchvision import transforms
data = requests.get(url)
image = Image.open(BytesIO(data.content)).convert("RGB")
image = image.resize((800, 1280), Image.BILINEAR)
to_tensor = transforms.ToTensor()
return to_tensor(image)
def get_test_images(self):
image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg"
image = self.get_image_from_url(url=image_url)
image_url2 = "https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image05.png"
image2 = self.get_image_from_url(url=image_url2)
images = [image]
test_images = [image2]
return images, test_images
@unittest.skip("Disable test until Resize opset 11 is implemented in ONNX Runtime")
@unittest.skipIf(torch.__version__ < "1.4.", "Disable test if torch version is less than 1.4")
def test_faster_rcnn(self):
images, test_images = self.get_test_images()
model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()
model(images)
self.run_model(model, [(images,), (test_images,)])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -477,8 +477,13 @@ class RoIHeads(torch.nn.Module): ...@@ -477,8 +477,13 @@ class RoIHeads(torch.nn.Module):
pred_scores = F.softmax(class_logits, -1) pred_scores = F.softmax(class_logits, -1)
# split boxes and scores per image # split boxes and scores per image
pred_boxes = pred_boxes.split(boxes_per_image, 0) if len(boxes_per_image) == 1:
pred_scores = pred_scores.split(boxes_per_image, 0) # TODO : remove this when ONNX support dynamic split sizes
pred_boxes = (pred_boxes,)
pred_scores = (pred_scores,)
else:
pred_boxes = pred_boxes.split(boxes_per_image, 0)
pred_scores = pred_scores.split(boxes_per_image, 0)
all_boxes = [] all_boxes = []
all_scores = [] all_scores = []
...@@ -497,8 +502,8 @@ class RoIHeads(torch.nn.Module): ...@@ -497,8 +502,8 @@ class RoIHeads(torch.nn.Module):
# batch everything, by making every class prediction be a separate instance # batch everything, by making every class prediction be a separate instance
boxes = boxes.reshape(-1, 4) boxes = boxes.reshape(-1, 4)
scores = scores.flatten() scores = scores.reshape(-1)
labels = labels.flatten() labels = labels.reshape(-1)
# remove low scoring boxes # remove low scoring boxes
inds = torch.nonzero(scores > self.score_thresh).squeeze(1) inds = torch.nonzero(scores > self.score_thresh).squeeze(1)
......
...@@ -172,6 +172,7 @@ def resize_boxes(boxes, original_size, new_size): ...@@ -172,6 +172,7 @@ def resize_boxes(boxes, original_size, new_size):
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)) ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size))
ratio_height, ratio_width = ratios ratio_height, ratio_width = ratios
xmin, ymin, xmax, ymax = boxes.unbind(1) xmin, ymin, xmax, ymax = boxes.unbind(1)
xmin = xmin * ratio_width xmin = xmin * ratio_width
xmax = xmax * ratio_width xmax = xmax * ratio_width
ymin = ymin * ratio_height ymin = ymin * ratio_height
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment