Commit 7f526aa9 authored by Lara Haidar's avatar Lara Haidar Committed by Francisco Massa
Browse files

Lahaidar/export faster rcnn (#1401)

* onnx esport faster rcnn

* test

* address PR comments

* revert unbind workaround

* disable tests for older versions of pytorch
parent 8bc4ab06
import io
import torch
from torchvision import ops
from torchvision import models
from torchvision.models.detection.image_list import ImageList
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
from collections import OrderedDict
......@@ -59,7 +62,6 @@ class ONNXExporterTester(unittest.TestCase):
# compute onnxruntime output prediction
ort_inputs = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs))
ort_outs = ort_session.run(None, ort_inputs)
for i in range(0, len(outputs)):
try:
torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05)
......@@ -138,17 +140,44 @@ class ONNXExporterTester(unittest.TestCase):
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh)
return rpn
def test_rpn(self):
class RPNModule(torch.nn.Module):
def __init__(self_module, images):
super(RPNModule, self_module).__init__()
self_module.rpn = self._init_test_rpn()
self_module.images = ImageList(images, [i.shape[-2:] for i in images])
def forward(self_module, features):
return self_module.rpn(self_module.images, features)
def get_features(images):
def _init_test_roi_heads_faster_rcnn(self):
out_channels = 256
num_classes = 91
box_fg_iou_thresh = 0.5
box_bg_iou_thresh = 0.5
box_batch_size_per_image = 512
box_positive_fraction = 0.25
bbox_reg_weights = None
box_score_thresh = 0.05
box_nms_thresh = 0.5
box_detections_per_img = 100
box_roi_pool = ops.MultiScaleRoIAlign(
featmap_names=['0', '1', '2', '3'],
output_size=7,
sampling_ratio=2)
resolution = box_roi_pool.output_size[0]
representation_size = 1024
box_head = TwoMLPHead(
out_channels * resolution ** 2,
representation_size)
representation_size = 1024
box_predictor = FastRCNNPredictor(
representation_size,
num_classes)
roi_heads = RoIHeads(
box_roi_pool, box_head, box_predictor,
box_fg_iou_thresh, box_bg_iou_thresh,
box_batch_size_per_image, box_positive_fraction,
bbox_reg_weights,
box_score_thresh, box_nms_thresh, box_detections_per_img)
return roi_heads
def get_features(self, images):
s0, s1 = images.shape[-2:]
features = [
('0', torch.rand(2, 256, s0 // 4, s1 // 4)),
......@@ -160,9 +189,19 @@ class ONNXExporterTester(unittest.TestCase):
features = OrderedDict(features)
return features
def test_rpn(self):
class RPNModule(torch.nn.Module):
def __init__(self_module, images):
super(RPNModule, self_module).__init__()
self_module.rpn = self._init_test_rpn()
self_module.images = ImageList(images, [i.shape[-2:] for i in images])
def forward(self_module, features):
return self_module.rpn(self_module.images, features)
images = torch.rand(2, 3, 600, 600)
features = get_features(images)
test_features = get_features(images)
features = self.get_features(images)
test_features = self.get_features(images)
model = RPNModule(images)
model.eval()
......@@ -194,6 +233,67 @@ class ONNXExporterTester(unittest.TestCase):
self.run_model(TransformModule(), [(i, [boxes],), (i1, [boxes1],)])
@unittest.skipIf(torch.__version__ < "1.4.", "Disable test if torch version is less than 1.4")
def test_roi_heads(self):
class RoiHeadsModule(torch.nn.Module):
def __init__(self_module, images):
super(RoiHeadsModule, self_module).__init__()
self_module.transform = self._init_test_generalized_rcnn_transform()
self_module.rpn = self._init_test_rpn()
self_module.roi_heads = self._init_test_roi_heads_faster_rcnn()
self_module.original_image_sizes = [img.shape[-2:] for img in images]
self_module.images = ImageList(images, [i.shape[-2:] for i in images])
def forward(self_module, features):
proposals, _ = self_module.rpn(self_module.images, features)
detections, _ = self_module.roi_heads(features, proposals, self_module.images.image_sizes)
detections = self_module.transform.postprocess(detections,
self_module.images.image_sizes,
self_module.original_image_sizes)
return detections
images = torch.rand(2, 3, 600, 600)
features = self.get_features(images)
test_features = self.get_features(images)
model = RoiHeadsModule(images)
model.eval()
model(features)
self.run_model(model, [(features,), (test_features,)], tolerate_small_mismatch=True)
def get_image_from_url(self, url):
import requests
import numpy
from PIL import Image
from io import BytesIO
from torchvision import transforms
data = requests.get(url)
image = Image.open(BytesIO(data.content)).convert("RGB")
image = image.resize((800, 1280), Image.BILINEAR)
to_tensor = transforms.ToTensor()
return to_tensor(image)
def get_test_images(self):
image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg"
image = self.get_image_from_url(url=image_url)
image_url2 = "https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image05.png"
image2 = self.get_image_from_url(url=image_url2)
images = [image]
test_images = [image2]
return images, test_images
@unittest.skip("Disable test until Resize opset 11 is implemented in ONNX Runtime")
@unittest.skipIf(torch.__version__ < "1.4.", "Disable test if torch version is less than 1.4")
def test_faster_rcnn(self):
images, test_images = self.get_test_images()
model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()
model(images)
self.run_model(model, [(images,), (test_images,)])
if __name__ == '__main__':
unittest.main()
......@@ -477,6 +477,11 @@ class RoIHeads(torch.nn.Module):
pred_scores = F.softmax(class_logits, -1)
# split boxes and scores per image
if len(boxes_per_image) == 1:
# TODO : remove this when ONNX support dynamic split sizes
pred_boxes = (pred_boxes,)
pred_scores = (pred_scores,)
else:
pred_boxes = pred_boxes.split(boxes_per_image, 0)
pred_scores = pred_scores.split(boxes_per_image, 0)
......@@ -497,8 +502,8 @@ class RoIHeads(torch.nn.Module):
# batch everything, by making every class prediction be a separate instance
boxes = boxes.reshape(-1, 4)
scores = scores.flatten()
labels = labels.flatten()
scores = scores.reshape(-1)
labels = labels.reshape(-1)
# remove low scoring boxes
inds = torch.nonzero(scores > self.score_thresh).squeeze(1)
......
......@@ -172,6 +172,7 @@ def resize_boxes(boxes, original_size, new_size):
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size))
ratio_height, ratio_width = ratios
xmin, ymin, xmax, ymax = boxes.unbind(1)
xmin = xmin * ratio_width
xmax = xmax * ratio_width
ymin = ymin * ratio_height
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment