Commit d88d8961 authored by eellison's avatar eellison Committed by Francisco Massa
Browse files

Make maskrcnn scriptable (#1407)

* almost working...

* respond to comments

* add empty tensor op, handle different output types in generalized rcnn

* clean ups

* address comments

* more changes

* it's working!

* torchscript bugs

* add script/ eager test

* eval script model

* fix flake

* division import

* py2 compat

* update test, fix arange bug

* import division statement

* fix linter

* fixes

* changes needed for JIT master

* cleanups

* remove imagelist_to

* requested changes

* Make FPN backwards-compatible and torchscript compatible

We remove support for feature channels=0, but support for it was already a bit limited

* Fix ONNX regression
parent b590f8c6
......@@ -208,7 +208,7 @@ def patched_make_field(self, types, domain, items, **kw):
# `kw` catches `env=None` needed for newer sphinx while maintaining
# backwards compatibility when passed along further down!
# type: (list, unicode, tuple) -> nodes.field
# type: (list, unicode, tuple) -> nodes.field # noqa: F821
def handle_item(fieldarg, content):
par = nodes.paragraph()
par += addnodes.literal_strong('', fieldarg) # Patch: this line added
......
......@@ -15,11 +15,11 @@ class ResnetFPNBackboneTester(unittest.TestCase):
x = torch.rand(1, 3, 300, 300, dtype=self.dtype, device=device)
resnet18_fpn = resnet_fpn_backbone(backbone_name='resnet18', pretrained=False)
y = resnet18_fpn(x)
self.assertEqual(list(y.keys()), [0, 1, 2, 3, 'pool'])
self.assertEqual(list(y.keys()), ['0', '1', '2', '3', 'pool'])
def test_resnet50_fpn_backbone(self):
device = torch.device('cpu')
x = torch.rand(1, 3, 300, 300, dtype=self.dtype, device=device)
resnet50_fpn = resnet_fpn_backbone(backbone_name='resnet50', pretrained=False)
y = resnet50_fpn(x)
self.assertEqual(list(y.keys()), [0, 1, 2, 3, 'pool'])
self.assertEqual(list(y.keys()), ['0', '1', '2', '3', 'pool'])
......@@ -51,7 +51,10 @@ script_test_models = [
"squeezenet1_0",
"vgg11",
"inception_v3",
'r3d_18',
"r3d_18",
"fasterrcnn_resnet50_fpn",
"maskrcnn_resnet50_fpn",
"keypointrcnn_resnet50_fpn",
]
......@@ -95,7 +98,6 @@ class ModelTester(TestCase):
def _test_detection_model(self, name):
set_rng_seed(0)
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False)
self.check_script(model, name)
model.eval()
input_shape = (3, 300, 300)
x = torch.rand(input_shape)
......@@ -130,9 +132,19 @@ class ModelTester(TestCase):
else:
self.assertExpected(map_nested_tensor_object(out, tensor_map_fn=subsample_tensor))
scripted_model = torch.jit.script(model)
scripted_model.eval()
scripted_out = scripted_model(model_input)[1]
self.assertNestedTensorObjectsEqual(scripted_out[0]["boxes"], out[0]["boxes"])
self.assertNestedTensorObjectsEqual(scripted_out[0]["scores"], out[0]["scores"])
# labels currently float in script: need to investigate (though same result)
self.assertNestedTensorObjectsEqual(scripted_out[0]["labels"].to(dtype=torch.long), out[0]["labels"])
self.assertTrue("boxes" in out[0])
self.assertTrue("scores" in out[0])
self.assertTrue("labels" in out[0])
# don't check script because we are compiling it here:
# TODO: refactor tests
# self.check_script(model, name)
def _test_video_model(self, name):
# the default input shape is
......
......@@ -367,5 +367,14 @@ class NMSTester(unittest.TestCase):
self.assertTrue(torch.allclose(r_cpu, r_cuda.cpu()), err_msg.format(iou))
class NewEmptyTensorTester(unittest.TestCase):
def test_new_empty_tensor(self):
input = torch.tensor([2., 2.], requires_grad=True)
new_shape = [3, 3]
out = torch.ops.torchvision._new_empty_tensor_op(input, new_shape)
assert out.size() == torch.Size([3, 3])
assert out.dtype == input.dtype
if __name__ == '__main__':
unittest.main()
......@@ -8,6 +8,7 @@ from torchvision import utils
from torchvision import io
from .extension import _HAS_OPS
import torch
try:
from .version import __version__ # noqa: F401
......@@ -70,5 +71,4 @@ def get_video_backend():
def _is_tracing():
import torch
return torch._C._get_tracing_state()
#pragma once
// All pure C++ headers for the C++ frontend.
#include <torch/all.h>
// Python bindings for the C++ frontend (includes Python.h).
#include <torch/python.h>
using namespace at;
using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class NewEmptyTensorOp : public torch::autograd::Function<NewEmptyTensorOp> {
public:
static variable_list forward(
AutogradContext* ctx,
Variable input,
c10::List<int64_t> new_shape) {
ctx->saved_data["shape"] = input.sizes();
std::vector<int64_t> shape(new_shape.begin(), new_shape.end());
return {input.new_empty(shape, TensorOptions())};
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
// Use data saved in forward
auto shape = ctx->saved_data["shape"].toIntList();
auto out = forward(ctx, grad_output[0], shape);
return {out[0], at::Tensor()};
}
};
Tensor new_empty_tensor(const Tensor& input, c10::List<int64_t> shape) {
return NewEmptyTensorOp::apply(input, shape)[0];
}
......@@ -9,6 +9,7 @@
#include "PSROIPool.h"
#include "ROIAlign.h"
#include "ROIPool.h"
#include "empty_tensor_op.h"
#include "nms.h"
// If we are in a Windows environment, we need to define
......@@ -43,6 +44,7 @@ static auto registry =
.op("torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> Tensor",
&roi_align)
.op("torchvision::roi_pool", &roi_pool)
.op("torchvision::_new_empty_tensor_op", &new_empty_tensor)
.op("torchvision::ps_roi_align", &ps_roi_align)
.op("torchvision::ps_roi_pool", &ps_roi_pool)
.op("torchvision::_cuda_version", &_cuda_version);
......@@ -37,7 +37,6 @@ class IntermediateLayerGetter(nn.ModuleDict):
>>> ('feat2', torch.Size([1, 256, 14, 14]))]
"""
_version = 2
__constants__ = ['layers']
__annotations__ = {
"return_layers": Dict[str, str],
}
......@@ -46,7 +45,7 @@ class IntermediateLayerGetter(nn.ModuleDict):
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model")
orig_return_layers = return_layers
return_layers = {k: v for k, v in return_layers.items()}
return_layers = {str(k): str(v) for k, v in return_layers.items()}
layers = OrderedDict()
for name, module in model.named_children():
layers[name] = module
......
......@@ -3,15 +3,28 @@ from __future__ import division
import math
import torch
from torch.jit.annotations import List, Tuple
from torch import Tensor
import torchvision
# TODO: https://github.com/pytorch/pytorch/issues/26727
def zeros_like(tensor, dtype):
# type: (Tensor, int) -> Tensor
if tensor.dtype == dtype:
return tensor.detach().clone()
else:
return tensor.to(dtype)
@torch.jit.script
class BalancedPositiveNegativeSampler(object):
"""
This class samples batches, ensuring that they contain a fixed proportion of positives
"""
def __init__(self, batch_size_per_image, positive_fraction):
# type: (int, float)
"""
Arguments:
batch_size_per_image (int): number of elements to be selected per image
......@@ -21,6 +34,7 @@ class BalancedPositiveNegativeSampler(object):
self.positive_fraction = positive_fraction
def __call__(self, matched_idxs):
# type: (List[Tensor])
"""
Arguments:
matched idxs: list of tensors containing -1, 0 or positive values.
......@@ -57,14 +71,15 @@ class BalancedPositiveNegativeSampler(object):
neg_idx_per_image = negative[perm2]
# create binary mask from indices
pos_idx_per_image_mask = torch.zeros_like(
pos_idx_per_image_mask = zeros_like(
matched_idxs_per_image, dtype=torch.uint8
)
neg_idx_per_image_mask = torch.zeros_like(
neg_idx_per_image_mask = zeros_like(
matched_idxs_per_image, dtype=torch.uint8
)
pos_idx_per_image_mask[pos_idx_per_image] = 1
neg_idx_per_image_mask[neg_idx_per_image] = 1
pos_idx_per_image_mask[pos_idx_per_image] = torch.tensor(1)
neg_idx_per_image_mask[neg_idx_per_image] = torch.tensor(1)
pos_idx.append(pos_idx_per_image_mask)
neg_idx.append(neg_idx_per_image_mask)
......@@ -120,6 +135,7 @@ def encode_boxes(reference_boxes, proposals, weights):
return targets
@torch.jit.script
class BoxCoder(object):
"""
This class encodes and decodes a set of bounding boxes into
......@@ -127,6 +143,7 @@ class BoxCoder(object):
"""
def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
# type: (Tuple[float, float, float, float], float)
"""
Arguments:
weights (4-element tuple)
......@@ -136,6 +153,7 @@ class BoxCoder(object):
self.bbox_xform_clip = bbox_xform_clip
def encode(self, reference_boxes, proposals):
# type: (List[Tensor], List[Tensor])
boxes_per_image = [len(b) for b in reference_boxes]
reference_boxes = torch.cat(reference_boxes, dim=0)
proposals = torch.cat(proposals, dim=0)
......@@ -159,16 +177,18 @@ class BoxCoder(object):
return targets
def decode(self, rel_codes, boxes):
# type: (Tensor, List[Tensor])
assert isinstance(boxes, (list, tuple))
if isinstance(rel_codes, (list, tuple)):
rel_codes = torch.cat(rel_codes, dim=0)
assert isinstance(rel_codes, torch.Tensor)
boxes_per_image = [b.size(0) for b in boxes]
concat_boxes = torch.cat(boxes, dim=0)
box_sum = 0
for val in boxes_per_image:
box_sum += val
pred_boxes = self.decode_single(
rel_codes.reshape(sum(boxes_per_image), -1), concat_boxes
rel_codes.reshape(box_sum, -1), concat_boxes
)
return pred_boxes.reshape(sum(boxes_per_image), -1, 4)
return pred_boxes.reshape(box_sum, -1, 4)
def decode_single(self, rel_codes, boxes):
"""
......@@ -210,6 +230,7 @@ class BoxCoder(object):
return pred_boxes
@torch.jit.script
class Matcher(object):
"""
This class assigns to each predicted "element" (e.g., a box) a ground-truth
......@@ -228,7 +249,13 @@ class Matcher(object):
BELOW_LOW_THRESHOLD = -1
BETWEEN_THRESHOLDS = -2
__annotations__ = {
'BELOW_LOW_THRESHOLD': int,
'BETWEEN_THRESHOLDS': int,
}
def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False):
# type: (float, float, bool)
"""
Args:
high_threshold (float): quality values greater than or equal to
......@@ -242,6 +269,8 @@ class Matcher(object):
for predictions that have only low-quality match candidates. See
set_low_quality_matches_ for more details.
"""
self.BELOW_LOW_THRESHOLD = -1
self.BETWEEN_THRESHOLDS = -2
assert low_threshold <= high_threshold
self.high_threshold = high_threshold
self.low_threshold = low_threshold
......@@ -274,16 +303,19 @@ class Matcher(object):
matched_vals, matches = match_quality_matrix.max(dim=0)
if self.allow_low_quality_matches:
all_matches = matches.clone()
else:
all_matches = None
# Assign candidate matches with low quality to negative (unassigned) values
below_low_threshold = matched_vals < self.low_threshold
between_thresholds = (matched_vals >= self.low_threshold) & (
matched_vals < self.high_threshold
)
matches[below_low_threshold] = Matcher.BELOW_LOW_THRESHOLD
matches[between_thresholds] = Matcher.BETWEEN_THRESHOLDS
matches[below_low_threshold] = torch.tensor(self.BELOW_LOW_THRESHOLD)
matches[between_thresholds] = torch.tensor(self.BETWEEN_THRESHOLDS)
if self.allow_low_quality_matches:
assert all_matches is not None
self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
return matches
......
......@@ -7,14 +7,12 @@ from .._utils import IntermediateLayerGetter
from .. import resnet
class BackboneWithFPN(nn.Sequential):
class BackboneWithFPN(nn.Module):
"""
Adds a FPN on top of a model.
Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
extract a submodel that returns the feature maps specified in return_layers.
The same limitations of IntermediatLayerGetter apply here.
Arguments:
backbone (nn.Module)
return_layers (Dict[name, new_name]): a dict containing the names
......@@ -24,21 +22,24 @@ class BackboneWithFPN(nn.Sequential):
in_channels_list (List[int]): number of channels for each feature map
that is returned, in the order they are present in the OrderedDict
out_channels (int): number of channels in the FPN.
Attributes:
out_channels (int): the number of channels in the FPN
"""
def __init__(self, backbone, return_layers, in_channels_list, out_channels):
body = IntermediateLayerGetter(backbone, return_layers=return_layers)
fpn = FeaturePyramidNetwork(
super(BackboneWithFPN, self).__init__()
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.fpn = FeaturePyramidNetwork(
in_channels_list=in_channels_list,
out_channels=out_channels,
extra_blocks=LastLevelMaxPool(),
)
super(BackboneWithFPN, self).__init__(OrderedDict(
[("body", body), ("fpn", fpn)]))
self.out_channels = out_channels
def forward(self, x):
x = self.body(x)
x = self.fpn(x)
return x
def resnet_fpn_backbone(backbone_name, pretrained):
backbone = resnet.__dict__[backbone_name](
......@@ -49,7 +50,7 @@ def resnet_fpn_backbone(backbone_name, pretrained):
if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
parameter.requires_grad_(False)
return_layers = {'layer1': 0, 'layer2': 1, 'layer3': 2, 'layer4': 3}
return_layers = {'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'}
in_channels_stage2 = backbone.inplanes // 8
in_channels_list = [
......
......@@ -199,7 +199,7 @@ class FasterRCNN(GeneralizedRCNN):
if box_roi_pool is None:
box_roi_pool = MultiScaleRoIAlign(
featmap_names=[0, 1, 2, 3],
featmap_names=['0', '1', '2', '3'],
output_size=7,
sampling_ratio=2)
......@@ -273,7 +273,7 @@ class FastRCNNPredictor(nn.Module):
self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
def forward(self, x):
if x.ndimension() == 4:
if x.dim() == 4:
assert list(x.shape[2:]) == [1, 1]
x = x.flatten(start_dim=1)
scores = self.cls_score(x)
......
......@@ -6,6 +6,9 @@ Implements the Generalized R-CNN framework
from collections import OrderedDict
import torch
from torch import nn
import warnings
from torch.jit.annotations import Tuple, List, Dict, Optional
from torch import Tensor
class GeneralizedRCNN(nn.Module):
......@@ -28,7 +31,16 @@ class GeneralizedRCNN(nn.Module):
self.rpn = rpn
self.roi_heads = roi_heads
@torch.jit.unused
def eager_outputs(self, losses, detections):
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
if self.training:
return losses
return detections
def forward(self, images, targets=None):
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
"""
Arguments:
images (list[Tensor]): images to be processed
......@@ -43,7 +55,12 @@ class GeneralizedRCNN(nn.Module):
"""
if self.training and targets is None:
raise ValueError("In training mode, targets should be passed")
original_image_sizes = [img.shape[-2:] for img in images]
original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
for img in images:
val = img.shape[-2:]
assert len(val) == 2
original_image_sizes.append((val[0], val[1]))
images, targets = self.transform(images, targets)
features = self.backbone(images.tensors)
if isinstance(features, torch.Tensor):
......@@ -56,7 +73,8 @@ class GeneralizedRCNN(nn.Module):
losses.update(detector_losses)
losses.update(proposal_losses)
if self.training:
return losses
return detections
if torch.jit.is_scripting():
warnings.warn("RCNN always returns a (Losses, Detections tuple in scripting)")
return (losses, detections)
else:
return self.eager_outputs(losses, detections)
......@@ -2,6 +2,8 @@
from __future__ import division
import torch
from torch.jit.annotations import List, Tuple
from torch import Tensor
class ImageList(object):
......@@ -13,6 +15,7 @@ class ImageList(object):
"""
def __init__(self, tensors, image_sizes):
# type: (Tensor, List[Tuple[int, int]])
"""
Arguments:
tensors (tensor)
......@@ -21,6 +24,7 @@ class ImageList(object):
self.tensors = tensors
self.image_sizes = image_sizes
def to(self, *args, **kwargs):
cast_tensor = self.tensors.to(*args, **kwargs)
def to(self, device):
# type: (Device) # noqa
cast_tensor = self.tensors.to(device)
return ImageList(cast_tensor, self.image_sizes)
......@@ -2,6 +2,7 @@ import torch
from torch import nn
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import MultiScaleRoIAlign
from ..utils import load_state_dict_from_url
......@@ -179,7 +180,7 @@ class KeypointRCNN(FasterRCNN):
if keypoint_roi_pool is None:
keypoint_roi_pool = MultiScaleRoIAlign(
featmap_names=[0, 1, 2, 3],
featmap_names=['0', '1', '2', '3'],
output_size=14,
sampling_ratio=2)
......@@ -252,7 +253,7 @@ class KeypointRCNNPredictor(nn.Module):
def forward(self, x):
x = self.kps_score_lowres(x)
x = misc_nn_ops.interpolate(
x, scale_factor=self.up_scale, mode="bilinear", align_corners=False
x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False
)
return x
......
......@@ -178,7 +178,7 @@ class MaskRCNN(FasterRCNN):
if mask_roi_pool is None:
mask_roi_pool = MultiScaleRoIAlign(
featmap_names=[0, 1, 2, 3],
featmap_names=['0', '1', '2', '3'],
output_size=14,
sampling_ratio=2)
......
......@@ -3,16 +3,20 @@ import torch
import torchvision
import torch.nn.functional as F
from torch import nn
from torch import nn, Tensor
from torchvision.ops import boxes as box_ops
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import roi_align
from . import _utils as det_utils
from torch.jit.annotations import Optional, List, Dict, Tuple
def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
# type: (Tensor, Tensor, List[Tensor], List[Tensor])
"""
Computes the loss for Faster R-CNN.
......@@ -51,6 +55,7 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
def maskrcnn_inference(x, labels):
# type: (Tensor, List[Tensor])
"""
From the results of the CNN, post process the masks
by taking the mask corresponding to the class with max
......@@ -77,14 +82,16 @@ def maskrcnn_inference(x, labels):
if len(boxes_per_image) == 1:
# TODO : remove when dynamic split supported in ONNX
mask_prob = (mask_prob,)
# and remove assignment to mask_prob_list, just assign to mask_prob
mask_prob_list = [mask_prob]
else:
mask_prob = mask_prob.split(boxes_per_image, dim=0)
mask_prob_list = mask_prob.split(boxes_per_image, dim=0)
return mask_prob
return mask_prob_list
def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
# type: (Tensor, Tensor, Tensor, int)
"""
Given segmentation masks and the bounding boxes corresponding
to the location of the masks in the image, this function
......@@ -95,10 +102,11 @@ def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
matched_idxs = matched_idxs.to(boxes)
rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
gt_masks = gt_masks[:, None].to(rois)
return roi_align(gt_masks, rois, (M, M), 1)[:, 0]
return roi_align(gt_masks, rois, (M, M), 1.)[:, 0]
def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
# type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor])
"""
Arguments:
proposals (list[BoxList])
......@@ -131,6 +139,7 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs
def keypoints_to_heatmap(keypoints, rois, heatmap_size):
# type: (Tensor, Tensor, int)
offset_x = rois[:, 0]
offset_y = rois[:, 1]
scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
......@@ -152,8 +161,8 @@ def keypoints_to_heatmap(keypoints, rois, heatmap_size):
y = (y - offset_y) * scale_y
y = y.floor().long()
x[x_boundary_inds] = heatmap_size - 1
y[y_boundary_inds] = heatmap_size - 1
x[x_boundary_inds] = torch.tensor(heatmap_size - 1)
y[y_boundary_inds] = torch.tensor(heatmap_size - 1)
valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
vis = keypoints[..., 2] > 0
......@@ -217,6 +226,17 @@ def _onnx_heatmaps_to_keypoints_loop(maps, rois, widths_ceil, heights_ceil,
return xy_preds, end_scores
# workaround for issue pytorch 27512
def tensor_floordiv(tensor, int_div):
# type: (Tensor, int)
result = tensor / int_div
# TODO: https://github.com/pytorch/pytorch/issues/26731
floating_point_types = (torch.float, torch.double, torch.half)
if result.dtype in floating_point_types:
result = result.trunc()
return result
def heatmaps_to_keypoints(maps, rois):
"""Extract predicted keypoint locations from heatmaps. Output has shape
(#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
......@@ -258,8 +278,9 @@ def heatmaps_to_keypoints(maps, rois):
# roi_map_probs = scores_to_probs(roi_map.copy())
w = roi_map.shape[2]
pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
x_int = pos % w
y_int = (pos - x_int) // w
y_int = tensor_floordiv((pos - x_int), w)
# assert (roi_map_probs[k, y_int, x_int] ==
# roi_map_probs[k, :, :].max())
x = (x_int.float() + 0.5) * width_correction
......@@ -273,6 +294,7 @@ def heatmaps_to_keypoints(maps, rois):
def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
# type: (Tensor, List[Tensor], List[Tensor], List[Tensor])
N, K, H, W = keypoint_logits.shape
assert H == W
discretization_size = H
......@@ -302,6 +324,7 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched
def keypointrcnn_inference(x, boxes):
# type: (Tensor, List[Tensor])
kp_probs = []
kp_scores = []
......@@ -323,6 +346,7 @@ def keypointrcnn_inference(x, boxes):
def _onnx_expand_boxes(boxes, scale):
# type: (Tensor, float)
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
......@@ -343,6 +367,7 @@ def _onnx_expand_boxes(boxes, scale):
# but are kept here for the moment while we need them
# temporarily for paste_mask_in_image
def expand_boxes(boxes, scale):
# type: (Tensor, float)
if torchvision._is_tracing():
return _onnx_expand_boxes(boxes, scale)
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
......@@ -361,10 +386,17 @@ def expand_boxes(boxes, scale):
return boxes_exp
@torch.jit.unused
def expand_masks_tracing_scale(M, padding):
# type: (int, int) -> float
return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
def expand_masks(mask, padding):
# type: (Tensor, int)
M = mask.shape[-1]
if torchvision._is_tracing():
scale = (M + 2 * padding).to(torch.float32) / M.to(torch.float32)
if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why
scale = expand_masks_tracing_scale(M, padding)
else:
scale = float(M + 2 * padding) / M
padded_mask = torch.nn.functional.pad(mask, (padding,) * 4)
......@@ -372,6 +404,7 @@ def expand_masks(mask, padding):
def paste_mask_in_image(mask, box, im_h, im_w):
# type: (Tensor, Tensor, int, int)
TO_REMOVE = 1
w = int(box[2] - box[0] + TO_REMOVE)
h = int(box[3] - box[1] + TO_REMOVE)
......@@ -449,29 +482,33 @@ def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
def paste_masks_in_image(masks, boxes, img_shape, padding=1):
# type: (Tensor, Tensor, Tuple[int, int], int)
masks, scale = expand_masks(masks, padding=padding)
boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
# im_h, im_w = img_shape.tolist()
im_h, im_w = img_shape
if torchvision._is_tracing():
return _onnx_paste_masks_in_image_loop(masks, boxes,
torch.scalar_tensor(im_h, dtype=torch.int64),
torch.scalar_tensor(im_w, dtype=torch.int64))[:, None]
boxes = boxes.tolist()
res = [
paste_mask_in_image(m[0], b, im_h, im_w)
for m, b in zip(masks, boxes)
]
if len(res) > 0:
res = torch.stack(res, dim=0)[:, None]
ret = torch.stack(res, dim=0)[:, None]
else:
res = masks.new_empty((0, 1, im_h, im_w))
return res
ret = masks.new_empty((0, 1, im_h, im_w))
return ret
class RoIHeads(torch.nn.Module):
__annotations__ = {
'box_coder': det_utils.BoxCoder,
'proposal_matcher': det_utils.Matcher,
'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler,
}
def __init__(self,
box_roi_pool,
box_head,
......@@ -525,7 +562,6 @@ class RoIHeads(torch.nn.Module):
self.keypoint_head = keypoint_head
self.keypoint_predictor = keypoint_predictor
@property
def has_mask(self):
if self.mask_roi_pool is None:
return False
......@@ -535,7 +571,6 @@ class RoIHeads(torch.nn.Module):
return False
return True
@property
def has_keypoint(self):
if self.keypoint_roi_pool is None:
return False
......@@ -546,10 +581,12 @@ class RoIHeads(torch.nn.Module):
return True
def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
# type: (List[Tensor], List[Tensor], List[Tensor])
matched_idxs = []
labels = []
for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
match_quality_matrix = self.box_similarity(gt_boxes_in_image, proposals_in_image)
# set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
......@@ -559,17 +596,18 @@ class RoIHeads(torch.nn.Module):
# Label background (below the low threshold)
bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
labels_in_image[bg_inds] = 0
labels_in_image[bg_inds] = torch.tensor(0)
# Label ignore proposals (between low and high thresholds)
ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler
labels_in_image[ignore_inds] = torch.tensor(-1) # -1 is ignored by sampler
matched_idxs.append(clamped_matched_idxs_in_image)
labels.append(labels_in_image)
return matched_idxs, labels
def subsample(self, labels):
# type: (List[Tensor])
sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
sampled_inds = []
for img_idx, (pos_inds_img, neg_inds_img) in enumerate(
......@@ -580,6 +618,7 @@ class RoIHeads(torch.nn.Module):
return sampled_inds
def add_gt_proposals(self, proposals, gt_boxes):
# type: (List[Tensor], List[Tensor])
proposals = [
torch.cat((proposal, gt_box))
for proposal, gt_box in zip(proposals, gt_boxes)
......@@ -587,15 +626,25 @@ class RoIHeads(torch.nn.Module):
return proposals
def DELTEME_all(self, the_list):
# type: (List[bool])
for i in the_list:
if not i:
return False
return True
def check_targets(self, targets):
# type: (Optional[List[Dict[str, Tensor]]])
assert targets is not None
assert all("boxes" in t for t in targets)
assert all("labels" in t for t in targets)
if self.has_mask:
assert all("masks" in t for t in targets)
assert self.DELTEME_all(["boxes" in t for t in targets])
assert self.DELTEME_all(["labels" in t for t in targets])
if self.has_mask():
assert self.DELTEME_all(["masks" in t for t in targets])
def select_training_samples(self, proposals, targets):
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
self.check_targets(targets)
assert targets is not None
dtype = proposals[0].dtype
gt_boxes = [t["boxes"].to(dtype) for t in targets]
gt_labels = [t["labels"] for t in targets]
......@@ -620,6 +669,7 @@ class RoIHeads(torch.nn.Module):
return proposals, matched_idxs, labels, regression_targets
def postprocess_detections(self, class_logits, box_regression, proposals, image_shapes):
# type: (Tensor, Tensor, List[Tensor], List[Tuple[int, int]])
device = class_logits.device
num_classes = class_logits.shape[-1]
......@@ -631,16 +681,17 @@ class RoIHeads(torch.nn.Module):
# split boxes and scores per image
if len(boxes_per_image) == 1:
# TODO : remove this when ONNX support dynamic split sizes
pred_boxes = (pred_boxes,)
pred_scores = (pred_scores,)
# and just assign to pred_boxes instead of pred_boxes_list
pred_boxes_list = [pred_boxes]
pred_scores_list = [pred_scores]
else:
pred_boxes = pred_boxes.split(boxes_per_image, 0)
pred_scores = pred_scores.split(boxes_per_image, 0)
pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
pred_scores_list = pred_scores.split(boxes_per_image, 0)
all_boxes = []
all_scores = []
all_labels = []
for boxes, scores, image_shape in zip(pred_boxes, pred_scores, image_shapes):
for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
# create labels for each prediction
......@@ -678,6 +729,7 @@ class RoIHeads(torch.nn.Module):
return all_boxes, all_scores, all_labels
def forward(self, features, proposals, image_shapes, targets=None):
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]], Optional[List[Dict[str, Tensor]]])
"""
Arguments:
features (List[Tensor])
......@@ -687,38 +739,50 @@ class RoIHeads(torch.nn.Module):
"""
if targets is not None:
for t in targets:
assert t["boxes"].dtype.is_floating_point, 'target boxes must of float type'
# TODO: https://github.com/pytorch/pytorch/issues/26731
floating_point_types = (torch.float, torch.double, torch.half)
assert t["boxes"].dtype in floating_point_types, 'target boxes must of float type'
assert t["labels"].dtype == torch.int64, 'target labels must of int64 type'
if self.has_keypoint:
if self.has_keypoint():
assert t["keypoints"].dtype == torch.float32, 'target keypoints must of float type'
if self.training:
proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
else:
labels = None
regression_targets = None
matched_idxs = None
box_features = self.box_roi_pool(features, proposals, image_shapes)
box_features = self.box_head(box_features)
class_logits, box_regression = self.box_predictor(box_features)
result, losses = [], {}
result = torch.jit.annotate(List[Dict[str, torch.Tensor]], [])
losses = {}
if self.training:
assert labels is not None and regression_targets is not None
loss_classifier, loss_box_reg = fastrcnn_loss(
class_logits, box_regression, labels, regression_targets)
losses = dict(loss_classifier=loss_classifier, loss_box_reg=loss_box_reg)
losses = {
"loss_classifier": loss_classifier,
"loss_box_reg": loss_box_reg
}
else:
boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
num_images = len(boxes)
for i in range(num_images):
result.append(
dict(
boxes=boxes[i],
labels=labels[i],
scores=scores[i],
)
{
"boxes": boxes[i],
"labels": labels[i],
"scores": scores[i],
}
)
if self.has_mask:
if self.has_mask():
mask_proposals = [p["boxes"] for p in result]
if self.training:
assert matched_idxs is not None
# during training, only focus on positive boxes
num_images = len(proposals)
mask_proposals = []
......@@ -727,19 +791,31 @@ class RoIHeads(torch.nn.Module):
pos = torch.nonzero(labels[img_id] > 0).squeeze(1)
mask_proposals.append(proposals[img_id][pos])
pos_matched_idxs.append(matched_idxs[img_id][pos])
else:
pos_matched_idxs = None
mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
mask_features = self.mask_head(mask_features)
mask_logits = self.mask_predictor(mask_features)
if self.mask_roi_pool is not None:
mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
mask_features = self.mask_head(mask_features)
mask_logits = self.mask_predictor(mask_features)
else:
mask_logits = torch.tensor(0)
raise Exception("Expected mask_roi_pool to be not None")
loss_mask = {}
if self.training:
assert targets is not None
assert pos_matched_idxs is not None
assert mask_logits is not None
gt_masks = [t["masks"] for t in targets]
gt_labels = [t["labels"] for t in targets]
loss_mask = maskrcnn_loss(
rcnn_loss_mask = maskrcnn_loss(
mask_logits, mask_proposals,
gt_masks, gt_labels, pos_matched_idxs)
loss_mask = dict(loss_mask=loss_mask)
loss_mask = {
"loss_mask": rcnn_loss_mask
}
else:
labels = [r["labels"] for r in result]
masks_probs = maskrcnn_inference(mask_logits, labels)
......@@ -748,17 +824,23 @@ class RoIHeads(torch.nn.Module):
losses.update(loss_mask)
if self.has_keypoint:
# keep none checks in if conditional so torchscript will conditionally
# compile each branch
if self.keypoint_roi_pool is not None and self.keypoint_head is not None \
and self.keypoint_predictor is not None:
keypoint_proposals = [p["boxes"] for p in result]
if self.training:
# during training, only focus on positive boxes
num_images = len(proposals)
keypoint_proposals = []
pos_matched_idxs = []
assert matched_idxs is not None
for img_id in range(num_images):
pos = torch.nonzero(labels[img_id] > 0).squeeze(1)
keypoint_proposals.append(proposals[img_id][pos])
pos_matched_idxs.append(matched_idxs[img_id][pos])
else:
pos_matched_idxs = None
keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
keypoint_features = self.keypoint_head(keypoint_features)
......@@ -766,12 +848,20 @@ class RoIHeads(torch.nn.Module):
loss_keypoint = {}
if self.training:
assert targets is not None
assert pos_matched_idxs is not None
gt_keypoints = [t["keypoints"] for t in targets]
loss_keypoint = keypointrcnn_loss(
rcnn_loss_keypoint = keypointrcnn_loss(
keypoint_logits, keypoint_proposals,
gt_keypoints, pos_matched_idxs)
loss_keypoint = dict(loss_keypoint=loss_keypoint)
loss_keypoint = {
"loss_keypoint": rcnn_loss_keypoint
}
else:
assert keypoint_logits is not None
assert keypoint_proposals is not None
keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
r["keypoints"] = keypoint_prob
......
from __future__ import division
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from torch.nn import functional as F
from torch import nn
from torch import nn, Tensor
import torchvision
from torchvision.ops import boxes as box_ops
from . import _utils as det_utils
from .image_list import ImageList
from torch.jit.annotations import List, Optional, Dict, Tuple
@torch.jit.unused
def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
# type: (Tensor, int) -> Tuple[int, int]
from torch.onnx import operators
num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0)
# TODO : remove cast to IntTensor/num_anchors.dtype when
......@@ -23,6 +29,11 @@ def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
class AnchorGenerator(nn.Module):
__annotations__ = {
"cell_anchors": Optional[List[torch.Tensor]],
"_cache": Dict[str, List[torch.Tensor]]
}
"""
Module that generates anchors for a set of feature maps and
image sizes.
......@@ -62,8 +73,9 @@ class AnchorGenerator(nn.Module):
self.cell_anchors = None
self._cache = {}
@staticmethod
def generate_anchors(scales, aspect_ratios, dtype=torch.float32, device="cpu"):
# TODO: https://github.com/pytorch/pytorch/issues/26792
def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"):
# type: (List[int], List[float], int, Device) # noqa: F821
scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios)
......@@ -76,8 +88,10 @@ class AnchorGenerator(nn.Module):
return base_anchors.round()
def set_cell_anchors(self, dtype, device):
# type: (int, Device) -> None # noqa: F821
if self.cell_anchors is not None:
return self.cell_anchors
return
cell_anchors = [
self.generate_anchors(
sizes,
......@@ -93,9 +107,13 @@ class AnchorGenerator(nn.Module):
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
def grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[int]])
anchors = []
cell_anchors = self.cell_anchors
assert cell_anchors is not None
for size, stride, base_anchors in zip(
grid_sizes, strides, self.cell_anchors
grid_sizes, strides, cell_anchors
):
grid_height, grid_width = size
stride_height, stride_width = stride
......@@ -122,7 +140,8 @@ class AnchorGenerator(nn.Module):
return anchors
def cached_grid_anchors(self, grid_sizes, strides):
key = tuple(grid_sizes) + tuple(strides)
# type: (List[List[int]], List[List[int]])
key = str(grid_sizes + strides)
if key in self._cache:
return self._cache[key]
anchors = self.grid_anchors(grid_sizes, strides)
......@@ -130,15 +149,14 @@ class AnchorGenerator(nn.Module):
return anchors
def forward(self, image_list, feature_maps):
grid_sizes = tuple([feature_map.shape[-2:] for feature_map in feature_maps])
# type: (ImageList, List[Tensor])
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:]
strides = tuple((float(image_size[0]) / float(g[0]),
float(image_size[1]) / float(g[1]))
for g in grid_sizes)
strides = [[int(image_size[0] / g[0]), int(image_size[1] / g[1])] for g in grid_sizes]
dtype, device = feature_maps[0].dtype, feature_maps[0].device
self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = []
anchors = torch.jit.annotate(List[List[torch.Tensor]], [])
for i, (image_height, image_width) in enumerate(image_list.image_sizes):
anchors_in_image = []
for anchors_per_feature_map in anchors_over_all_feature_maps:
......@@ -172,6 +190,7 @@ class RPNHead(nn.Module):
torch.nn.init.constant_(l.bias, 0)
def forward(self, x):
# type: (List[Tensor])
logits = []
bbox_reg = []
for feature in x:
......@@ -182,6 +201,7 @@ class RPNHead(nn.Module):
def permute_and_flatten(layer, N, A, C, H, W):
# type: (Tensor, int, int, int, int, int)
layer = layer.view(N, -1, C, H, W)
layer = layer.permute(0, 3, 4, 1, 2)
layer = layer.reshape(N, -1, C)
......@@ -189,12 +209,14 @@ def permute_and_flatten(layer, N, A, C, H, W):
def concat_box_prediction_layers(box_cls, box_regression):
# type: (List[Tensor], List[Tensor])
box_cls_flattened = []
box_regression_flattened = []
# for each feature level, permute the outputs to make them be in the
# same format as the labels. Note that the labels are computed for
# all feature levels concatenated, so we keep the same representation
# for the objectness and the box_regression
last_C = torch.jit.annotate(Optional[int], None)
for box_cls_per_level, box_regression_per_level in zip(
box_cls, box_regression
):
......@@ -207,14 +229,16 @@ def concat_box_prediction_layers(box_cls, box_regression):
)
box_cls_flattened.append(box_cls_per_level)
last_C = C
box_regression_per_level = permute_and_flatten(
box_regression_per_level, N, A, 4, H, W
)
box_regression_flattened.append(box_regression_per_level)
assert last_C is not None
# concatenate on the first dimension (representing the feature levels), to
# take into account the way the labels were generated (with all feature maps
# being concatenated as well)
box_cls = torch.cat(box_cls_flattened, dim=1).reshape(-1, C)
box_cls = torch.cat(box_cls_flattened, dim=1).reshape(-1, last_C)
box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4)
return box_cls, box_regression
......@@ -244,6 +268,13 @@ class RegionProposalNetwork(torch.nn.Module):
nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
"""
__annotations__ = {
'box_coder': det_utils.BoxCoder,
'proposal_matcher': det_utils.Matcher,
'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler,
'pre_nms_top_n': Dict[str, int],
'post_nms_top_n': Dict[str, int],
}
def __init__(self,
anchor_generator,
......@@ -276,24 +307,23 @@ class RegionProposalNetwork(torch.nn.Module):
self.nms_thresh = nms_thresh
self.min_size = 1e-3
@property
def pre_nms_top_n(self):
if self.training:
return self._pre_nms_top_n['training']
return self._pre_nms_top_n['testing']
@property
def post_nms_top_n(self):
if self.training:
return self._post_nms_top_n['training']
return self._post_nms_top_n['testing']
def assign_targets_to_anchors(self, anchors, targets):
# type: (List[Tensor], List[Dict[str, Tensor]])
labels = []
matched_gt_boxes = []
for anchors_per_image, targets_per_image in zip(anchors, targets):
gt_boxes = targets_per_image["boxes"]
match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image)
match_quality_matrix = box_ops.box_iou(gt_boxes, anchors_per_image)
matched_idxs = self.proposal_matcher(match_quality_matrix)
# get the targets corresponding GT for each proposal
# NB: need to clamp the indices because we can have a single
......@@ -306,31 +336,33 @@ class RegionProposalNetwork(torch.nn.Module):
# Background (negative examples)
bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
labels_per_image[bg_indices] = 0
labels_per_image[bg_indices] = torch.tensor(0)
# discard indices that are between thresholds
inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
labels_per_image[inds_to_discard] = -1
labels_per_image[inds_to_discard] = torch.tensor(-1)
labels.append(labels_per_image)
matched_gt_boxes.append(matched_gt_boxes_per_image)
return labels, matched_gt_boxes
def _get_top_n_idx(self, objectness, num_anchors_per_level):
# type: (Tensor, List[int])
r = []
offset = 0
for ob in objectness.split(num_anchors_per_level, 1):
if torchvision._is_tracing():
num_anchors, pre_nms_top_n = _onnx_get_num_anchors_and_pre_nms_top_n(ob, self.pre_nms_top_n)
num_anchors, pre_nms_top_n = _onnx_get_num_anchors_and_pre_nms_top_n(ob, self.pre_nms_top_n())
else:
num_anchors = ob.shape[1]
pre_nms_top_n = min(self.pre_nms_top_n, num_anchors)
pre_nms_top_n = min(self.pre_nms_top_n(), num_anchors)
_, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
r.append(top_n_idx + offset)
offset += num_anchors
return torch.cat(r, dim=1)
def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level):
# type: (Tensor, Tensor, List[Tuple[int, int]], List[int])
num_images = proposals.shape[0]
device = proposals.device
# do not backprop throught objectness
......@@ -346,7 +378,10 @@ class RegionProposalNetwork(torch.nn.Module):
# select top_n boxes independently per level before applying nms
top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)
batch_idx = torch.arange(num_images, device=device)[:, None]
image_range = torch.arange(num_images, device=device)
batch_idx = image_range[:, None]
objectness = objectness[batch_idx, top_n_idx]
levels = levels[batch_idx, top_n_idx]
proposals = proposals[batch_idx, top_n_idx]
......@@ -360,13 +395,14 @@ class RegionProposalNetwork(torch.nn.Module):
# non-maximum suppression, independently done per level
keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
# keep only topk scoring predictions
keep = keep[:self.post_nms_top_n]
keep = keep[:self.post_nms_top_n()]
boxes, scores = boxes[keep], scores[keep]
final_boxes.append(boxes)
final_scores.append(scores)
return final_boxes, final_scores
def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets):
# type: (Tensor, Tensor, List[Tensor], List[Tensor])
"""
Arguments:
objectness (Tensor)
......@@ -403,6 +439,7 @@ class RegionProposalNetwork(torch.nn.Module):
return objectness_loss, box_loss
def forward(self, images, features, targets=None):
# type: (ImageList, Dict[str, Tensor], Optional[List[Dict[str, Tensor]]])
"""
Arguments:
images (ImageList): images for which we want to compute the predictions
......@@ -437,6 +474,7 @@ class RegionProposalNetwork(torch.nn.Module):
losses = {}
if self.training:
assert targets is not None
labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
loss_objectness, loss_rpn_box_reg = self.compute_loss(
......
from __future__ import division
import random
import math
import torch
from torch import nn
from torch import nn, Tensor
import torchvision
from torch.jit.annotations import List, Tuple, Dict, Optional
from torchvision.ops import misc as misc_nn_ops
from .image_list import ImageList
......@@ -31,22 +34,29 @@ class GeneralizedRCNNTransform(nn.Module):
self.image_std = image_std
def forward(self, images, targets=None):
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
images = [img for img in images]
for i in range(len(images)):
image = images[i]
target = targets[i] if targets is not None else targets
target_index = targets[i] if targets is not None else None
if image.dim() != 3:
raise ValueError("images is expected to be a list of 3d tensors "
"of shape [C, H, W], got {}".format(image.shape))
image = self.normalize(image)
image, target = self.resize(image, target)
image, target_index = self.resize(image, target_index)
images[i] = image
if targets is not None:
targets[i] = target
if targets is not None and target_index is not None:
targets[i] = target_index
image_sizes = [img.shape[-2:] for img in images]
images = self.batch_images(images)
image_list = ImageList(images, image_sizes)
image_sizes_list = torch.jit.annotate(List[Tuple[int, int]], [])
for image_size in image_sizes:
assert len(image_size) == 2
image_sizes_list.append((image_size[0], image_size[1]))
image_list = ImageList(images, image_sizes_list)
return image_list, targets
def normalize(self, image):
......@@ -55,16 +65,27 @@ class GeneralizedRCNNTransform(nn.Module):
std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
return (image - mean[:, None, None]) / std[:, None, None]
def torch_choice(self, l):
# type: (List[int])
"""
Implements `random.choice` via torch ops so it can be compiled with
TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
is fixed.
"""
index = int(torch.empty(1).uniform_(0., float(len(l))).item())
return l[index]
def resize(self, image, target):
# type: (Tensor, Optional[Dict[str, Tensor]])
h, w = image.shape[-2:]
im_shape = torch.tensor(image.shape[-2:])
min_size = float(torch.min(im_shape))
max_size = float(torch.max(im_shape))
if self.training:
size = random.choice(self.min_size)
size = float(self.torch_choice(self.min_size))
else:
# FIXME assume for now that testing uses the largest scale
size = self.min_size[-1]
size = float(self.min_size[-1])
scale_factor = size / min_size
if max_size * scale_factor > self.max_size:
scale_factor = self.max_size / max_size
......@@ -91,7 +112,9 @@ class GeneralizedRCNNTransform(nn.Module):
# _onnx_batch_images() is an implementation of
# batch_images() that is supported by ONNX tracing.
@torch.jit.unused
def _onnx_batch_images(self, images, size_divisible=32):
# type: (List[Tensor], int) -> Tensor
max_size = []
for i in range(images[0].dim()):
max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
......@@ -112,27 +135,36 @@ class GeneralizedRCNNTransform(nn.Module):
return torch.stack(padded_imgs)
def max_by_axis(self, the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
def batch_images(self, images, size_divisible=32):
# type: (List[Tensor], int)
if torchvision._is_tracing():
# batch_images() does not export well to ONNX
# call _onnx_batch_images() instead
return self._onnx_batch_images(images, size_divisible)
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
stride = size_divisible
max_size = self.max_by_axis([list(img.shape) for img in images])
stride = float(size_divisible)
max_size = list(max_size)
max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride)
max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride)
max_size = tuple(max_size)
batch_shape = (len(images),) + max_size
batched_imgs = images[0].new(*batch_shape).zero_()
batch_shape = [len(images)] + max_size
batched_imgs = images[0].new_full(batch_shape, 0)
for img, pad_img in zip(images, batched_imgs):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
return batched_imgs
def postprocess(self, result, image_shapes, original_image_sizes):
# type: (List[Dict[str, Tensor]], List[Tuple[int, int]], List[Tuple[int, int]])
if self.training:
return result
for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
......@@ -151,7 +183,8 @@ class GeneralizedRCNNTransform(nn.Module):
def resize_keypoints(keypoints, original_size, new_size):
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size))
# type: (Tensor, List[int], List[int])
ratios = [float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)]
ratio_h, ratio_w = ratios
resized_data = keypoints.clone()
if torch._C._get_tracing_state():
......@@ -165,7 +198,8 @@ def resize_keypoints(keypoints, original_size, new_size):
def resize_boxes(boxes, original_size, new_size):
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size))
# type: (Tensor, List[int], List[int])
ratios = [float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)]
ratio_height, ratio_width = ratios
xmin, ymin, xmax, ymax = boxes.unbind(1)
......
from .boxes import nms, box_iou
from .new_empty_tensor import _new_empty_tensor
from .roi_align import roi_align, RoIAlign
from .roi_pool import roi_pool, RoIPool
from .ps_roi_align import ps_roi_align, PSRoIAlign
......@@ -12,7 +13,7 @@ _register_custom_op()
__all__ = [
'nms', 'roi_align', 'RoIAlign', 'roi_pool', 'RoIPool',
'nms', 'roi_align', 'RoIAlign', 'roi_pool', 'RoIPool', '_new_empty_tensor',
'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool', 'PSRoIPool',
'MultiScaleRoIAlign', 'FeaturePyramidNetwork'
]
from __future__ import division
import torch
from torch.jit.annotations import Tuple
from torch import Tensor
def nms(boxes, scores, iou_threshold):
# type: (Tensor, Tensor, float)
"""
Performs non-maximum suppression (NMS) on the boxes according
to their intersection-over-union (IoU).
......@@ -32,6 +37,7 @@ def nms(boxes, scores, iou_threshold):
def batched_nms(boxes, scores, idxs, iou_threshold):
# type: (Tensor, Tensor, Tensor, float)
"""
Performs non-maximum suppression in a batched fashion.
......@@ -72,12 +78,13 @@ def batched_nms(boxes, scores, idxs, iou_threshold):
def remove_small_boxes(boxes, min_size):
# type: (Tensor, float)
"""
Remove boxes which contains at least one side smaller than min_size.
Arguments:
boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
min_size (int): minimum size
min_size (float): minimum size
Returns:
keep (Tensor[K]): indices of the boxes that have both sides
......@@ -90,6 +97,7 @@ def remove_small_boxes(boxes, min_size):
def clip_boxes_to_image(boxes, size):
# type: (Tensor, Tuple[int, int])
"""
Clip boxes so that they lie inside an image of size `size`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment