Commit d88d8961 authored by eellison's avatar eellison Committed by Francisco Massa
Browse files

Make maskrcnn scriptable (#1407)

* almost working...

* respond to comments

* add empty tensor op, handle different output types in generalized rcnn

* clean ups

* address comments

* more changes

* it's working!

* torchscript bugs

* add script/ eager test

* eval script model

* fix flake

* division import

* py2 compat

* update test, fix arange bug

* import division statement

* fix linter

* fixes

* changes needed for JIT master

* cleanups

* remove imagelist_to

* requested changes

* Make FPN backwards-compatible and torchscript compatible

We remove support for feature channels=0, but support for it was already a bit limited

* Fix ONNX regression
parent b590f8c6
......@@ -2,7 +2,9 @@ from collections import OrderedDict
import torch
import torch.nn.functional as F
from torch import nn
from torch import nn, Tensor
from torch.jit.annotations import Tuple, List, Dict
class FeaturePyramidNetwork(nn.Module):
......@@ -42,14 +44,13 @@ class FeaturePyramidNetwork(nn.Module):
>>> ('feat3', torch.Size([1, 5, 8, 8]))]
"""
def __init__(self, in_channels_list, out_channels, extra_blocks=None):
super(FeaturePyramidNetwork, self).__init__()
self.inner_blocks = nn.ModuleList()
self.layer_blocks = nn.ModuleList()
for in_channels in in_channels_list:
if in_channels == 0:
continue
raise ValueError("in_channels=0 is currently not supported")
inner_block_module = nn.Conv2d(in_channels, out_channels, 1)
layer_block_module = nn.Conv2d(out_channels, out_channels, 3, padding=1)
self.inner_blocks.append(inner_block_module)
......@@ -65,7 +66,46 @@ class FeaturePyramidNetwork(nn.Module):
assert isinstance(extra_blocks, ExtraFPNBlock)
self.extra_blocks = extra_blocks
def get_result_from_inner_blocks(self, x, idx):
# type: (Tensor, int)
"""
This is equivalent to self.inner_blocks[idx](x),
but torchscript doesn't support this yet
"""
num_blocks = 0
for m in self.inner_blocks:
num_blocks += 1
if idx < 0:
idx += num_blocks
i = 0
out = x
for module in self.inner_blocks:
if i == idx:
out = module(x)
i += 1
return out
def get_result_from_layer_blocks(self, x, idx):
# type: (Tensor, int)
"""
This is equivalent to self.layer_blocks[idx](x),
but torchscript doesn't support this yet
"""
num_blocks = 0
for m in self.layer_blocks:
num_blocks += 1
if idx < 0:
idx += num_blocks
i = 0
out = x
for module in self.layer_blocks:
if i == idx:
out = module(x)
i += 1
return out
def forward(self, x):
# type: (Dict[str, Tensor])
"""
Computes the FPN for a set of feature maps.
......@@ -80,19 +120,16 @@ class FeaturePyramidNetwork(nn.Module):
names = list(x.keys())
x = list(x.values())
last_inner = self.inner_blocks[-1](x[-1])
last_inner = self.get_result_from_inner_blocks(x[-1], -1)
results = []
results.append(self.layer_blocks[-1](last_inner))
for feature, inner_block, layer_block in zip(
x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1]
):
if not inner_block:
continue
inner_lateral = inner_block(feature)
results.append(self.get_result_from_layer_blocks(last_inner, -1))
for idx in range(len(x) - 2, -1, -1):
inner_lateral = self.get_result_from_inner_blocks(x[idx], idx)
feat_shape = inner_lateral.shape[-2:]
inner_top_down = F.interpolate(last_inner, size=feat_shape, mode="nearest")
last_inner = inner_lateral + inner_top_down
results.insert(0, layer_block(last_inner))
results.insert(0, self.get_result_from_layer_blocks(last_inner, idx))
if self.extra_blocks is not None:
results, names = self.extra_blocks(results, x, names)
......@@ -127,6 +164,7 @@ class LastLevelMaxPool(ExtraFPNBlock):
Applies a max_pool2d on top of the last feature map
"""
def forward(self, x, y, names):
# type: (List[Tensor], List[Tensor], List[str])
names.append("pool")
x.append(F.max_pool2d(x[-1], 1, 2, 0))
return x, names
......
from __future__ import division
from collections import OrderedDict
from torch.jit.annotations import Optional, List
from torch import Tensor
"""
helper class that supports empty tensors on some nn functions.
......@@ -12,40 +15,9 @@ is implemented
import math
import torch
from torch.nn.modules.utils import _ntuple
class _NewEmptyTensorOp(torch.autograd.Function):
@staticmethod
def forward(ctx, x, new_shape):
ctx.shape = x.shape
return x.new_empty(new_shape)
@staticmethod
def backward(ctx, grad):
shape = ctx.shape
return _NewEmptyTensorOp.apply(grad, shape), None
class Conv2d(torch.nn.Conv2d):
"""
Equivalent to nn.Conv2d, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
def forward(self, x):
if x.numel() > 0:
return super(Conv2d, self).forward(x)
# get output shape
output_shape = [
(i + 2 * p - (di * (k - 1) + 1)) // d + 1
for i, p, di, k, d in zip(
x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride
)
]
output_shape = [x.shape[0], self.weight.shape[0]] + output_shape
return _NewEmptyTensorOp.apply(x, output_shape)
from torchvision.ops import _new_empty_tensor
from torch.nn import Module, Conv2d
import torch.nn.functional as F
class ConvTranspose2d(torch.nn.ConvTranspose2d):
......@@ -56,22 +28,33 @@ class ConvTranspose2d(torch.nn.ConvTranspose2d):
"""
def forward(self, x):
if x.numel() > 0:
return super(ConvTranspose2d, self).forward(x)
return self.super_forward(x)
# get output shape
output_shape = [
(i - 1) * d - 2 * p + (di * (k - 1) + 1) + op
for i, p, di, k, d, op in zip(
x.shape[-2:],
self.padding,
self.dilation,
self.kernel_size,
self.stride,
self.output_padding,
list(self.padding),
list(self.dilation),
list(self.kernel_size),
list(self.stride),
list(self.output_padding),
)
]
output_shape = [x.shape[0], self.bias.shape[0]] + output_shape
return _NewEmptyTensorOp.apply(x, output_shape)
return _new_empty_tensor(x, output_shape)
def super_forward(self, input, output_size=None):
# type: (Tensor, Optional[List[int]]) -> Tensor
if self.padding_mode != 'zeros':
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d')
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size)
return F.conv_transpose2d(
input, self.weight, self.bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
class BatchNorm2d(torch.nn.BatchNorm2d):
......@@ -85,12 +68,39 @@ class BatchNorm2d(torch.nn.BatchNorm2d):
return super(BatchNorm2d, self).forward(x)
# get output shape
output_shape = x.shape
return _NewEmptyTensorOp.apply(x, output_shape)
return _new_empty_tensor(x, output_shape)
def _check_size_scale_factor(dim, size, scale_factor):
# type: (int, Optional[List[int]], Optional[float]) -> None
if size is None and scale_factor is None:
raise ValueError("either size or scale_factor should be defined")
if size is not None and scale_factor is not None:
raise ValueError("only one of size or scale_factor should be defined")
if not (scale_factor is not None and len(scale_factor) != dim):
raise ValueError(
"scale_factor shape must match input shape. "
"Input is {}D, scale_factor size is {}".format(dim, len(scale_factor))
)
def interpolate(
input, size=None, scale_factor=None, mode="nearest", align_corners=None
):
def _output_size(dim, input, size, scale_factor):
# type: (int, Tensor, Optional[List[int]], Optional[float]) -> List[int]
assert dim == 2
_check_size_scale_factor(dim, size, scale_factor)
if size is not None:
return size
# if dim is not 2 or scale_factor is iterable use _ntuple instead of concat
assert scale_factor is not None and isinstance(scale_factor, (int, float))
scale_factors = [scale_factor, scale_factor]
# math.floor might return float in py2.7
return [
int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)
]
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
"""
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
......@@ -101,34 +111,9 @@ def interpolate(
input, size, scale_factor, mode, align_corners
)
def _check_size_scale_factor(dim):
if size is None and scale_factor is None:
raise ValueError("either size or scale_factor should be defined")
if size is not None and scale_factor is not None:
raise ValueError("only one of size or scale_factor should be defined")
if (
scale_factor is not None and
isinstance(scale_factor, tuple) and
len(scale_factor) != dim
):
raise ValueError(
"scale_factor shape must match input shape. "
"Input is {}D, scale_factor size is {}".format(dim, len(scale_factor))
)
def _output_size(dim):
_check_size_scale_factor(dim)
if size is not None:
return size
scale_factors = _ntuple(dim)(scale_factor)
# math.floor might return float in py2.7
return [
int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)
]
output_shape = tuple(_output_size(2))
output_shape = _output_size(2, input, size, scale_factor)
output_shape = input.shape[:-2] + output_shape
return _NewEmptyTensorOp.apply(input, output_shape)
return _new_empty_tensor(input, output_shape)
# This is not in nn
......
import torch
from torch.jit.annotations import List
from torch import Tensor
def _new_empty_tensor(x, shape):
# type: (Tensor, List[int]) -> Tensor
"""
Arguments:
input (Tensor): input tensor
shape List[int]: the new empty tensor shape
Returns:
output (Tensor)
"""
return torch.ops.torchvision._new_empty_tensor_op(x, shape)
from __future__ import division
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
import torch.nn.functional as F
from torch import nn
from torch import nn, Tensor
import torchvision
from torchvision.ops import roi_align
from torchvision.ops.boxes import box_area
from torch.jit.annotations import Optional, List, Dict, Tuple
import torchvision
# copying result_idx_in_level to a specific index in result[]
# is not supported by ONNX tracing yet.
# _onnx_merge_levels() is an implementation supported by ONNX
# that merges the levels to the right indices
@torch.jit.unused
def _onnx_merge_levels(levels, unmerged_results):
# type: (Tensor, List[Tensor]) -> Tensor
first_result = unmerged_results[0]
dtype, device = first_result.dtype, first_result.device
res = torch.zeros((levels.size(0), first_result.size(1),
......@@ -28,6 +33,13 @@ def _onnx_merge_levels(levels, unmerged_results):
return res
# TODO: (eellison) T54974082 https://github.com/pytorch/pytorch/issues/26744/pytorch/issues/26744
def initLevelMapper(k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
# type: (int, int, int, int, float)
return LevelMapper(k_min, k_max, canonical_scale, canonical_level, eps)
@torch.jit.script
class LevelMapper(object):
"""Determine which FPN level each RoI in a set of RoIs should map to based
on the heuristic in the FPN paper.
......@@ -41,6 +53,7 @@ class LevelMapper(object):
"""
def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
# type: (int, int, int, int, float)
self.k_min = k_min
self.k_max = k_max
self.s0 = canonical_scale
......@@ -48,6 +61,7 @@ class LevelMapper(object):
self.eps = eps
def __call__(self, boxlists):
# type: (List[Tensor])
"""
Arguments:
boxlists (list[BoxList])
......@@ -90,6 +104,11 @@ class MultiScaleRoIAlign(nn.Module):
"""
__annotations__ = {
'scales': Optional[List[float]],
'map_levels': Optional[LevelMapper]
}
def __init__(self, featmap_names, output_size, sampling_ratio):
super(MultiScaleRoIAlign, self).__init__()
if isinstance(output_size, int):
......@@ -101,11 +120,12 @@ class MultiScaleRoIAlign(nn.Module):
self.map_levels = None
def convert_to_roi_format(self, boxes):
# type: (List[Tensor])
concat_boxes = torch.cat(boxes, dim=0)
device, dtype = concat_boxes.device, concat_boxes.dtype
ids = torch.cat(
[
torch.full_like(b[:, :1], i, dtype=dtype, device=device)
torch.full_like(b[:, :1], i, dtype=dtype, layout=torch.strided, device=device)
for i, b in enumerate(boxes)
],
dim=0,
......@@ -114,27 +134,37 @@ class MultiScaleRoIAlign(nn.Module):
return rois
def infer_scale(self, feature, original_size):
# type: (Tensor, List[int])
# assumption: the scale is of the form 2 ** (-k), with k integer
size = feature.shape[-2:]
possible_scales = []
possible_scales = torch.jit.annotate(List[float], [])
for s1, s2 in zip(size, original_size):
approx_scale = float(s1) / s2
scale = 2 ** torch.tensor(approx_scale).log2().round().item()
scale = 2 ** float(torch.tensor(approx_scale).log2().round())
possible_scales.append(scale)
assert possible_scales[0] == possible_scales[1]
return possible_scales[0]
def setup_scales(self, features, image_shapes):
original_input_shape = tuple(max(s) for s in zip(*image_shapes))
# type: (List[Tensor], List[Tuple[int, int]])
assert len(image_shapes) != 0
max_x = 0
max_y = 0
for shape in image_shapes:
max_x = max(shape[0], max_x)
max_y = max(shape[1], max_y)
original_input_shape = (max_x, max_y)
scales = [self.infer_scale(feat, original_input_shape) for feat in features]
# get the levels in the feature map by leveraging the fact that the network always
# downsamples by a factor of 2 at each level.
lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
self.scales = scales
self.map_levels = LevelMapper(lvl_min, lvl_max)
self.map_levels = initLevelMapper(int(lvl_min), int(lvl_max))
def forward(self, x, boxes, image_shapes):
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]])
"""
Arguments:
x (OrderedDict[Tensor]): feature maps for each level. They are assumed to have
......@@ -148,34 +178,43 @@ class MultiScaleRoIAlign(nn.Module):
Returns:
result (Tensor)
"""
x = [v for k, v in x.items() if k in self.featmap_names]
num_levels = len(x)
x_filtered = []
for k, v in x.items():
if k in self.featmap_names:
x_filtered.append(v)
num_levels = len(x_filtered)
rois = self.convert_to_roi_format(boxes)
if self.scales is None:
self.setup_scales(x, image_shapes)
self.setup_scales(x_filtered, image_shapes)
scales = self.scales
assert scales is not None
if num_levels == 1:
return roi_align(
x[0], rois,
x_filtered[0], rois,
output_size=self.output_size,
spatial_scale=self.scales[0],
spatial_scale=scales[0],
sampling_ratio=self.sampling_ratio
)
levels = self.map_levels(boxes)
mapper = self.map_levels
assert mapper is not None
levels = mapper(boxes)
num_rois = len(rois)
num_channels = x[0].shape[1]
num_channels = x_filtered[0].shape[1]
dtype, device = x[0].dtype, x[0].device
dtype, device = x_filtered[0].dtype, x_filtered[0].device
result = torch.zeros(
(num_rois, num_channels,) + self.output_size,
dtype=dtype,
device=device,
)
results = []
for level, (per_level_feature, scale) in enumerate(zip(x, self.scales)):
tracing_results = []
for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)):
idx_in_level = torch.nonzero(levels == level).squeeze(1)
rois_per_level = rois[idx_in_level]
......@@ -185,10 +224,11 @@ class MultiScaleRoIAlign(nn.Module):
spatial_scale=scale, sampling_ratio=self.sampling_ratio)
if torchvision._is_tracing():
results.append(result_idx_in_level.to(dtype))
tracing_results.append(result_idx_in_level.to(dtype))
else:
result[idx_in_level] = result_idx_in_level
if torchvision._is_tracing():
result = _onnx_merge_levels(levels, results)
result = _onnx_merge_levels(levels, tracing_results)
return result
......@@ -2,13 +2,13 @@ import torch
from torch import nn, Tensor
from torch.nn.modules.utils import _pair
from torch.jit.annotations import List
from torch.jit.annotations import List, BroadcastingList2
from ._utils import convert_boxes_to_roi_format
def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1):
# type: (Tensor, Tensor, int, float, int) -> Tensor
# type: (Tensor, Tensor, BroadcastingList2[int], float, int) -> Tensor
"""
Performs Region of Interest (RoI) Align operator described in Mask R-CNN
......
......@@ -2,13 +2,13 @@ import torch
from torch import nn, Tensor
from torch.nn.modules.utils import _pair
from torch.jit.annotations import List
from torch.jit.annotations import List, BroadcastingList2
from ._utils import convert_boxes_to_roi_format
def roi_pool(input, boxes, output_size, spatial_scale=1.0):
# type: (Tensor, Tensor, int, float) -> Tensor
# type: (Tensor, Tensor, BroadcastingList2[int], float) -> Tensor
"""
Performs Region of Interest (RoI) Pool operator described in Fast R-CNN
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment