Commit d88d8961 authored by eellison's avatar eellison Committed by Francisco Massa
Browse files

Make maskrcnn scriptable (#1407)

* almost working...

* respond to comments

* add empty tensor op, handle different output types in generalized rcnn

* clean ups

* address comments

* more changes

* it's working!

* torchscript bugs

* add script/ eager test

* eval script model

* fix flake

* division import

* py2 compat

* update test, fix arange bug

* import division statement

* fix linter

* fixes

* changes needed for JIT master

* cleanups

* remove imagelist_to

* requested changes

* Make FPN backwards-compatible and torchscript compatible

We remove support for feature channels=0, but support for it was already a bit limited

* Fix ONNX regression
parent b590f8c6
...@@ -208,7 +208,7 @@ def patched_make_field(self, types, domain, items, **kw): ...@@ -208,7 +208,7 @@ def patched_make_field(self, types, domain, items, **kw):
# `kw` catches `env=None` needed for newer sphinx while maintaining # `kw` catches `env=None` needed for newer sphinx while maintaining
# backwards compatibility when passed along further down! # backwards compatibility when passed along further down!
# type: (list, unicode, tuple) -> nodes.field # type: (list, unicode, tuple) -> nodes.field # noqa: F821
def handle_item(fieldarg, content): def handle_item(fieldarg, content):
par = nodes.paragraph() par = nodes.paragraph()
par += addnodes.literal_strong('', fieldarg) # Patch: this line added par += addnodes.literal_strong('', fieldarg) # Patch: this line added
......
...@@ -15,11 +15,11 @@ class ResnetFPNBackboneTester(unittest.TestCase): ...@@ -15,11 +15,11 @@ class ResnetFPNBackboneTester(unittest.TestCase):
x = torch.rand(1, 3, 300, 300, dtype=self.dtype, device=device) x = torch.rand(1, 3, 300, 300, dtype=self.dtype, device=device)
resnet18_fpn = resnet_fpn_backbone(backbone_name='resnet18', pretrained=False) resnet18_fpn = resnet_fpn_backbone(backbone_name='resnet18', pretrained=False)
y = resnet18_fpn(x) y = resnet18_fpn(x)
self.assertEqual(list(y.keys()), [0, 1, 2, 3, 'pool']) self.assertEqual(list(y.keys()), ['0', '1', '2', '3', 'pool'])
def test_resnet50_fpn_backbone(self): def test_resnet50_fpn_backbone(self):
device = torch.device('cpu') device = torch.device('cpu')
x = torch.rand(1, 3, 300, 300, dtype=self.dtype, device=device) x = torch.rand(1, 3, 300, 300, dtype=self.dtype, device=device)
resnet50_fpn = resnet_fpn_backbone(backbone_name='resnet50', pretrained=False) resnet50_fpn = resnet_fpn_backbone(backbone_name='resnet50', pretrained=False)
y = resnet50_fpn(x) y = resnet50_fpn(x)
self.assertEqual(list(y.keys()), [0, 1, 2, 3, 'pool']) self.assertEqual(list(y.keys()), ['0', '1', '2', '3', 'pool'])
...@@ -51,7 +51,10 @@ script_test_models = [ ...@@ -51,7 +51,10 @@ script_test_models = [
"squeezenet1_0", "squeezenet1_0",
"vgg11", "vgg11",
"inception_v3", "inception_v3",
'r3d_18', "r3d_18",
"fasterrcnn_resnet50_fpn",
"maskrcnn_resnet50_fpn",
"keypointrcnn_resnet50_fpn",
] ]
...@@ -95,7 +98,6 @@ class ModelTester(TestCase): ...@@ -95,7 +98,6 @@ class ModelTester(TestCase):
def _test_detection_model(self, name): def _test_detection_model(self, name):
set_rng_seed(0) set_rng_seed(0)
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False) model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False)
self.check_script(model, name)
model.eval() model.eval()
input_shape = (3, 300, 300) input_shape = (3, 300, 300)
x = torch.rand(input_shape) x = torch.rand(input_shape)
...@@ -130,9 +132,19 @@ class ModelTester(TestCase): ...@@ -130,9 +132,19 @@ class ModelTester(TestCase):
else: else:
self.assertExpected(map_nested_tensor_object(out, tensor_map_fn=subsample_tensor)) self.assertExpected(map_nested_tensor_object(out, tensor_map_fn=subsample_tensor))
scripted_model = torch.jit.script(model)
scripted_model.eval()
scripted_out = scripted_model(model_input)[1]
self.assertNestedTensorObjectsEqual(scripted_out[0]["boxes"], out[0]["boxes"])
self.assertNestedTensorObjectsEqual(scripted_out[0]["scores"], out[0]["scores"])
# labels currently float in script: need to investigate (though same result)
self.assertNestedTensorObjectsEqual(scripted_out[0]["labels"].to(dtype=torch.long), out[0]["labels"])
self.assertTrue("boxes" in out[0]) self.assertTrue("boxes" in out[0])
self.assertTrue("scores" in out[0]) self.assertTrue("scores" in out[0])
self.assertTrue("labels" in out[0]) self.assertTrue("labels" in out[0])
# don't check script because we are compiling it here:
# TODO: refactor tests
# self.check_script(model, name)
def _test_video_model(self, name): def _test_video_model(self, name):
# the default input shape is # the default input shape is
......
...@@ -367,5 +367,14 @@ class NMSTester(unittest.TestCase): ...@@ -367,5 +367,14 @@ class NMSTester(unittest.TestCase):
self.assertTrue(torch.allclose(r_cpu, r_cuda.cpu()), err_msg.format(iou)) self.assertTrue(torch.allclose(r_cpu, r_cuda.cpu()), err_msg.format(iou))
class NewEmptyTensorTester(unittest.TestCase):
def test_new_empty_tensor(self):
input = torch.tensor([2., 2.], requires_grad=True)
new_shape = [3, 3]
out = torch.ops.torchvision._new_empty_tensor_op(input, new_shape)
assert out.size() == torch.Size([3, 3])
assert out.dtype == input.dtype
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -8,6 +8,7 @@ from torchvision import utils ...@@ -8,6 +8,7 @@ from torchvision import utils
from torchvision import io from torchvision import io
from .extension import _HAS_OPS from .extension import _HAS_OPS
import torch
try: try:
from .version import __version__ # noqa: F401 from .version import __version__ # noqa: F401
...@@ -70,5 +71,4 @@ def get_video_backend(): ...@@ -70,5 +71,4 @@ def get_video_backend():
def _is_tracing(): def _is_tracing():
import torch
return torch._C._get_tracing_state() return torch._C._get_tracing_state()
#pragma once
// All pure C++ headers for the C++ frontend.
#include <torch/all.h>
// Python bindings for the C++ frontend (includes Python.h).
#include <torch/python.h>
using namespace at;
using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class NewEmptyTensorOp : public torch::autograd::Function<NewEmptyTensorOp> {
public:
static variable_list forward(
AutogradContext* ctx,
Variable input,
c10::List<int64_t> new_shape) {
ctx->saved_data["shape"] = input.sizes();
std::vector<int64_t> shape(new_shape.begin(), new_shape.end());
return {input.new_empty(shape, TensorOptions())};
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
// Use data saved in forward
auto shape = ctx->saved_data["shape"].toIntList();
auto out = forward(ctx, grad_output[0], shape);
return {out[0], at::Tensor()};
}
};
Tensor new_empty_tensor(const Tensor& input, c10::List<int64_t> shape) {
return NewEmptyTensorOp::apply(input, shape)[0];
}
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "PSROIPool.h" #include "PSROIPool.h"
#include "ROIAlign.h" #include "ROIAlign.h"
#include "ROIPool.h" #include "ROIPool.h"
#include "empty_tensor_op.h"
#include "nms.h" #include "nms.h"
// If we are in a Windows environment, we need to define // If we are in a Windows environment, we need to define
...@@ -43,6 +44,7 @@ static auto registry = ...@@ -43,6 +44,7 @@ static auto registry =
.op("torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> Tensor", .op("torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> Tensor",
&roi_align) &roi_align)
.op("torchvision::roi_pool", &roi_pool) .op("torchvision::roi_pool", &roi_pool)
.op("torchvision::_new_empty_tensor_op", &new_empty_tensor)
.op("torchvision::ps_roi_align", &ps_roi_align) .op("torchvision::ps_roi_align", &ps_roi_align)
.op("torchvision::ps_roi_pool", &ps_roi_pool) .op("torchvision::ps_roi_pool", &ps_roi_pool)
.op("torchvision::_cuda_version", &_cuda_version); .op("torchvision::_cuda_version", &_cuda_version);
...@@ -37,7 +37,6 @@ class IntermediateLayerGetter(nn.ModuleDict): ...@@ -37,7 +37,6 @@ class IntermediateLayerGetter(nn.ModuleDict):
>>> ('feat2', torch.Size([1, 256, 14, 14]))] >>> ('feat2', torch.Size([1, 256, 14, 14]))]
""" """
_version = 2 _version = 2
__constants__ = ['layers']
__annotations__ = { __annotations__ = {
"return_layers": Dict[str, str], "return_layers": Dict[str, str],
} }
...@@ -46,7 +45,7 @@ class IntermediateLayerGetter(nn.ModuleDict): ...@@ -46,7 +45,7 @@ class IntermediateLayerGetter(nn.ModuleDict):
if not set(return_layers).issubset([name for name, _ in model.named_children()]): if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model") raise ValueError("return_layers are not present in model")
orig_return_layers = return_layers orig_return_layers = return_layers
return_layers = {k: v for k, v in return_layers.items()} return_layers = {str(k): str(v) for k, v in return_layers.items()}
layers = OrderedDict() layers = OrderedDict()
for name, module in model.named_children(): for name, module in model.named_children():
layers[name] = module layers[name] = module
......
...@@ -3,15 +3,28 @@ from __future__ import division ...@@ -3,15 +3,28 @@ from __future__ import division
import math import math
import torch import torch
from torch.jit.annotations import List, Tuple
from torch import Tensor
import torchvision import torchvision
# TODO: https://github.com/pytorch/pytorch/issues/26727
def zeros_like(tensor, dtype):
# type: (Tensor, int) -> Tensor
if tensor.dtype == dtype:
return tensor.detach().clone()
else:
return tensor.to(dtype)
@torch.jit.script
class BalancedPositiveNegativeSampler(object): class BalancedPositiveNegativeSampler(object):
""" """
This class samples batches, ensuring that they contain a fixed proportion of positives This class samples batches, ensuring that they contain a fixed proportion of positives
""" """
def __init__(self, batch_size_per_image, positive_fraction): def __init__(self, batch_size_per_image, positive_fraction):
# type: (int, float)
""" """
Arguments: Arguments:
batch_size_per_image (int): number of elements to be selected per image batch_size_per_image (int): number of elements to be selected per image
...@@ -21,6 +34,7 @@ class BalancedPositiveNegativeSampler(object): ...@@ -21,6 +34,7 @@ class BalancedPositiveNegativeSampler(object):
self.positive_fraction = positive_fraction self.positive_fraction = positive_fraction
def __call__(self, matched_idxs): def __call__(self, matched_idxs):
# type: (List[Tensor])
""" """
Arguments: Arguments:
matched idxs: list of tensors containing -1, 0 or positive values. matched idxs: list of tensors containing -1, 0 or positive values.
...@@ -57,14 +71,15 @@ class BalancedPositiveNegativeSampler(object): ...@@ -57,14 +71,15 @@ class BalancedPositiveNegativeSampler(object):
neg_idx_per_image = negative[perm2] neg_idx_per_image = negative[perm2]
# create binary mask from indices # create binary mask from indices
pos_idx_per_image_mask = torch.zeros_like( pos_idx_per_image_mask = zeros_like(
matched_idxs_per_image, dtype=torch.uint8 matched_idxs_per_image, dtype=torch.uint8
) )
neg_idx_per_image_mask = torch.zeros_like( neg_idx_per_image_mask = zeros_like(
matched_idxs_per_image, dtype=torch.uint8 matched_idxs_per_image, dtype=torch.uint8
) )
pos_idx_per_image_mask[pos_idx_per_image] = 1
neg_idx_per_image_mask[neg_idx_per_image] = 1 pos_idx_per_image_mask[pos_idx_per_image] = torch.tensor(1)
neg_idx_per_image_mask[neg_idx_per_image] = torch.tensor(1)
pos_idx.append(pos_idx_per_image_mask) pos_idx.append(pos_idx_per_image_mask)
neg_idx.append(neg_idx_per_image_mask) neg_idx.append(neg_idx_per_image_mask)
...@@ -120,6 +135,7 @@ def encode_boxes(reference_boxes, proposals, weights): ...@@ -120,6 +135,7 @@ def encode_boxes(reference_boxes, proposals, weights):
return targets return targets
@torch.jit.script
class BoxCoder(object): class BoxCoder(object):
""" """
This class encodes and decodes a set of bounding boxes into This class encodes and decodes a set of bounding boxes into
...@@ -127,6 +143,7 @@ class BoxCoder(object): ...@@ -127,6 +143,7 @@ class BoxCoder(object):
""" """
def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)): def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
# type: (Tuple[float, float, float, float], float)
""" """
Arguments: Arguments:
weights (4-element tuple) weights (4-element tuple)
...@@ -136,6 +153,7 @@ class BoxCoder(object): ...@@ -136,6 +153,7 @@ class BoxCoder(object):
self.bbox_xform_clip = bbox_xform_clip self.bbox_xform_clip = bbox_xform_clip
def encode(self, reference_boxes, proposals): def encode(self, reference_boxes, proposals):
# type: (List[Tensor], List[Tensor])
boxes_per_image = [len(b) for b in reference_boxes] boxes_per_image = [len(b) for b in reference_boxes]
reference_boxes = torch.cat(reference_boxes, dim=0) reference_boxes = torch.cat(reference_boxes, dim=0)
proposals = torch.cat(proposals, dim=0) proposals = torch.cat(proposals, dim=0)
...@@ -159,16 +177,18 @@ class BoxCoder(object): ...@@ -159,16 +177,18 @@ class BoxCoder(object):
return targets return targets
def decode(self, rel_codes, boxes): def decode(self, rel_codes, boxes):
# type: (Tensor, List[Tensor])
assert isinstance(boxes, (list, tuple)) assert isinstance(boxes, (list, tuple))
if isinstance(rel_codes, (list, tuple)):
rel_codes = torch.cat(rel_codes, dim=0)
assert isinstance(rel_codes, torch.Tensor) assert isinstance(rel_codes, torch.Tensor)
boxes_per_image = [b.size(0) for b in boxes] boxes_per_image = [b.size(0) for b in boxes]
concat_boxes = torch.cat(boxes, dim=0) concat_boxes = torch.cat(boxes, dim=0)
box_sum = 0
for val in boxes_per_image:
box_sum += val
pred_boxes = self.decode_single( pred_boxes = self.decode_single(
rel_codes.reshape(sum(boxes_per_image), -1), concat_boxes rel_codes.reshape(box_sum, -1), concat_boxes
) )
return pred_boxes.reshape(sum(boxes_per_image), -1, 4) return pred_boxes.reshape(box_sum, -1, 4)
def decode_single(self, rel_codes, boxes): def decode_single(self, rel_codes, boxes):
""" """
...@@ -210,6 +230,7 @@ class BoxCoder(object): ...@@ -210,6 +230,7 @@ class BoxCoder(object):
return pred_boxes return pred_boxes
@torch.jit.script
class Matcher(object): class Matcher(object):
""" """
This class assigns to each predicted "element" (e.g., a box) a ground-truth This class assigns to each predicted "element" (e.g., a box) a ground-truth
...@@ -228,7 +249,13 @@ class Matcher(object): ...@@ -228,7 +249,13 @@ class Matcher(object):
BELOW_LOW_THRESHOLD = -1 BELOW_LOW_THRESHOLD = -1
BETWEEN_THRESHOLDS = -2 BETWEEN_THRESHOLDS = -2
__annotations__ = {
'BELOW_LOW_THRESHOLD': int,
'BETWEEN_THRESHOLDS': int,
}
def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False): def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False):
# type: (float, float, bool)
""" """
Args: Args:
high_threshold (float): quality values greater than or equal to high_threshold (float): quality values greater than or equal to
...@@ -242,6 +269,8 @@ class Matcher(object): ...@@ -242,6 +269,8 @@ class Matcher(object):
for predictions that have only low-quality match candidates. See for predictions that have only low-quality match candidates. See
set_low_quality_matches_ for more details. set_low_quality_matches_ for more details.
""" """
self.BELOW_LOW_THRESHOLD = -1
self.BETWEEN_THRESHOLDS = -2
assert low_threshold <= high_threshold assert low_threshold <= high_threshold
self.high_threshold = high_threshold self.high_threshold = high_threshold
self.low_threshold = low_threshold self.low_threshold = low_threshold
...@@ -274,16 +303,19 @@ class Matcher(object): ...@@ -274,16 +303,19 @@ class Matcher(object):
matched_vals, matches = match_quality_matrix.max(dim=0) matched_vals, matches = match_quality_matrix.max(dim=0)
if self.allow_low_quality_matches: if self.allow_low_quality_matches:
all_matches = matches.clone() all_matches = matches.clone()
else:
all_matches = None
# Assign candidate matches with low quality to negative (unassigned) values # Assign candidate matches with low quality to negative (unassigned) values
below_low_threshold = matched_vals < self.low_threshold below_low_threshold = matched_vals < self.low_threshold
between_thresholds = (matched_vals >= self.low_threshold) & ( between_thresholds = (matched_vals >= self.low_threshold) & (
matched_vals < self.high_threshold matched_vals < self.high_threshold
) )
matches[below_low_threshold] = Matcher.BELOW_LOW_THRESHOLD matches[below_low_threshold] = torch.tensor(self.BELOW_LOW_THRESHOLD)
matches[between_thresholds] = Matcher.BETWEEN_THRESHOLDS matches[between_thresholds] = torch.tensor(self.BETWEEN_THRESHOLDS)
if self.allow_low_quality_matches: if self.allow_low_quality_matches:
assert all_matches is not None
self.set_low_quality_matches_(matches, all_matches, match_quality_matrix) self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
return matches return matches
......
...@@ -7,14 +7,12 @@ from .._utils import IntermediateLayerGetter ...@@ -7,14 +7,12 @@ from .._utils import IntermediateLayerGetter
from .. import resnet from .. import resnet
class BackboneWithFPN(nn.Sequential): class BackboneWithFPN(nn.Module):
""" """
Adds a FPN on top of a model. Adds a FPN on top of a model.
Internally, it uses torchvision.models._utils.IntermediateLayerGetter to Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
extract a submodel that returns the feature maps specified in return_layers. extract a submodel that returns the feature maps specified in return_layers.
The same limitations of IntermediatLayerGetter apply here. The same limitations of IntermediatLayerGetter apply here.
Arguments: Arguments:
backbone (nn.Module) backbone (nn.Module)
return_layers (Dict[name, new_name]): a dict containing the names return_layers (Dict[name, new_name]): a dict containing the names
...@@ -24,21 +22,24 @@ class BackboneWithFPN(nn.Sequential): ...@@ -24,21 +22,24 @@ class BackboneWithFPN(nn.Sequential):
in_channels_list (List[int]): number of channels for each feature map in_channels_list (List[int]): number of channels for each feature map
that is returned, in the order they are present in the OrderedDict that is returned, in the order they are present in the OrderedDict
out_channels (int): number of channels in the FPN. out_channels (int): number of channels in the FPN.
Attributes: Attributes:
out_channels (int): the number of channels in the FPN out_channels (int): the number of channels in the FPN
""" """
def __init__(self, backbone, return_layers, in_channels_list, out_channels): def __init__(self, backbone, return_layers, in_channels_list, out_channels):
body = IntermediateLayerGetter(backbone, return_layers=return_layers) super(BackboneWithFPN, self).__init__()
fpn = FeaturePyramidNetwork( self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.fpn = FeaturePyramidNetwork(
in_channels_list=in_channels_list, in_channels_list=in_channels_list,
out_channels=out_channels, out_channels=out_channels,
extra_blocks=LastLevelMaxPool(), extra_blocks=LastLevelMaxPool(),
) )
super(BackboneWithFPN, self).__init__(OrderedDict(
[("body", body), ("fpn", fpn)]))
self.out_channels = out_channels self.out_channels = out_channels
def forward(self, x):
x = self.body(x)
x = self.fpn(x)
return x
def resnet_fpn_backbone(backbone_name, pretrained): def resnet_fpn_backbone(backbone_name, pretrained):
backbone = resnet.__dict__[backbone_name]( backbone = resnet.__dict__[backbone_name](
...@@ -49,7 +50,7 @@ def resnet_fpn_backbone(backbone_name, pretrained): ...@@ -49,7 +50,7 @@ def resnet_fpn_backbone(backbone_name, pretrained):
if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
parameter.requires_grad_(False) parameter.requires_grad_(False)
return_layers = {'layer1': 0, 'layer2': 1, 'layer3': 2, 'layer4': 3} return_layers = {'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'}
in_channels_stage2 = backbone.inplanes // 8 in_channels_stage2 = backbone.inplanes // 8
in_channels_list = [ in_channels_list = [
......
...@@ -199,7 +199,7 @@ class FasterRCNN(GeneralizedRCNN): ...@@ -199,7 +199,7 @@ class FasterRCNN(GeneralizedRCNN):
if box_roi_pool is None: if box_roi_pool is None:
box_roi_pool = MultiScaleRoIAlign( box_roi_pool = MultiScaleRoIAlign(
featmap_names=[0, 1, 2, 3], featmap_names=['0', '1', '2', '3'],
output_size=7, output_size=7,
sampling_ratio=2) sampling_ratio=2)
...@@ -273,7 +273,7 @@ class FastRCNNPredictor(nn.Module): ...@@ -273,7 +273,7 @@ class FastRCNNPredictor(nn.Module):
self.bbox_pred = nn.Linear(in_channels, num_classes * 4) self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
def forward(self, x): def forward(self, x):
if x.ndimension() == 4: if x.dim() == 4:
assert list(x.shape[2:]) == [1, 1] assert list(x.shape[2:]) == [1, 1]
x = x.flatten(start_dim=1) x = x.flatten(start_dim=1)
scores = self.cls_score(x) scores = self.cls_score(x)
......
...@@ -6,6 +6,9 @@ Implements the Generalized R-CNN framework ...@@ -6,6 +6,9 @@ Implements the Generalized R-CNN framework
from collections import OrderedDict from collections import OrderedDict
import torch import torch
from torch import nn from torch import nn
import warnings
from torch.jit.annotations import Tuple, List, Dict, Optional
from torch import Tensor
class GeneralizedRCNN(nn.Module): class GeneralizedRCNN(nn.Module):
...@@ -28,7 +31,16 @@ class GeneralizedRCNN(nn.Module): ...@@ -28,7 +31,16 @@ class GeneralizedRCNN(nn.Module):
self.rpn = rpn self.rpn = rpn
self.roi_heads = roi_heads self.roi_heads = roi_heads
@torch.jit.unused
def eager_outputs(self, losses, detections):
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
if self.training:
return losses
return detections
def forward(self, images, targets=None): def forward(self, images, targets=None):
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
""" """
Arguments: Arguments:
images (list[Tensor]): images to be processed images (list[Tensor]): images to be processed
...@@ -43,7 +55,12 @@ class GeneralizedRCNN(nn.Module): ...@@ -43,7 +55,12 @@ class GeneralizedRCNN(nn.Module):
""" """
if self.training and targets is None: if self.training and targets is None:
raise ValueError("In training mode, targets should be passed") raise ValueError("In training mode, targets should be passed")
original_image_sizes = [img.shape[-2:] for img in images] original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
for img in images:
val = img.shape[-2:]
assert len(val) == 2
original_image_sizes.append((val[0], val[1]))
images, targets = self.transform(images, targets) images, targets = self.transform(images, targets)
features = self.backbone(images.tensors) features = self.backbone(images.tensors)
if isinstance(features, torch.Tensor): if isinstance(features, torch.Tensor):
...@@ -56,7 +73,8 @@ class GeneralizedRCNN(nn.Module): ...@@ -56,7 +73,8 @@ class GeneralizedRCNN(nn.Module):
losses.update(detector_losses) losses.update(detector_losses)
losses.update(proposal_losses) losses.update(proposal_losses)
if self.training: if torch.jit.is_scripting():
return losses warnings.warn("RCNN always returns a (Losses, Detections tuple in scripting)")
return (losses, detections)
return detections else:
return self.eager_outputs(losses, detections)
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
from __future__ import division from __future__ import division
import torch import torch
from torch.jit.annotations import List, Tuple
from torch import Tensor
class ImageList(object): class ImageList(object):
...@@ -13,6 +15,7 @@ class ImageList(object): ...@@ -13,6 +15,7 @@ class ImageList(object):
""" """
def __init__(self, tensors, image_sizes): def __init__(self, tensors, image_sizes):
# type: (Tensor, List[Tuple[int, int]])
""" """
Arguments: Arguments:
tensors (tensor) tensors (tensor)
...@@ -21,6 +24,7 @@ class ImageList(object): ...@@ -21,6 +24,7 @@ class ImageList(object):
self.tensors = tensors self.tensors = tensors
self.image_sizes = image_sizes self.image_sizes = image_sizes
def to(self, *args, **kwargs): def to(self, device):
cast_tensor = self.tensors.to(*args, **kwargs) # type: (Device) # noqa
cast_tensor = self.tensors.to(device)
return ImageList(cast_tensor, self.image_sizes) return ImageList(cast_tensor, self.image_sizes)
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
from torch import nn from torch import nn
from torchvision.ops import misc as misc_nn_ops from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import MultiScaleRoIAlign from torchvision.ops import MultiScaleRoIAlign
from ..utils import load_state_dict_from_url from ..utils import load_state_dict_from_url
...@@ -179,7 +180,7 @@ class KeypointRCNN(FasterRCNN): ...@@ -179,7 +180,7 @@ class KeypointRCNN(FasterRCNN):
if keypoint_roi_pool is None: if keypoint_roi_pool is None:
keypoint_roi_pool = MultiScaleRoIAlign( keypoint_roi_pool = MultiScaleRoIAlign(
featmap_names=[0, 1, 2, 3], featmap_names=['0', '1', '2', '3'],
output_size=14, output_size=14,
sampling_ratio=2) sampling_ratio=2)
...@@ -252,7 +253,7 @@ class KeypointRCNNPredictor(nn.Module): ...@@ -252,7 +253,7 @@ class KeypointRCNNPredictor(nn.Module):
def forward(self, x): def forward(self, x):
x = self.kps_score_lowres(x) x = self.kps_score_lowres(x)
x = misc_nn_ops.interpolate( x = misc_nn_ops.interpolate(
x, scale_factor=self.up_scale, mode="bilinear", align_corners=False x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False
) )
return x return x
......
...@@ -178,7 +178,7 @@ class MaskRCNN(FasterRCNN): ...@@ -178,7 +178,7 @@ class MaskRCNN(FasterRCNN):
if mask_roi_pool is None: if mask_roi_pool is None:
mask_roi_pool = MultiScaleRoIAlign( mask_roi_pool = MultiScaleRoIAlign(
featmap_names=[0, 1, 2, 3], featmap_names=['0', '1', '2', '3'],
output_size=14, output_size=14,
sampling_ratio=2) sampling_ratio=2)
......
...@@ -3,16 +3,20 @@ import torch ...@@ -3,16 +3,20 @@ import torch
import torchvision import torchvision
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn, Tensor
from torchvision.ops import boxes as box_ops from torchvision.ops import boxes as box_ops
from torchvision.ops import misc as misc_nn_ops from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import roi_align from torchvision.ops import roi_align
from . import _utils as det_utils from . import _utils as det_utils
from torch.jit.annotations import Optional, List, Dict, Tuple
def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
# type: (Tensor, Tensor, List[Tensor], List[Tensor])
""" """
Computes the loss for Faster R-CNN. Computes the loss for Faster R-CNN.
...@@ -51,6 +55,7 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): ...@@ -51,6 +55,7 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
def maskrcnn_inference(x, labels): def maskrcnn_inference(x, labels):
# type: (Tensor, List[Tensor])
""" """
From the results of the CNN, post process the masks From the results of the CNN, post process the masks
by taking the mask corresponding to the class with max by taking the mask corresponding to the class with max
...@@ -77,14 +82,16 @@ def maskrcnn_inference(x, labels): ...@@ -77,14 +82,16 @@ def maskrcnn_inference(x, labels):
if len(boxes_per_image) == 1: if len(boxes_per_image) == 1:
# TODO : remove when dynamic split supported in ONNX # TODO : remove when dynamic split supported in ONNX
mask_prob = (mask_prob,) # and remove assignment to mask_prob_list, just assign to mask_prob
mask_prob_list = [mask_prob]
else: else:
mask_prob = mask_prob.split(boxes_per_image, dim=0) mask_prob_list = mask_prob.split(boxes_per_image, dim=0)
return mask_prob return mask_prob_list
def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M): def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
# type: (Tensor, Tensor, Tensor, int)
""" """
Given segmentation masks and the bounding boxes corresponding Given segmentation masks and the bounding boxes corresponding
to the location of the masks in the image, this function to the location of the masks in the image, this function
...@@ -95,10 +102,11 @@ def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M): ...@@ -95,10 +102,11 @@ def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
matched_idxs = matched_idxs.to(boxes) matched_idxs = matched_idxs.to(boxes)
rois = torch.cat([matched_idxs[:, None], boxes], dim=1) rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
gt_masks = gt_masks[:, None].to(rois) gt_masks = gt_masks[:, None].to(rois)
return roi_align(gt_masks, rois, (M, M), 1)[:, 0] return roi_align(gt_masks, rois, (M, M), 1.)[:, 0]
def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs): def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
# type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor])
""" """
Arguments: Arguments:
proposals (list[BoxList]) proposals (list[BoxList])
...@@ -131,6 +139,7 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs ...@@ -131,6 +139,7 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs
def keypoints_to_heatmap(keypoints, rois, heatmap_size): def keypoints_to_heatmap(keypoints, rois, heatmap_size):
# type: (Tensor, Tensor, int)
offset_x = rois[:, 0] offset_x = rois[:, 0]
offset_y = rois[:, 1] offset_y = rois[:, 1]
scale_x = heatmap_size / (rois[:, 2] - rois[:, 0]) scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
...@@ -152,8 +161,8 @@ def keypoints_to_heatmap(keypoints, rois, heatmap_size): ...@@ -152,8 +161,8 @@ def keypoints_to_heatmap(keypoints, rois, heatmap_size):
y = (y - offset_y) * scale_y y = (y - offset_y) * scale_y
y = y.floor().long() y = y.floor().long()
x[x_boundary_inds] = heatmap_size - 1 x[x_boundary_inds] = torch.tensor(heatmap_size - 1)
y[y_boundary_inds] = heatmap_size - 1 y[y_boundary_inds] = torch.tensor(heatmap_size - 1)
valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size) valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
vis = keypoints[..., 2] > 0 vis = keypoints[..., 2] > 0
...@@ -217,6 +226,17 @@ def _onnx_heatmaps_to_keypoints_loop(maps, rois, widths_ceil, heights_ceil, ...@@ -217,6 +226,17 @@ def _onnx_heatmaps_to_keypoints_loop(maps, rois, widths_ceil, heights_ceil,
return xy_preds, end_scores return xy_preds, end_scores
# workaround for issue pytorch 27512
def tensor_floordiv(tensor, int_div):
# type: (Tensor, int)
result = tensor / int_div
# TODO: https://github.com/pytorch/pytorch/issues/26731
floating_point_types = (torch.float, torch.double, torch.half)
if result.dtype in floating_point_types:
result = result.trunc()
return result
def heatmaps_to_keypoints(maps, rois): def heatmaps_to_keypoints(maps, rois):
"""Extract predicted keypoint locations from heatmaps. Output has shape """Extract predicted keypoint locations from heatmaps. Output has shape
(#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob) (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
...@@ -258,8 +278,9 @@ def heatmaps_to_keypoints(maps, rois): ...@@ -258,8 +278,9 @@ def heatmaps_to_keypoints(maps, rois):
# roi_map_probs = scores_to_probs(roi_map.copy()) # roi_map_probs = scores_to_probs(roi_map.copy())
w = roi_map.shape[2] w = roi_map.shape[2]
pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1) pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
x_int = pos % w x_int = pos % w
y_int = (pos - x_int) // w y_int = tensor_floordiv((pos - x_int), w)
# assert (roi_map_probs[k, y_int, x_int] == # assert (roi_map_probs[k, y_int, x_int] ==
# roi_map_probs[k, :, :].max()) # roi_map_probs[k, :, :].max())
x = (x_int.float() + 0.5) * width_correction x = (x_int.float() + 0.5) * width_correction
...@@ -273,6 +294,7 @@ def heatmaps_to_keypoints(maps, rois): ...@@ -273,6 +294,7 @@ def heatmaps_to_keypoints(maps, rois):
def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs): def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
# type: (Tensor, List[Tensor], List[Tensor], List[Tensor])
N, K, H, W = keypoint_logits.shape N, K, H, W = keypoint_logits.shape
assert H == W assert H == W
discretization_size = H discretization_size = H
...@@ -302,6 +324,7 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched ...@@ -302,6 +324,7 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched
def keypointrcnn_inference(x, boxes): def keypointrcnn_inference(x, boxes):
# type: (Tensor, List[Tensor])
kp_probs = [] kp_probs = []
kp_scores = [] kp_scores = []
...@@ -323,6 +346,7 @@ def keypointrcnn_inference(x, boxes): ...@@ -323,6 +346,7 @@ def keypointrcnn_inference(x, boxes):
def _onnx_expand_boxes(boxes, scale): def _onnx_expand_boxes(boxes, scale):
# type: (Tensor, float)
w_half = (boxes[:, 2] - boxes[:, 0]) * .5 w_half = (boxes[:, 2] - boxes[:, 0]) * .5
h_half = (boxes[:, 3] - boxes[:, 1]) * .5 h_half = (boxes[:, 3] - boxes[:, 1]) * .5
x_c = (boxes[:, 2] + boxes[:, 0]) * .5 x_c = (boxes[:, 2] + boxes[:, 0]) * .5
...@@ -343,6 +367,7 @@ def _onnx_expand_boxes(boxes, scale): ...@@ -343,6 +367,7 @@ def _onnx_expand_boxes(boxes, scale):
# but are kept here for the moment while we need them # but are kept here for the moment while we need them
# temporarily for paste_mask_in_image # temporarily for paste_mask_in_image
def expand_boxes(boxes, scale): def expand_boxes(boxes, scale):
# type: (Tensor, float)
if torchvision._is_tracing(): if torchvision._is_tracing():
return _onnx_expand_boxes(boxes, scale) return _onnx_expand_boxes(boxes, scale)
w_half = (boxes[:, 2] - boxes[:, 0]) * .5 w_half = (boxes[:, 2] - boxes[:, 0]) * .5
...@@ -361,10 +386,17 @@ def expand_boxes(boxes, scale): ...@@ -361,10 +386,17 @@ def expand_boxes(boxes, scale):
return boxes_exp return boxes_exp
@torch.jit.unused
def expand_masks_tracing_scale(M, padding):
# type: (int, int) -> float
return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
def expand_masks(mask, padding): def expand_masks(mask, padding):
# type: (Tensor, int)
M = mask.shape[-1] M = mask.shape[-1]
if torchvision._is_tracing(): if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why
scale = (M + 2 * padding).to(torch.float32) / M.to(torch.float32) scale = expand_masks_tracing_scale(M, padding)
else: else:
scale = float(M + 2 * padding) / M scale = float(M + 2 * padding) / M
padded_mask = torch.nn.functional.pad(mask, (padding,) * 4) padded_mask = torch.nn.functional.pad(mask, (padding,) * 4)
...@@ -372,6 +404,7 @@ def expand_masks(mask, padding): ...@@ -372,6 +404,7 @@ def expand_masks(mask, padding):
def paste_mask_in_image(mask, box, im_h, im_w): def paste_mask_in_image(mask, box, im_h, im_w):
# type: (Tensor, Tensor, int, int)
TO_REMOVE = 1 TO_REMOVE = 1
w = int(box[2] - box[0] + TO_REMOVE) w = int(box[2] - box[0] + TO_REMOVE)
h = int(box[3] - box[1] + TO_REMOVE) h = int(box[3] - box[1] + TO_REMOVE)
...@@ -449,29 +482,33 @@ def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w): ...@@ -449,29 +482,33 @@ def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
def paste_masks_in_image(masks, boxes, img_shape, padding=1): def paste_masks_in_image(masks, boxes, img_shape, padding=1):
# type: (Tensor, Tensor, Tuple[int, int], int)
masks, scale = expand_masks(masks, padding=padding) masks, scale = expand_masks(masks, padding=padding)
boxes = expand_boxes(boxes, scale).to(dtype=torch.int64) boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
# im_h, im_w = img_shape.tolist()
im_h, im_w = img_shape im_h, im_w = img_shape
if torchvision._is_tracing(): if torchvision._is_tracing():
return _onnx_paste_masks_in_image_loop(masks, boxes, return _onnx_paste_masks_in_image_loop(masks, boxes,
torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_h, dtype=torch.int64),
torch.scalar_tensor(im_w, dtype=torch.int64))[:, None] torch.scalar_tensor(im_w, dtype=torch.int64))[:, None]
boxes = boxes.tolist()
res = [ res = [
paste_mask_in_image(m[0], b, im_h, im_w) paste_mask_in_image(m[0], b, im_h, im_w)
for m, b in zip(masks, boxes) for m, b in zip(masks, boxes)
] ]
if len(res) > 0: if len(res) > 0:
res = torch.stack(res, dim=0)[:, None] ret = torch.stack(res, dim=0)[:, None]
else: else:
res = masks.new_empty((0, 1, im_h, im_w)) ret = masks.new_empty((0, 1, im_h, im_w))
return res return ret
class RoIHeads(torch.nn.Module): class RoIHeads(torch.nn.Module):
__annotations__ = {
'box_coder': det_utils.BoxCoder,
'proposal_matcher': det_utils.Matcher,
'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler,
}
def __init__(self, def __init__(self,
box_roi_pool, box_roi_pool,
box_head, box_head,
...@@ -525,7 +562,6 @@ class RoIHeads(torch.nn.Module): ...@@ -525,7 +562,6 @@ class RoIHeads(torch.nn.Module):
self.keypoint_head = keypoint_head self.keypoint_head = keypoint_head
self.keypoint_predictor = keypoint_predictor self.keypoint_predictor = keypoint_predictor
@property
def has_mask(self): def has_mask(self):
if self.mask_roi_pool is None: if self.mask_roi_pool is None:
return False return False
...@@ -535,7 +571,6 @@ class RoIHeads(torch.nn.Module): ...@@ -535,7 +571,6 @@ class RoIHeads(torch.nn.Module):
return False return False
return True return True
@property
def has_keypoint(self): def has_keypoint(self):
if self.keypoint_roi_pool is None: if self.keypoint_roi_pool is None:
return False return False
...@@ -546,10 +581,12 @@ class RoIHeads(torch.nn.Module): ...@@ -546,10 +581,12 @@ class RoIHeads(torch.nn.Module):
return True return True
def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels): def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
# type: (List[Tensor], List[Tensor], List[Tensor])
matched_idxs = [] matched_idxs = []
labels = [] labels = []
for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels): for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
match_quality_matrix = self.box_similarity(gt_boxes_in_image, proposals_in_image) # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
matched_idxs_in_image = self.proposal_matcher(match_quality_matrix) matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0) clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
...@@ -559,17 +596,18 @@ class RoIHeads(torch.nn.Module): ...@@ -559,17 +596,18 @@ class RoIHeads(torch.nn.Module):
# Label background (below the low threshold) # Label background (below the low threshold)
bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
labels_in_image[bg_inds] = 0 labels_in_image[bg_inds] = torch.tensor(0)
# Label ignore proposals (between low and high thresholds) # Label ignore proposals (between low and high thresholds)
ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler labels_in_image[ignore_inds] = torch.tensor(-1) # -1 is ignored by sampler
matched_idxs.append(clamped_matched_idxs_in_image) matched_idxs.append(clamped_matched_idxs_in_image)
labels.append(labels_in_image) labels.append(labels_in_image)
return matched_idxs, labels return matched_idxs, labels
def subsample(self, labels): def subsample(self, labels):
# type: (List[Tensor])
sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
sampled_inds = [] sampled_inds = []
for img_idx, (pos_inds_img, neg_inds_img) in enumerate( for img_idx, (pos_inds_img, neg_inds_img) in enumerate(
...@@ -580,6 +618,7 @@ class RoIHeads(torch.nn.Module): ...@@ -580,6 +618,7 @@ class RoIHeads(torch.nn.Module):
return sampled_inds return sampled_inds
def add_gt_proposals(self, proposals, gt_boxes): def add_gt_proposals(self, proposals, gt_boxes):
# type: (List[Tensor], List[Tensor])
proposals = [ proposals = [
torch.cat((proposal, gt_box)) torch.cat((proposal, gt_box))
for proposal, gt_box in zip(proposals, gt_boxes) for proposal, gt_box in zip(proposals, gt_boxes)
...@@ -587,15 +626,25 @@ class RoIHeads(torch.nn.Module): ...@@ -587,15 +626,25 @@ class RoIHeads(torch.nn.Module):
return proposals return proposals
def DELTEME_all(self, the_list):
# type: (List[bool])
for i in the_list:
if not i:
return False
return True
def check_targets(self, targets): def check_targets(self, targets):
# type: (Optional[List[Dict[str, Tensor]]])
assert targets is not None assert targets is not None
assert all("boxes" in t for t in targets) assert self.DELTEME_all(["boxes" in t for t in targets])
assert all("labels" in t for t in targets) assert self.DELTEME_all(["labels" in t for t in targets])
if self.has_mask: if self.has_mask():
assert all("masks" in t for t in targets) assert self.DELTEME_all(["masks" in t for t in targets])
def select_training_samples(self, proposals, targets): def select_training_samples(self, proposals, targets):
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
self.check_targets(targets) self.check_targets(targets)
assert targets is not None
dtype = proposals[0].dtype dtype = proposals[0].dtype
gt_boxes = [t["boxes"].to(dtype) for t in targets] gt_boxes = [t["boxes"].to(dtype) for t in targets]
gt_labels = [t["labels"] for t in targets] gt_labels = [t["labels"] for t in targets]
...@@ -620,6 +669,7 @@ class RoIHeads(torch.nn.Module): ...@@ -620,6 +669,7 @@ class RoIHeads(torch.nn.Module):
return proposals, matched_idxs, labels, regression_targets return proposals, matched_idxs, labels, regression_targets
def postprocess_detections(self, class_logits, box_regression, proposals, image_shapes): def postprocess_detections(self, class_logits, box_regression, proposals, image_shapes):
# type: (Tensor, Tensor, List[Tensor], List[Tuple[int, int]])
device = class_logits.device device = class_logits.device
num_classes = class_logits.shape[-1] num_classes = class_logits.shape[-1]
...@@ -631,16 +681,17 @@ class RoIHeads(torch.nn.Module): ...@@ -631,16 +681,17 @@ class RoIHeads(torch.nn.Module):
# split boxes and scores per image # split boxes and scores per image
if len(boxes_per_image) == 1: if len(boxes_per_image) == 1:
# TODO : remove this when ONNX support dynamic split sizes # TODO : remove this when ONNX support dynamic split sizes
pred_boxes = (pred_boxes,) # and just assign to pred_boxes instead of pred_boxes_list
pred_scores = (pred_scores,) pred_boxes_list = [pred_boxes]
pred_scores_list = [pred_scores]
else: else:
pred_boxes = pred_boxes.split(boxes_per_image, 0) pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
pred_scores = pred_scores.split(boxes_per_image, 0) pred_scores_list = pred_scores.split(boxes_per_image, 0)
all_boxes = [] all_boxes = []
all_scores = [] all_scores = []
all_labels = [] all_labels = []
for boxes, scores, image_shape in zip(pred_boxes, pred_scores, image_shapes): for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
boxes = box_ops.clip_boxes_to_image(boxes, image_shape) boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
# create labels for each prediction # create labels for each prediction
...@@ -678,6 +729,7 @@ class RoIHeads(torch.nn.Module): ...@@ -678,6 +729,7 @@ class RoIHeads(torch.nn.Module):
return all_boxes, all_scores, all_labels return all_boxes, all_scores, all_labels
def forward(self, features, proposals, image_shapes, targets=None): def forward(self, features, proposals, image_shapes, targets=None):
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]], Optional[List[Dict[str, Tensor]]])
""" """
Arguments: Arguments:
features (List[Tensor]) features (List[Tensor])
...@@ -687,38 +739,50 @@ class RoIHeads(torch.nn.Module): ...@@ -687,38 +739,50 @@ class RoIHeads(torch.nn.Module):
""" """
if targets is not None: if targets is not None:
for t in targets: for t in targets:
assert t["boxes"].dtype.is_floating_point, 'target boxes must of float type' # TODO: https://github.com/pytorch/pytorch/issues/26731
floating_point_types = (torch.float, torch.double, torch.half)
assert t["boxes"].dtype in floating_point_types, 'target boxes must of float type'
assert t["labels"].dtype == torch.int64, 'target labels must of int64 type' assert t["labels"].dtype == torch.int64, 'target labels must of int64 type'
if self.has_keypoint: if self.has_keypoint():
assert t["keypoints"].dtype == torch.float32, 'target keypoints must of float type' assert t["keypoints"].dtype == torch.float32, 'target keypoints must of float type'
if self.training: if self.training:
proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets) proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
else:
labels = None
regression_targets = None
matched_idxs = None
box_features = self.box_roi_pool(features, proposals, image_shapes) box_features = self.box_roi_pool(features, proposals, image_shapes)
box_features = self.box_head(box_features) box_features = self.box_head(box_features)
class_logits, box_regression = self.box_predictor(box_features) class_logits, box_regression = self.box_predictor(box_features)
result, losses = [], {} result = torch.jit.annotate(List[Dict[str, torch.Tensor]], [])
losses = {}
if self.training: if self.training:
assert labels is not None and regression_targets is not None
loss_classifier, loss_box_reg = fastrcnn_loss( loss_classifier, loss_box_reg = fastrcnn_loss(
class_logits, box_regression, labels, regression_targets) class_logits, box_regression, labels, regression_targets)
losses = dict(loss_classifier=loss_classifier, loss_box_reg=loss_box_reg) losses = {
"loss_classifier": loss_classifier,
"loss_box_reg": loss_box_reg
}
else: else:
boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes) boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
num_images = len(boxes) num_images = len(boxes)
for i in range(num_images): for i in range(num_images):
result.append( result.append(
dict( {
boxes=boxes[i], "boxes": boxes[i],
labels=labels[i], "labels": labels[i],
scores=scores[i], "scores": scores[i],
) }
) )
if self.has_mask: if self.has_mask():
mask_proposals = [p["boxes"] for p in result] mask_proposals = [p["boxes"] for p in result]
if self.training: if self.training:
assert matched_idxs is not None
# during training, only focus on positive boxes # during training, only focus on positive boxes
num_images = len(proposals) num_images = len(proposals)
mask_proposals = [] mask_proposals = []
...@@ -727,19 +791,31 @@ class RoIHeads(torch.nn.Module): ...@@ -727,19 +791,31 @@ class RoIHeads(torch.nn.Module):
pos = torch.nonzero(labels[img_id] > 0).squeeze(1) pos = torch.nonzero(labels[img_id] > 0).squeeze(1)
mask_proposals.append(proposals[img_id][pos]) mask_proposals.append(proposals[img_id][pos])
pos_matched_idxs.append(matched_idxs[img_id][pos]) pos_matched_idxs.append(matched_idxs[img_id][pos])
else:
pos_matched_idxs = None
mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes) if self.mask_roi_pool is not None:
mask_features = self.mask_head(mask_features) mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
mask_logits = self.mask_predictor(mask_features) mask_features = self.mask_head(mask_features)
mask_logits = self.mask_predictor(mask_features)
else:
mask_logits = torch.tensor(0)
raise Exception("Expected mask_roi_pool to be not None")
loss_mask = {} loss_mask = {}
if self.training: if self.training:
assert targets is not None
assert pos_matched_idxs is not None
assert mask_logits is not None
gt_masks = [t["masks"] for t in targets] gt_masks = [t["masks"] for t in targets]
gt_labels = [t["labels"] for t in targets] gt_labels = [t["labels"] for t in targets]
loss_mask = maskrcnn_loss( rcnn_loss_mask = maskrcnn_loss(
mask_logits, mask_proposals, mask_logits, mask_proposals,
gt_masks, gt_labels, pos_matched_idxs) gt_masks, gt_labels, pos_matched_idxs)
loss_mask = dict(loss_mask=loss_mask) loss_mask = {
"loss_mask": rcnn_loss_mask
}
else: else:
labels = [r["labels"] for r in result] labels = [r["labels"] for r in result]
masks_probs = maskrcnn_inference(mask_logits, labels) masks_probs = maskrcnn_inference(mask_logits, labels)
...@@ -748,17 +824,23 @@ class RoIHeads(torch.nn.Module): ...@@ -748,17 +824,23 @@ class RoIHeads(torch.nn.Module):
losses.update(loss_mask) losses.update(loss_mask)
if self.has_keypoint: # keep none checks in if conditional so torchscript will conditionally
# compile each branch
if self.keypoint_roi_pool is not None and self.keypoint_head is not None \
and self.keypoint_predictor is not None:
keypoint_proposals = [p["boxes"] for p in result] keypoint_proposals = [p["boxes"] for p in result]
if self.training: if self.training:
# during training, only focus on positive boxes # during training, only focus on positive boxes
num_images = len(proposals) num_images = len(proposals)
keypoint_proposals = [] keypoint_proposals = []
pos_matched_idxs = [] pos_matched_idxs = []
assert matched_idxs is not None
for img_id in range(num_images): for img_id in range(num_images):
pos = torch.nonzero(labels[img_id] > 0).squeeze(1) pos = torch.nonzero(labels[img_id] > 0).squeeze(1)
keypoint_proposals.append(proposals[img_id][pos]) keypoint_proposals.append(proposals[img_id][pos])
pos_matched_idxs.append(matched_idxs[img_id][pos]) pos_matched_idxs.append(matched_idxs[img_id][pos])
else:
pos_matched_idxs = None
keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes) keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
keypoint_features = self.keypoint_head(keypoint_features) keypoint_features = self.keypoint_head(keypoint_features)
...@@ -766,12 +848,20 @@ class RoIHeads(torch.nn.Module): ...@@ -766,12 +848,20 @@ class RoIHeads(torch.nn.Module):
loss_keypoint = {} loss_keypoint = {}
if self.training: if self.training:
assert targets is not None
assert pos_matched_idxs is not None
gt_keypoints = [t["keypoints"] for t in targets] gt_keypoints = [t["keypoints"] for t in targets]
loss_keypoint = keypointrcnn_loss( rcnn_loss_keypoint = keypointrcnn_loss(
keypoint_logits, keypoint_proposals, keypoint_logits, keypoint_proposals,
gt_keypoints, pos_matched_idxs) gt_keypoints, pos_matched_idxs)
loss_keypoint = dict(loss_keypoint=loss_keypoint) loss_keypoint = {
"loss_keypoint": rcnn_loss_keypoint
}
else: else:
assert keypoint_logits is not None
assert keypoint_proposals is not None
keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals) keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result): for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
r["keypoints"] = keypoint_prob r["keypoints"] = keypoint_prob
......
from __future__ import division
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from torch import nn from torch import nn, Tensor
import torchvision import torchvision
from torchvision.ops import boxes as box_ops from torchvision.ops import boxes as box_ops
from . import _utils as det_utils from . import _utils as det_utils
from .image_list import ImageList
from torch.jit.annotations import List, Optional, Dict, Tuple
@torch.jit.unused @torch.jit.unused
def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n): def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
# type: (Tensor, int) -> Tuple[int, int]
from torch.onnx import operators from torch.onnx import operators
num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0) num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0)
# TODO : remove cast to IntTensor/num_anchors.dtype when # TODO : remove cast to IntTensor/num_anchors.dtype when
...@@ -23,6 +29,11 @@ def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n): ...@@ -23,6 +29,11 @@ def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
class AnchorGenerator(nn.Module): class AnchorGenerator(nn.Module):
__annotations__ = {
"cell_anchors": Optional[List[torch.Tensor]],
"_cache": Dict[str, List[torch.Tensor]]
}
""" """
Module that generates anchors for a set of feature maps and Module that generates anchors for a set of feature maps and
image sizes. image sizes.
...@@ -62,8 +73,9 @@ class AnchorGenerator(nn.Module): ...@@ -62,8 +73,9 @@ class AnchorGenerator(nn.Module):
self.cell_anchors = None self.cell_anchors = None
self._cache = {} self._cache = {}
@staticmethod # TODO: https://github.com/pytorch/pytorch/issues/26792
def generate_anchors(scales, aspect_ratios, dtype=torch.float32, device="cpu"): def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"):
# type: (List[int], List[float], int, Device) # noqa: F821
scales = torch.as_tensor(scales, dtype=dtype, device=device) scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device) aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios) h_ratios = torch.sqrt(aspect_ratios)
...@@ -76,8 +88,10 @@ class AnchorGenerator(nn.Module): ...@@ -76,8 +88,10 @@ class AnchorGenerator(nn.Module):
return base_anchors.round() return base_anchors.round()
def set_cell_anchors(self, dtype, device): def set_cell_anchors(self, dtype, device):
# type: (int, Device) -> None # noqa: F821
if self.cell_anchors is not None: if self.cell_anchors is not None:
return self.cell_anchors return
cell_anchors = [ cell_anchors = [
self.generate_anchors( self.generate_anchors(
sizes, sizes,
...@@ -93,9 +107,13 @@ class AnchorGenerator(nn.Module): ...@@ -93,9 +107,13 @@ class AnchorGenerator(nn.Module):
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)] return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
def grid_anchors(self, grid_sizes, strides): def grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[int]])
anchors = [] anchors = []
cell_anchors = self.cell_anchors
assert cell_anchors is not None
for size, stride, base_anchors in zip( for size, stride, base_anchors in zip(
grid_sizes, strides, self.cell_anchors grid_sizes, strides, cell_anchors
): ):
grid_height, grid_width = size grid_height, grid_width = size
stride_height, stride_width = stride stride_height, stride_width = stride
...@@ -122,7 +140,8 @@ class AnchorGenerator(nn.Module): ...@@ -122,7 +140,8 @@ class AnchorGenerator(nn.Module):
return anchors return anchors
def cached_grid_anchors(self, grid_sizes, strides): def cached_grid_anchors(self, grid_sizes, strides):
key = tuple(grid_sizes) + tuple(strides) # type: (List[List[int]], List[List[int]])
key = str(grid_sizes + strides)
if key in self._cache: if key in self._cache:
return self._cache[key] return self._cache[key]
anchors = self.grid_anchors(grid_sizes, strides) anchors = self.grid_anchors(grid_sizes, strides)
...@@ -130,15 +149,14 @@ class AnchorGenerator(nn.Module): ...@@ -130,15 +149,14 @@ class AnchorGenerator(nn.Module):
return anchors return anchors
def forward(self, image_list, feature_maps): def forward(self, image_list, feature_maps):
grid_sizes = tuple([feature_map.shape[-2:] for feature_map in feature_maps]) # type: (ImageList, List[Tensor])
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:] image_size = image_list.tensors.shape[-2:]
strides = tuple((float(image_size[0]) / float(g[0]), strides = [[int(image_size[0] / g[0]), int(image_size[1] / g[1])] for g in grid_sizes]
float(image_size[1]) / float(g[1]))
for g in grid_sizes)
dtype, device = feature_maps[0].dtype, feature_maps[0].device dtype, device = feature_maps[0].dtype, feature_maps[0].device
self.set_cell_anchors(dtype, device) self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides) anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = [] anchors = torch.jit.annotate(List[List[torch.Tensor]], [])
for i, (image_height, image_width) in enumerate(image_list.image_sizes): for i, (image_height, image_width) in enumerate(image_list.image_sizes):
anchors_in_image = [] anchors_in_image = []
for anchors_per_feature_map in anchors_over_all_feature_maps: for anchors_per_feature_map in anchors_over_all_feature_maps:
...@@ -172,6 +190,7 @@ class RPNHead(nn.Module): ...@@ -172,6 +190,7 @@ class RPNHead(nn.Module):
torch.nn.init.constant_(l.bias, 0) torch.nn.init.constant_(l.bias, 0)
def forward(self, x): def forward(self, x):
# type: (List[Tensor])
logits = [] logits = []
bbox_reg = [] bbox_reg = []
for feature in x: for feature in x:
...@@ -182,6 +201,7 @@ class RPNHead(nn.Module): ...@@ -182,6 +201,7 @@ class RPNHead(nn.Module):
def permute_and_flatten(layer, N, A, C, H, W): def permute_and_flatten(layer, N, A, C, H, W):
# type: (Tensor, int, int, int, int, int)
layer = layer.view(N, -1, C, H, W) layer = layer.view(N, -1, C, H, W)
layer = layer.permute(0, 3, 4, 1, 2) layer = layer.permute(0, 3, 4, 1, 2)
layer = layer.reshape(N, -1, C) layer = layer.reshape(N, -1, C)
...@@ -189,12 +209,14 @@ def permute_and_flatten(layer, N, A, C, H, W): ...@@ -189,12 +209,14 @@ def permute_and_flatten(layer, N, A, C, H, W):
def concat_box_prediction_layers(box_cls, box_regression): def concat_box_prediction_layers(box_cls, box_regression):
# type: (List[Tensor], List[Tensor])
box_cls_flattened = [] box_cls_flattened = []
box_regression_flattened = [] box_regression_flattened = []
# for each feature level, permute the outputs to make them be in the # for each feature level, permute the outputs to make them be in the
# same format as the labels. Note that the labels are computed for # same format as the labels. Note that the labels are computed for
# all feature levels concatenated, so we keep the same representation # all feature levels concatenated, so we keep the same representation
# for the objectness and the box_regression # for the objectness and the box_regression
last_C = torch.jit.annotate(Optional[int], None)
for box_cls_per_level, box_regression_per_level in zip( for box_cls_per_level, box_regression_per_level in zip(
box_cls, box_regression box_cls, box_regression
): ):
...@@ -207,14 +229,16 @@ def concat_box_prediction_layers(box_cls, box_regression): ...@@ -207,14 +229,16 @@ def concat_box_prediction_layers(box_cls, box_regression):
) )
box_cls_flattened.append(box_cls_per_level) box_cls_flattened.append(box_cls_per_level)
last_C = C
box_regression_per_level = permute_and_flatten( box_regression_per_level = permute_and_flatten(
box_regression_per_level, N, A, 4, H, W box_regression_per_level, N, A, 4, H, W
) )
box_regression_flattened.append(box_regression_per_level) box_regression_flattened.append(box_regression_per_level)
assert last_C is not None
# concatenate on the first dimension (representing the feature levels), to # concatenate on the first dimension (representing the feature levels), to
# take into account the way the labels were generated (with all feature maps # take into account the way the labels were generated (with all feature maps
# being concatenated as well) # being concatenated as well)
box_cls = torch.cat(box_cls_flattened, dim=1).reshape(-1, C) box_cls = torch.cat(box_cls_flattened, dim=1).reshape(-1, last_C)
box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4) box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4)
return box_cls, box_regression return box_cls, box_regression
...@@ -244,6 +268,13 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -244,6 +268,13 @@ class RegionProposalNetwork(torch.nn.Module):
nms_thresh (float): NMS threshold used for postprocessing the RPN proposals nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
""" """
__annotations__ = {
'box_coder': det_utils.BoxCoder,
'proposal_matcher': det_utils.Matcher,
'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler,
'pre_nms_top_n': Dict[str, int],
'post_nms_top_n': Dict[str, int],
}
def __init__(self, def __init__(self,
anchor_generator, anchor_generator,
...@@ -276,24 +307,23 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -276,24 +307,23 @@ class RegionProposalNetwork(torch.nn.Module):
self.nms_thresh = nms_thresh self.nms_thresh = nms_thresh
self.min_size = 1e-3 self.min_size = 1e-3
@property
def pre_nms_top_n(self): def pre_nms_top_n(self):
if self.training: if self.training:
return self._pre_nms_top_n['training'] return self._pre_nms_top_n['training']
return self._pre_nms_top_n['testing'] return self._pre_nms_top_n['testing']
@property
def post_nms_top_n(self): def post_nms_top_n(self):
if self.training: if self.training:
return self._post_nms_top_n['training'] return self._post_nms_top_n['training']
return self._post_nms_top_n['testing'] return self._post_nms_top_n['testing']
def assign_targets_to_anchors(self, anchors, targets): def assign_targets_to_anchors(self, anchors, targets):
# type: (List[Tensor], List[Dict[str, Tensor]])
labels = [] labels = []
matched_gt_boxes = [] matched_gt_boxes = []
for anchors_per_image, targets_per_image in zip(anchors, targets): for anchors_per_image, targets_per_image in zip(anchors, targets):
gt_boxes = targets_per_image["boxes"] gt_boxes = targets_per_image["boxes"]
match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image) match_quality_matrix = box_ops.box_iou(gt_boxes, anchors_per_image)
matched_idxs = self.proposal_matcher(match_quality_matrix) matched_idxs = self.proposal_matcher(match_quality_matrix)
# get the targets corresponding GT for each proposal # get the targets corresponding GT for each proposal
# NB: need to clamp the indices because we can have a single # NB: need to clamp the indices because we can have a single
...@@ -306,31 +336,33 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -306,31 +336,33 @@ class RegionProposalNetwork(torch.nn.Module):
# Background (negative examples) # Background (negative examples)
bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
labels_per_image[bg_indices] = 0 labels_per_image[bg_indices] = torch.tensor(0)
# discard indices that are between thresholds # discard indices that are between thresholds
inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
labels_per_image[inds_to_discard] = -1 labels_per_image[inds_to_discard] = torch.tensor(-1)
labels.append(labels_per_image) labels.append(labels_per_image)
matched_gt_boxes.append(matched_gt_boxes_per_image) matched_gt_boxes.append(matched_gt_boxes_per_image)
return labels, matched_gt_boxes return labels, matched_gt_boxes
def _get_top_n_idx(self, objectness, num_anchors_per_level): def _get_top_n_idx(self, objectness, num_anchors_per_level):
# type: (Tensor, List[int])
r = [] r = []
offset = 0 offset = 0
for ob in objectness.split(num_anchors_per_level, 1): for ob in objectness.split(num_anchors_per_level, 1):
if torchvision._is_tracing(): if torchvision._is_tracing():
num_anchors, pre_nms_top_n = _onnx_get_num_anchors_and_pre_nms_top_n(ob, self.pre_nms_top_n) num_anchors, pre_nms_top_n = _onnx_get_num_anchors_and_pre_nms_top_n(ob, self.pre_nms_top_n())
else: else:
num_anchors = ob.shape[1] num_anchors = ob.shape[1]
pre_nms_top_n = min(self.pre_nms_top_n, num_anchors) pre_nms_top_n = min(self.pre_nms_top_n(), num_anchors)
_, top_n_idx = ob.topk(pre_nms_top_n, dim=1) _, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
r.append(top_n_idx + offset) r.append(top_n_idx + offset)
offset += num_anchors offset += num_anchors
return torch.cat(r, dim=1) return torch.cat(r, dim=1)
def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level): def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level):
# type: (Tensor, Tensor, List[Tuple[int, int]], List[int])
num_images = proposals.shape[0] num_images = proposals.shape[0]
device = proposals.device device = proposals.device
# do not backprop throught objectness # do not backprop throught objectness
...@@ -346,7 +378,10 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -346,7 +378,10 @@ class RegionProposalNetwork(torch.nn.Module):
# select top_n boxes independently per level before applying nms # select top_n boxes independently per level before applying nms
top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level) top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)
batch_idx = torch.arange(num_images, device=device)[:, None]
image_range = torch.arange(num_images, device=device)
batch_idx = image_range[:, None]
objectness = objectness[batch_idx, top_n_idx] objectness = objectness[batch_idx, top_n_idx]
levels = levels[batch_idx, top_n_idx] levels = levels[batch_idx, top_n_idx]
proposals = proposals[batch_idx, top_n_idx] proposals = proposals[batch_idx, top_n_idx]
...@@ -360,13 +395,14 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -360,13 +395,14 @@ class RegionProposalNetwork(torch.nn.Module):
# non-maximum suppression, independently done per level # non-maximum suppression, independently done per level
keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh) keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
# keep only topk scoring predictions # keep only topk scoring predictions
keep = keep[:self.post_nms_top_n] keep = keep[:self.post_nms_top_n()]
boxes, scores = boxes[keep], scores[keep] boxes, scores = boxes[keep], scores[keep]
final_boxes.append(boxes) final_boxes.append(boxes)
final_scores.append(scores) final_scores.append(scores)
return final_boxes, final_scores return final_boxes, final_scores
def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets): def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets):
# type: (Tensor, Tensor, List[Tensor], List[Tensor])
""" """
Arguments: Arguments:
objectness (Tensor) objectness (Tensor)
...@@ -403,6 +439,7 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -403,6 +439,7 @@ class RegionProposalNetwork(torch.nn.Module):
return objectness_loss, box_loss return objectness_loss, box_loss
def forward(self, images, features, targets=None): def forward(self, images, features, targets=None):
# type: (ImageList, Dict[str, Tensor], Optional[List[Dict[str, Tensor]]])
""" """
Arguments: Arguments:
images (ImageList): images for which we want to compute the predictions images (ImageList): images for which we want to compute the predictions
...@@ -437,6 +474,7 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -437,6 +474,7 @@ class RegionProposalNetwork(torch.nn.Module):
losses = {} losses = {}
if self.training: if self.training:
assert targets is not None
labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets) labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
regression_targets = self.box_coder.encode(matched_gt_boxes, anchors) regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
loss_objectness, loss_rpn_box_reg = self.compute_loss( loss_objectness, loss_rpn_box_reg = self.compute_loss(
......
from __future__ import division
import random import random
import math import math
import torch import torch
from torch import nn from torch import nn, Tensor
import torchvision import torchvision
from torch.jit.annotations import List, Tuple, Dict, Optional
from torchvision.ops import misc as misc_nn_ops from torchvision.ops import misc as misc_nn_ops
from .image_list import ImageList from .image_list import ImageList
...@@ -31,22 +34,29 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -31,22 +34,29 @@ class GeneralizedRCNNTransform(nn.Module):
self.image_std = image_std self.image_std = image_std
def forward(self, images, targets=None): def forward(self, images, targets=None):
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
images = [img for img in images] images = [img for img in images]
for i in range(len(images)): for i in range(len(images)):
image = images[i] image = images[i]
target = targets[i] if targets is not None else targets target_index = targets[i] if targets is not None else None
if image.dim() != 3: if image.dim() != 3:
raise ValueError("images is expected to be a list of 3d tensors " raise ValueError("images is expected to be a list of 3d tensors "
"of shape [C, H, W], got {}".format(image.shape)) "of shape [C, H, W], got {}".format(image.shape))
image = self.normalize(image) image = self.normalize(image)
image, target = self.resize(image, target) image, target_index = self.resize(image, target_index)
images[i] = image images[i] = image
if targets is not None: if targets is not None and target_index is not None:
targets[i] = target targets[i] = target_index
image_sizes = [img.shape[-2:] for img in images] image_sizes = [img.shape[-2:] for img in images]
images = self.batch_images(images) images = self.batch_images(images)
image_list = ImageList(images, image_sizes) image_sizes_list = torch.jit.annotate(List[Tuple[int, int]], [])
for image_size in image_sizes:
assert len(image_size) == 2
image_sizes_list.append((image_size[0], image_size[1]))
image_list = ImageList(images, image_sizes_list)
return image_list, targets return image_list, targets
def normalize(self, image): def normalize(self, image):
...@@ -55,16 +65,27 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -55,16 +65,27 @@ class GeneralizedRCNNTransform(nn.Module):
std = torch.as_tensor(self.image_std, dtype=dtype, device=device) std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
return (image - mean[:, None, None]) / std[:, None, None] return (image - mean[:, None, None]) / std[:, None, None]
def torch_choice(self, l):
# type: (List[int])
"""
Implements `random.choice` via torch ops so it can be compiled with
TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
is fixed.
"""
index = int(torch.empty(1).uniform_(0., float(len(l))).item())
return l[index]
def resize(self, image, target): def resize(self, image, target):
# type: (Tensor, Optional[Dict[str, Tensor]])
h, w = image.shape[-2:] h, w = image.shape[-2:]
im_shape = torch.tensor(image.shape[-2:]) im_shape = torch.tensor(image.shape[-2:])
min_size = float(torch.min(im_shape)) min_size = float(torch.min(im_shape))
max_size = float(torch.max(im_shape)) max_size = float(torch.max(im_shape))
if self.training: if self.training:
size = random.choice(self.min_size) size = float(self.torch_choice(self.min_size))
else: else:
# FIXME assume for now that testing uses the largest scale # FIXME assume for now that testing uses the largest scale
size = self.min_size[-1] size = float(self.min_size[-1])
scale_factor = size / min_size scale_factor = size / min_size
if max_size * scale_factor > self.max_size: if max_size * scale_factor > self.max_size:
scale_factor = self.max_size / max_size scale_factor = self.max_size / max_size
...@@ -91,7 +112,9 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -91,7 +112,9 @@ class GeneralizedRCNNTransform(nn.Module):
# _onnx_batch_images() is an implementation of # _onnx_batch_images() is an implementation of
# batch_images() that is supported by ONNX tracing. # batch_images() that is supported by ONNX tracing.
@torch.jit.unused
def _onnx_batch_images(self, images, size_divisible=32): def _onnx_batch_images(self, images, size_divisible=32):
# type: (List[Tensor], int) -> Tensor
max_size = [] max_size = []
for i in range(images[0].dim()): for i in range(images[0].dim()):
max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64) max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
...@@ -112,27 +135,36 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -112,27 +135,36 @@ class GeneralizedRCNNTransform(nn.Module):
return torch.stack(padded_imgs) return torch.stack(padded_imgs)
def max_by_axis(self, the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
def batch_images(self, images, size_divisible=32): def batch_images(self, images, size_divisible=32):
# type: (List[Tensor], int)
if torchvision._is_tracing(): if torchvision._is_tracing():
# batch_images() does not export well to ONNX # batch_images() does not export well to ONNX
# call _onnx_batch_images() instead # call _onnx_batch_images() instead
return self._onnx_batch_images(images, size_divisible) return self._onnx_batch_images(images, size_divisible)
max_size = tuple(max(s) for s in zip(*[img.shape for img in images])) max_size = self.max_by_axis([list(img.shape) for img in images])
stride = size_divisible stride = float(size_divisible)
max_size = list(max_size) max_size = list(max_size)
max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride) max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride)
max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride) max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride)
max_size = tuple(max_size)
batch_shape = (len(images),) + max_size batch_shape = [len(images)] + max_size
batched_imgs = images[0].new(*batch_shape).zero_() batched_imgs = images[0].new_full(batch_shape, 0)
for img, pad_img in zip(images, batched_imgs): for img, pad_img in zip(images, batched_imgs):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
return batched_imgs return batched_imgs
def postprocess(self, result, image_shapes, original_image_sizes): def postprocess(self, result, image_shapes, original_image_sizes):
# type: (List[Dict[str, Tensor]], List[Tuple[int, int]], List[Tuple[int, int]])
if self.training: if self.training:
return result return result
for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)): for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
...@@ -151,7 +183,8 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -151,7 +183,8 @@ class GeneralizedRCNNTransform(nn.Module):
def resize_keypoints(keypoints, original_size, new_size): def resize_keypoints(keypoints, original_size, new_size):
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)) # type: (Tensor, List[int], List[int])
ratios = [float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)]
ratio_h, ratio_w = ratios ratio_h, ratio_w = ratios
resized_data = keypoints.clone() resized_data = keypoints.clone()
if torch._C._get_tracing_state(): if torch._C._get_tracing_state():
...@@ -165,7 +198,8 @@ def resize_keypoints(keypoints, original_size, new_size): ...@@ -165,7 +198,8 @@ def resize_keypoints(keypoints, original_size, new_size):
def resize_boxes(boxes, original_size, new_size): def resize_boxes(boxes, original_size, new_size):
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)) # type: (Tensor, List[int], List[int])
ratios = [float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)]
ratio_height, ratio_width = ratios ratio_height, ratio_width = ratios
xmin, ymin, xmax, ymax = boxes.unbind(1) xmin, ymin, xmax, ymax = boxes.unbind(1)
......
from .boxes import nms, box_iou from .boxes import nms, box_iou
from .new_empty_tensor import _new_empty_tensor
from .roi_align import roi_align, RoIAlign from .roi_align import roi_align, RoIAlign
from .roi_pool import roi_pool, RoIPool from .roi_pool import roi_pool, RoIPool
from .ps_roi_align import ps_roi_align, PSRoIAlign from .ps_roi_align import ps_roi_align, PSRoIAlign
...@@ -12,7 +13,7 @@ _register_custom_op() ...@@ -12,7 +13,7 @@ _register_custom_op()
__all__ = [ __all__ = [
'nms', 'roi_align', 'RoIAlign', 'roi_pool', 'RoIPool', 'nms', 'roi_align', 'RoIAlign', 'roi_pool', 'RoIPool', '_new_empty_tensor',
'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool', 'PSRoIPool', 'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool', 'PSRoIPool',
'MultiScaleRoIAlign', 'FeaturePyramidNetwork' 'MultiScaleRoIAlign', 'FeaturePyramidNetwork'
] ]
from __future__ import division
import torch import torch
from torch.jit.annotations import Tuple
from torch import Tensor
def nms(boxes, scores, iou_threshold): def nms(boxes, scores, iou_threshold):
# type: (Tensor, Tensor, float)
""" """
Performs non-maximum suppression (NMS) on the boxes according Performs non-maximum suppression (NMS) on the boxes according
to their intersection-over-union (IoU). to their intersection-over-union (IoU).
...@@ -32,6 +37,7 @@ def nms(boxes, scores, iou_threshold): ...@@ -32,6 +37,7 @@ def nms(boxes, scores, iou_threshold):
def batched_nms(boxes, scores, idxs, iou_threshold): def batched_nms(boxes, scores, idxs, iou_threshold):
# type: (Tensor, Tensor, Tensor, float)
""" """
Performs non-maximum suppression in a batched fashion. Performs non-maximum suppression in a batched fashion.
...@@ -72,12 +78,13 @@ def batched_nms(boxes, scores, idxs, iou_threshold): ...@@ -72,12 +78,13 @@ def batched_nms(boxes, scores, idxs, iou_threshold):
def remove_small_boxes(boxes, min_size): def remove_small_boxes(boxes, min_size):
# type: (Tensor, float)
""" """
Remove boxes which contains at least one side smaller than min_size. Remove boxes which contains at least one side smaller than min_size.
Arguments: Arguments:
boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
min_size (int): minimum size min_size (float): minimum size
Returns: Returns:
keep (Tensor[K]): indices of the boxes that have both sides keep (Tensor[K]): indices of the boxes that have both sides
...@@ -90,6 +97,7 @@ def remove_small_boxes(boxes, min_size): ...@@ -90,6 +97,7 @@ def remove_small_boxes(boxes, min_size):
def clip_boxes_to_image(boxes, size): def clip_boxes_to_image(boxes, size):
# type: (Tensor, Tuple[int, int])
""" """
Clip boxes so that they lie inside an image of size `size`. Clip boxes so that they lie inside an image of size `size`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment