Unverified Commit ccd1b27d authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add Faster R-CNN and Mask R-CNN (#898)

* [Remove] Use stride in 1x1 in resnet

This is temporary

* Move files to torchvision

Inference works

* Now seems to give same results

Was using the wrong number of total iterations in the end...

* Distributed evaluation seems to work

* Factor out transforms into its own file

* Enabling horizontal flips

* MultiStepLR and preparing for launches

* Add warmup

* Clip gt boxes to images

Seems to be crucial to avoid divergence. Also reduces the losses over different processes for better logging

* Single-GPU batch-size 1 of CocoEvaluator works

* Multi-GPU CocoEvaluator works

Gives the exact same results as the other one, and also supports batch size > 1

* Silence prints from pycocotools

* Commenting unneeded code for run

* Fixes

* Improvements and cleanups

* Remove scales from Pooler

It was not a free parameter, and depended only on the feature map dimensions

* Cleanups

* More cleanups

* Add misc ops and totally remove maskrcnn_benchmark

* nit

* Move Pooler to ops

* Make FPN slightly more generic

* Minor improvements or FPN

* Move FPN to ops

* Move functions to utils

* Lint fixes

* More lint

* Minor cleanups

* Add FasterRCNN

* Remove modifications to resnet

* Fixes for Python2

* More lint fixes

* Add aspect ratio grouping

* Move functions around

* Make evaluation use all images for mAP, even those without annotations

* Bugfix with DDP introduced in last commit

* [Check] Remove category mapping

* Lint

* Make GroupedBatchSampler prioritize largest clusters in the end of iteration

* Bugfix for selecting the iou_types during evaluation

Also switch to using the torchvision normalization now on, given that we are using torchvision base models

* More lint

* Add barrier after init_process_group

Better be safe than sorry

* Make evaluation only use one CPU thread per process

When doing multi-gpu evaluation, paste_masks_in_image is multithreaded and throttles evaluation altogether. Also change default for aspect ratio group to match Detectron

* Fix bug in GroupedBatchSampler

After the first epoch, the number of batch elements could be larger than batch_size, because they got accumulated from the previous iteration. Fix this and also rename some variables for more clarity

* Start adding KeypointRCNN

Currently runs and perform inference, need to do full training

* Remove use of opencv in keypoint inference

PyTorch 1.1 adds support for bicubic interpolation which matches opencv (except for empty boxes, where one of the dimensions is 1, but that's fine)

* Remove Masker

Towards having mask postprocessing done inside the model

* Bugfixes in previous change plus cleanups

* Preparing to run keypoint training

* Zero initialize bias for mask heads

* Minor improvements on print

* Towards moving resize to model

Also remove class mapping specific to COCO

* Remove zero init in bias for mask head

Checking if it decreased accuracy

* [CHECK] See if this change brings back expected accuracy

* Cleanups on model and training script

* Remove BatchCollator

* Some cleanups in coco_eval

* Move postprocess to transform

* Revert back scaling and start adding conversion to coco api

The scaling didn't seem to matter

* Use decorator instead of context manager in evaluate

* Move training and evaluation functions to a separate file

Also adds support for obtaining a coco API object from our dataset

* Remove unused code

* Update location of lr_scheduler

Its behavior has changed in PyTorch 1.1

* Remove debug code

* Typo

* Bugfix

* Move image normalization to model

* Remove legacy tensor constructors

Also move away from Int and instead use int64

* Bugfix in MultiscaleRoiAlign

* Move transforms to its own file

* Add missing file

* Lint

* More lint

* Add some basic test for detection models

* More lint
parent 6272c412
from __future__ import division
"""
helper class that supports empty tensors on some nn functions.
Ideally, add support directly in PyTorch to empty tensors in
those functions.
This can be removed once https://github.com/pytorch/pytorch/issues/12013
is implemented
"""
import math
import torch
from torch.nn.modules.utils import _ntuple
class _NewEmptyTensorOp(torch.autograd.Function):
@staticmethod
def forward(ctx, x, new_shape):
ctx.shape = x.shape
return x.new_empty(new_shape)
@staticmethod
def backward(ctx, grad):
shape = ctx.shape
return _NewEmptyTensorOp.apply(grad, shape), None
class Conv2d(torch.nn.Conv2d):
def forward(self, x):
if x.numel() > 0:
return super(Conv2d, self).forward(x)
# get output shape
output_shape = [
(i + 2 * p - (di * (k - 1) + 1)) // d + 1
for i, p, di, k, d in zip(
x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride
)
]
output_shape = [x.shape[0], self.weight.shape[0]] + output_shape
return _NewEmptyTensorOp.apply(x, output_shape)
class ConvTranspose2d(torch.nn.ConvTranspose2d):
def forward(self, x):
if x.numel() > 0:
return super(ConvTranspose2d, self).forward(x)
# get output shape
output_shape = [
(i - 1) * d - 2 * p + (di * (k - 1) + 1) + op
for i, p, di, k, d, op in zip(
x.shape[-2:],
self.padding,
self.dilation,
self.kernel_size,
self.stride,
self.output_padding,
)
]
output_shape = [x.shape[0], self.bias.shape[0]] + output_shape
return _NewEmptyTensorOp.apply(x, output_shape)
class BatchNorm2d(torch.nn.BatchNorm2d):
def forward(self, x):
if x.numel() > 0:
return super(BatchNorm2d, self).forward(x)
# get output shape
output_shape = x.shape
return _NewEmptyTensorOp.apply(x, output_shape)
def interpolate(
input, size=None, scale_factor=None, mode="nearest", align_corners=None
):
if input.numel() > 0:
return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners
)
def _check_size_scale_factor(dim):
if size is None and scale_factor is None:
raise ValueError("either size or scale_factor should be defined")
if size is not None and scale_factor is not None:
raise ValueError("only one of size or scale_factor should be defined")
if (
scale_factor is not None and
isinstance(scale_factor, tuple) and
len(scale_factor) != dim
):
raise ValueError(
"scale_factor shape must match input shape. "
"Input is {}D, scale_factor size is {}".format(dim, len(scale_factor))
)
def _output_size(dim):
_check_size_scale_factor(dim)
if size is not None:
return size
scale_factors = _ntuple(dim)(scale_factor)
# math.floor might return float in py2.7
return [
int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)
]
output_shape = tuple(_output_size(2))
output_shape = input.shape[:-2] + output_shape
return _NewEmptyTensorOp.apply(input, output_shape)
# This is not in nn
class FrozenBatchNorm2d(torch.jit.ScriptModule):
"""
BatchNorm2d where the batch statistics and the affine parameters
are fixed
"""
def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
@torch.jit.script_method
def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
scale = w * rv.rsqrt()
bias = b - rm * scale
return x * scale + bias
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
import torch.nn.functional as F
from torch import nn
from torchvision.ops import roi_align
from torchvision.ops.boxes import box_area
class LevelMapper(object):
"""Determine which FPN level each RoI in a set of RoIs should map to based
on the heuristic in the FPN paper.
"""
def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
"""
Arguments:
k_min (int)
k_max (int)
canonical_scale (int)
canonical_level (int)
eps (float)
"""
self.k_min = k_min
self.k_max = k_max
self.s0 = canonical_scale
self.lvl0 = canonical_level
self.eps = eps
def __call__(self, boxlists):
"""
Arguments:
boxlists (list[BoxList])
"""
# Compute level ids
s = torch.sqrt(torch.cat([box_area(boxlist) for boxlist in boxlists]))
# Eqn.(1) in FPN paper
target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0 + self.eps))
target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max)
return target_lvls.to(torch.int64) - self.k_min
class MultiScaleRoIAlign(nn.Module):
"""
Pooler for Detection with or without FPN.
It currently hard-code ROIAlign in the implementation,
but that can be made more generic later on.
Also, the requirement of passing the scales is not strictly necessary, as they
can be inferred from the size of the feature map / size of original image,
which is available thanks to the BoxList.
"""
def __init__(self, featmap_names, output_size, sampling_ratio):
"""
Arguments:
output_size (list[tuple[int]] or list[int]): output size for the pooled region
scales (list[float]): scales for each Pooler
sampling_ratio (int): sampling ratio for ROIAlign
"""
super(MultiScaleRoIAlign, self).__init__()
if isinstance(output_size, int):
output_size = (output_size, output_size)
self.featmap_names = featmap_names
self.sampling_ratio = sampling_ratio
self.output_size = tuple(output_size)
self.scales = None
self.map_levels = None
def convert_to_roi_format(self, boxes):
concat_boxes = torch.cat(boxes, dim=0)
device, dtype = concat_boxes.device, concat_boxes.dtype
ids = torch.cat(
[
torch.full((len(b), 1), i, dtype=dtype, device=device)
for i, b in enumerate(boxes)
],
dim=0,
)
rois = torch.cat([ids, concat_boxes], dim=1)
return rois
def infer_scale(self, feature, original_size):
# assumption: the scale is of the form 2 ** (-k), with k integer
size = feature.shape[-2:]
possible_scales = []
for s1, s2 in zip(size, original_size):
approx_scale = float(s1) / s2
scale = 2 ** torch.tensor(approx_scale).log2().round().item()
possible_scales.append(scale)
assert possible_scales[0] == possible_scales[1]
return possible_scales[0]
def setup_scales(self, features, image_shapes):
original_input_shape = tuple(max(s) for s in zip(*image_shapes))
scales = [self.infer_scale(feat, original_input_shape) for feat in features]
# get the levels in the feature map by leveraging the fact that the network always
# downsamples by a factor of 2 at each level.
lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
self.scales = scales
self.map_levels = LevelMapper(lvl_min, lvl_max)
def forward(self, x, boxes, image_shapes):
"""
Arguments:
x (OrderedDict[Tensor]): feature maps for each level
boxes (list[BoxList]): boxes to be used to perform the pooling operation.
Returns:
result (Tensor)
"""
x = [v for k, v in x.items() if k in self.featmap_names]
num_levels = len(x)
rois = self.convert_to_roi_format(boxes)
if self.scales is None:
self.setup_scales(x, image_shapes)
if num_levels == 1:
return roi_align(
x[0], rois,
output_size=self.output_size,
spatial_scale=self.scales[0],
sampling_ratio=self.sampling_ratio
)
levels = self.map_levels(boxes)
num_rois = len(rois)
num_channels = x[0].shape[1]
dtype, device = x[0].dtype, x[0].device
result = torch.zeros(
(num_rois, num_channels,) + self.output_size,
dtype=dtype,
device=device,
)
for level, (per_level_feature, scale) in enumerate(zip(x, self.scales)):
idx_in_level = torch.nonzero(levels == level).squeeze(1)
rois_per_level = rois[idx_in_level]
result[idx_in_level] = roi_align(
per_level_feature, rois_per_level,
output_size=self.output_size,
spatial_scale=scale, sampling_ratio=self.sampling_ratio
)
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment