Commit 76702a03 authored by Lara Haidar's avatar Lara Haidar Committed by Francisco Massa
Browse files

Support Exporting MultiScaleRoiAlign to ONNX (#1324)

* Support Exporting MultiScaleRoiAlign to ONNX

* remove cast

* fix dtype

* move cast
parent f0d3daa7
...@@ -3,6 +3,8 @@ import torch ...@@ -3,6 +3,8 @@ import torch
from torchvision import ops from torchvision import ops
from torchvision.models.detection.transform import GeneralizedRCNNTransform from torchvision.models.detection.transform import GeneralizedRCNNTransform
from collections import OrderedDict
# onnxruntime requires python 3.5 or above # onnxruntime requires python 3.5 or above
try: try:
import onnxruntime import onnxruntime
...@@ -69,19 +71,18 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -69,19 +71,18 @@ class ONNXExporterTester(unittest.TestCase):
self.run_model(Module(), [(boxes, scores)]) self.run_model(Module(), [(boxes, scores)])
def test_roi_pool(self): def test_roi_align(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32) x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32) single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1, 2) model = ops.RoIAlign((5, 5), 1, 2)
self.run_model(model, [(x, single_roi)]) self.run_model(model, [(x, single_roi)])
def test_roi_align(self): def test_roi_pool(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32) x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
rois = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32) rois = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
pool_h = 5 pool_h = 5
pool_w = 5 pool_w = 5
model = ops.RoIPool((pool_h, pool_w), 2) model = ops.RoIPool((pool_h, pool_w), 2)
model.eval()
self.run_model(model, [(x, rois)]) self.run_model(model, [(x, rois)])
@unittest.skip("Disable test until Resize opset 11 is implemented in ONNX Runtime") @unittest.skip("Disable test until Resize opset 11 is implemented in ONNX Runtime")
...@@ -103,6 +104,31 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -103,6 +104,31 @@ class ONNXExporterTester(unittest.TestCase):
input_test = [torch.rand(3, 800, 1280), torch.rand(3, 800, 800)] input_test = [torch.rand(3, 800, 1280), torch.rand(3, 800, 800)]
self.run_model(TransformModule(), [input, input_test]) self.run_model(TransformModule(), [input, input_test])
def test_multi_scale_roi_align(self):
class TransformModule(torch.nn.Module):
def __init__(self):
super(TransformModule, self).__init__()
self.model = ops.MultiScaleRoIAlign(['feat1', 'feat2'], 3, 2)
self.image_sizes = [(512, 512)]
def forward(self, input, boxes):
return self.model(input, boxes, self.image_sizes)
i = OrderedDict()
i['feat1'] = torch.rand(1, 5, 64, 64)
i['feat2'] = torch.rand(1, 5, 16, 16)
boxes = torch.rand(6, 4) * 256
boxes[:, 2:] += boxes[:, :2]
i1 = OrderedDict()
i1['feat1'] = torch.rand(1, 5, 64, 64)
i1['feat2'] = torch.rand(1, 5, 16, 16)
boxes1 = torch.rand(6, 4) * 256
boxes1[:, 2:] += boxes1[:, :2]
self.run_model(TransformModule(), [(i, [boxes],), (i1, [boxes1],)])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -3,10 +3,31 @@ import torch ...@@ -3,10 +3,31 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
import torchvision
from torchvision.ops import roi_align from torchvision.ops import roi_align
from torchvision.ops.boxes import box_area from torchvision.ops.boxes import box_area
# copying result_idx_in_level to a specific index in result[]
# is not supported by ONNX tracing yet.
# _onnx_merge_levels() is an implementation supported by ONNX
# that merges the levels to the right indices
def _onnx_merge_levels(levels, unmerged_results):
first_result = unmerged_results[0]
dtype, device = first_result.dtype, first_result.device
res = torch.zeros((levels.size(0), first_result.size(1),
first_result.size(2), first_result.size(3)),
dtype=dtype, device=device)
for l in range(len(unmerged_results)):
index = (levels == l).nonzero().view(-1, 1, 1, 1)
index = index.expand(index.size(0),
unmerged_results[l].size(1),
unmerged_results[l].size(2),
unmerged_results[l].size(3))
res = res.scatter(0, index, unmerged_results[l])
return res
class LevelMapper(object): class LevelMapper(object):
"""Determine which FPN level each RoI in a set of RoIs should map to based """Determine which FPN level each RoI in a set of RoIs should map to based
on the heuristic in the FPN paper. on the heuristic in the FPN paper.
...@@ -35,9 +56,9 @@ class LevelMapper(object): ...@@ -35,9 +56,9 @@ class LevelMapper(object):
s = torch.sqrt(torch.cat([box_area(boxlist) for boxlist in boxlists])) s = torch.sqrt(torch.cat([box_area(boxlist) for boxlist in boxlists]))
# Eqn.(1) in FPN paper # Eqn.(1) in FPN paper
target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0 + self.eps)) target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0) + torch.tensor(self.eps, dtype=s.dtype))
target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max) target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max)
return target_lvls.to(torch.int64) - self.k_min return (target_lvls.to(torch.int64) - self.k_min).to(torch.int64)
class MultiScaleRoIAlign(nn.Module): class MultiScaleRoIAlign(nn.Module):
...@@ -84,7 +105,7 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -84,7 +105,7 @@ class MultiScaleRoIAlign(nn.Module):
device, dtype = concat_boxes.device, concat_boxes.dtype device, dtype = concat_boxes.device, concat_boxes.dtype
ids = torch.cat( ids = torch.cat(
[ [
torch.full((len(b), 1), i, dtype=dtype, device=device) torch.full_like(b[:, :1], i, dtype=dtype, device=device)
for i, b in enumerate(boxes) for i, b in enumerate(boxes)
], ],
dim=0, dim=0,
...@@ -153,14 +174,21 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -153,14 +174,21 @@ class MultiScaleRoIAlign(nn.Module):
device=device, device=device,
) )
results = []
for level, (per_level_feature, scale) in enumerate(zip(x, self.scales)): for level, (per_level_feature, scale) in enumerate(zip(x, self.scales)):
idx_in_level = torch.nonzero(levels == level).squeeze(1) idx_in_level = torch.nonzero(levels == level).squeeze(1)
rois_per_level = rois[idx_in_level] rois_per_level = rois[idx_in_level]
result[idx_in_level] = roi_align( result_idx_in_level = roi_align(
per_level_feature, rois_per_level, per_level_feature, rois_per_level,
output_size=self.output_size, output_size=self.output_size,
spatial_scale=scale, sampling_ratio=self.sampling_ratio spatial_scale=scale, sampling_ratio=self.sampling_ratio)
)
if torchvision._is_tracing():
results.append(result_idx_in_level.to(dtype))
else:
result[idx_in_level] = result_idx_in_level
if torchvision._is_tracing():
result = _onnx_merge_levels(levels, results)
return result return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment