Unverified Commit 5bb81c8e authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

RetinaNet object detection (take 2) (#2784)



* Add rough implementation of RetinaNet.

* Move AnchorGenerator to a seperate file.

* Move box similarity to Matcher.

* Expose extra blocks in FPN.

* Expose retinanet in __init__.py.

* Use P6 and P7 in FPN for retinanet.

* Use parameters from retinanet for anchor generation.

* General fixes for retinanet model.

* Implement loss for retinanet heads.

* Output reshaped outputs from retinanet heads.

* Add postprocessing of detections.

* Small fixes.

* Remove unused argument.

* Remove python2 invocation of super.

* Add postprocessing for additional outputs.

* Add missing import of ImageList.

* Remove redundant import.

* Simplify class correction.

* Fix pylint warnings.

* Remove the label adjustment for background class.

* Set default score threshold to 0.05.

* Add weight initialization for regression layer.

* Allow training on images with no annotations.

* Use smooth_l1_loss with beta value.

* Add more typehints for TorchScript conversions.

* Fix linting issues.

* Fix type hints in postprocess_detections.

* Fix type annotations for TorchScript.

* Fix inconsistency with matched_idxs.

* Add retinanet model test.

* Add missing JIT annotations.

* Remove redundant model construction

Make tests pass

* Fix bugs during training on newer PyTorch and unused params in DDP

Needs cleanup and to add back support for images with no annotations

* Cleanup resnet_fpn_backbone

* Use L1 loss for regression

Gives 1mAP improvement over smooth l1

* Disable support for images with no annotations

Need to fix distributed first

* Fix retinanet tests

Need to deduplicate those box checks

* Fix Lint

* Add pretrained model

* Add training info for retinanet
Co-authored-by: default avatarHans Gaiser <hansg91@gmail.com>
Co-authored-by: default avatarHans Gaiser <hans.gaiser@robovalley.com>
Co-authored-by: default avatarHans Gaiser <hans.gaiser@robohouse.com>
parent 42e7f1f0
...@@ -350,6 +350,7 @@ the instances set of COCO train2017 and evaluated on COCO val2017. ...@@ -350,6 +350,7 @@ the instances set of COCO train2017 and evaluated on COCO val2017.
Network box AP mask AP keypoint AP Network box AP mask AP keypoint AP
================================ ======= ======== =========== ================================ ======= ======== ===========
Faster R-CNN ResNet-50 FPN 37.0 - - Faster R-CNN ResNet-50 FPN 37.0 - -
RetinaNet ResNet-50 FPN 36.4 - -
Mask R-CNN ResNet-50 FPN 37.9 34.6 - Mask R-CNN ResNet-50 FPN 37.9 34.6 -
================================ ======= ======== =========== ================================ ======= ======== ===========
...@@ -405,6 +406,7 @@ precision-recall. ...@@ -405,6 +406,7 @@ precision-recall.
Network train time (s / it) test time (s / it) memory (GB) Network train time (s / it) test time (s / it) memory (GB)
============================== =================== ================== =========== ============================== =================== ================== ===========
Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2 Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4 Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8 Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
============================== =================== ================== =========== ============================== =================== ================== ===========
...@@ -416,6 +418,12 @@ Faster R-CNN ...@@ -416,6 +418,12 @@ Faster R-CNN
.. autofunction:: torchvision.models.detection.fasterrcnn_resnet50_fpn .. autofunction:: torchvision.models.detection.fasterrcnn_resnet50_fpn
RetinaNet
------------
.. autofunction:: torchvision.models.detection.retinanet_resnet50_fpn
Mask R-CNN Mask R-CNN
---------- ----------
......
...@@ -27,6 +27,13 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ ...@@ -27,6 +27,13 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr-steps 16 22 --aspect-ratio-group-factor 3
``` ```
### RetinaNet
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--dataset coco --model retinanet_resnet50_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01
```
### Mask R-CNN ### Mask R-CNN
``` ```
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
from .faster_rcnn import * from .faster_rcnn import *
from .mask_rcnn import * from .mask_rcnn import *
from .keypoint_rcnn import * from .keypoint_rcnn import *
from .retinanet import *
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from torch import nn
from torch.jit.annotations import List, Optional, Dict
from .image_list import ImageList
class AnchorGenerator(nn.Module):
"""
Module that generates anchors for a set of feature maps and
image sizes.
The module support computing anchors at multiple sizes and aspect ratios
per feature map. This module assumes aspect ratio = height / width for
each anchor.
sizes and aspect_ratios should have the same number of elements, and it should
correspond to the number of feature maps.
sizes[i] and aspect_ratios[i] can have an arbitrary number of elements,
and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors
per spatial location for feature map i.
Arguments:
sizes (Tuple[Tuple[int]]):
aspect_ratios (Tuple[Tuple[float]]):
"""
__annotations__ = {
"cell_anchors": Optional[List[torch.Tensor]],
"_cache": Dict[str, List[torch.Tensor]]
}
def __init__(
self,
sizes=((128, 256, 512),),
aspect_ratios=((0.5, 1.0, 2.0),),
):
super(AnchorGenerator, self).__init__()
if not isinstance(sizes[0], (list, tuple)):
# TODO change this
sizes = tuple((s,) for s in sizes)
if not isinstance(aspect_ratios[0], (list, tuple)):
aspect_ratios = (aspect_ratios,) * len(sizes)
assert len(sizes) == len(aspect_ratios)
self.sizes = sizes
self.aspect_ratios = aspect_ratios
self.cell_anchors = None
self._cache = {}
# TODO: https://github.com/pytorch/pytorch/issues/26792
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
# (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
# This method assumes aspect ratio = height / width for an anchor.
def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"):
# type: (List[int], List[float], int, Device) -> Tensor # noqa: F821
scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios)
w_ratios = 1 / h_ratios
ws = (w_ratios[:, None] * scales[None, :]).view(-1)
hs = (h_ratios[:, None] * scales[None, :]).view(-1)
base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
return base_anchors.round()
def set_cell_anchors(self, dtype, device):
# type: (int, Device) -> None # noqa: F821
if self.cell_anchors is not None:
cell_anchors = self.cell_anchors
assert cell_anchors is not None
# suppose that all anchors have the same device
# which is a valid assumption in the current state of the codebase
if cell_anchors[0].device == device:
return
cell_anchors = [
self.generate_anchors(
sizes,
aspect_ratios,
dtype,
device
)
for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios)
]
self.cell_anchors = cell_anchors
def num_anchors_per_location(self):
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
def grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
anchors = []
cell_anchors = self.cell_anchors
assert cell_anchors is not None
assert len(grid_sizes) == len(strides) == len(cell_anchors)
for size, stride, base_anchors in zip(
grid_sizes, strides, cell_anchors
):
grid_height, grid_width = size
stride_height, stride_width = stride
device = base_anchors.device
# For output anchor, compute [x_center, y_center, x_center, y_center]
shifts_x = torch.arange(
0, grid_width, dtype=torch.float32, device=device
) * stride_width
shifts_y = torch.arange(
0, grid_height, dtype=torch.float32, device=device
) * stride_height
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
# For every (base anchor, output anchor) pair,
# offset each zero-centered base anchor by the center of the output anchor.
anchors.append(
(shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)
)
return anchors
def cached_grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
key = str(grid_sizes) + str(strides)
if key in self._cache:
return self._cache[key]
anchors = self.grid_anchors(grid_sizes, strides)
self._cache[key] = anchors
return anchors
def forward(self, image_list, feature_maps):
# type: (ImageList, List[Tensor]) -> List[Tensor]
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:]
dtype, device = feature_maps[0].dtype, feature_maps[0].device
strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = torch.jit.annotate(List[List[torch.Tensor]], [])
for i, (image_height, image_width) in enumerate(image_list.image_sizes):
anchors_in_image = []
for anchors_per_feature_map in anchors_over_all_feature_maps:
anchors_in_image.append(anchors_per_feature_map)
anchors.append(anchors_in_image)
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
# Clear the cache in case that memory leaks.
self._cache.clear()
return anchors
...@@ -25,13 +25,17 @@ class BackboneWithFPN(nn.Module): ...@@ -25,13 +25,17 @@ class BackboneWithFPN(nn.Module):
Attributes: Attributes:
out_channels (int): the number of channels in the FPN out_channels (int): the number of channels in the FPN
""" """
def __init__(self, backbone, return_layers, in_channels_list, out_channels): def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None):
super(BackboneWithFPN, self).__init__() super(BackboneWithFPN, self).__init__()
if extra_blocks is None:
extra_blocks = LastLevelMaxPool()
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.fpn = FeaturePyramidNetwork( self.fpn = FeaturePyramidNetwork(
in_channels_list=in_channels_list, in_channels_list=in_channels_list,
out_channels=out_channels, out_channels=out_channels,
extra_blocks=LastLevelMaxPool(), extra_blocks=extra_blocks,
) )
self.out_channels = out_channels self.out_channels = out_channels
...@@ -41,7 +45,14 @@ class BackboneWithFPN(nn.Module): ...@@ -41,7 +45,14 @@ class BackboneWithFPN(nn.Module):
return x return x
def resnet_fpn_backbone(backbone_name, pretrained, norm_layer=misc_nn_ops.FrozenBatchNorm2d, trainable_layers=3): def resnet_fpn_backbone(
backbone_name,
pretrained,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=3,
returned_layers=None,
extra_blocks=None
):
""" """
Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone. Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.
...@@ -82,14 +93,15 @@ def resnet_fpn_backbone(backbone_name, pretrained, norm_layer=misc_nn_ops.Frozen ...@@ -82,14 +93,15 @@ def resnet_fpn_backbone(backbone_name, pretrained, norm_layer=misc_nn_ops.Frozen
if all([not name.startswith(layer) for layer in layers_to_train]): if all([not name.startswith(layer) for layer in layers_to_train]):
parameter.requires_grad_(False) parameter.requires_grad_(False)
return_layers = {'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'} if extra_blocks is None:
extra_blocks = LastLevelMaxPool()
if returned_layers is None:
returned_layers = [1, 2, 3, 4]
assert min(returned_layers) > 0 and max(returned_layers) < 5
return_layers = {f'layer{k}': str(v) for v, k in enumerate(returned_layers)}
in_channels_stage2 = backbone.inplanes // 8 in_channels_stage2 = backbone.inplanes // 8
in_channels_list = [ in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
in_channels_stage2,
in_channels_stage2 * 2,
in_channels_stage2 * 4,
in_channels_stage2 * 8,
]
out_channels = 256 out_channels = 256
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels) return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
...@@ -9,8 +9,9 @@ from torchvision.ops import MultiScaleRoIAlign ...@@ -9,8 +9,9 @@ from torchvision.ops import MultiScaleRoIAlign
from ..utils import load_state_dict_from_url from ..utils import load_state_dict_from_url
from .anchor_utils import AnchorGenerator
from .generalized_rcnn import GeneralizedRCNN from .generalized_rcnn import GeneralizedRCNN
from .rpn import AnchorGenerator, RPNHead, RegionProposalNetwork from .rpn import RPNHead, RegionProposalNetwork
from .roi_heads import RoIHeads from .roi_heads import RoIHeads
from .transform import GeneralizedRCNNTransform from .transform import GeneralizedRCNNTransform
from .backbone_utils import resnet_fpn_backbone from .backbone_utils import resnet_fpn_backbone
......
...@@ -103,7 +103,7 @@ class KeypointRCNN(FasterRCNN): ...@@ -103,7 +103,7 @@ class KeypointRCNN(FasterRCNN):
>>> import torch >>> import torch
>>> import torchvision >>> import torchvision
>>> from torchvision.models.detection import KeypointRCNN >>> from torchvision.models.detection import KeypointRCNN
>>> from torchvision.models.detection.rpn import AnchorGenerator >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
>>> >>>
>>> # load a pre-trained model for classification and return >>> # load a pre-trained model for classification and return
>>> # only the features >>> # only the features
......
...@@ -107,7 +107,7 @@ class MaskRCNN(FasterRCNN): ...@@ -107,7 +107,7 @@ class MaskRCNN(FasterRCNN):
>>> import torch >>> import torch
>>> import torchvision >>> import torchvision
>>> from torchvision.models.detection import MaskRCNN >>> from torchvision.models.detection import MaskRCNN
>>> from torchvision.models.detection.rpn import AnchorGenerator >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
>>> >>>
>>> # load a pre-trained model for classification and return >>> # load a pre-trained model for classification and return
>>> # only the features >>> # only the features
......
import math
from collections import OrderedDict
import warnings
import torch
import torch.nn as nn
from torch import Tensor
from torch.jit.annotations import Dict, List, Tuple
from ..utils import load_state_dict_from_url
from . import _utils as det_utils
from .anchor_utils import AnchorGenerator
from .transform import GeneralizedRCNNTransform
from .backbone_utils import resnet_fpn_backbone
from ...ops.feature_pyramid_network import LastLevelP6P7
from ...ops import sigmoid_focal_loss
from ...ops import boxes as box_ops
__all__ = [
"RetinaNet", "retinanet_resnet50_fpn",
]
def _sum(x: List[Tensor]) -> Tensor:
res = x[0]
for i in x[1:]:
res = res + i
return res
class RetinaNetHead(nn.Module):
"""
A regression and classification head for use in RetinaNet.
Arguments:
in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted
num_classes (int): number of classes to be predicted
"""
def __init__(self, in_channels, num_anchors, num_classes):
super().__init__()
self.classification_head = RetinaNetClassificationHead(in_channels, num_anchors, num_classes)
self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors)
def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor]
return {
'classification': self.classification_head.compute_loss(targets, head_outputs, matched_idxs),
'bbox_regression': self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs),
}
def forward(self, x):
# type: (List[Tensor]) -> Dict[str, Tensor]
return {
'cls_logits': self.classification_head(x),
'bbox_regression': self.regression_head(x)
}
class RetinaNetClassificationHead(nn.Module):
"""
A classification head for use in RetinaNet.
Arguments:
in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted
num_classes (int): number of classes to be predicted
"""
def __init__(self, in_channels, num_anchors, num_classes, prior_probability=0.01):
super().__init__()
conv = []
for _ in range(4):
conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
conv.append(nn.ReLU())
self.conv = nn.Sequential(*conv)
for layer in self.conv.children():
if isinstance(layer, nn.Conv2d):
torch.nn.init.normal_(layer.weight, std=0.01)
torch.nn.init.constant_(layer.bias, 0)
self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))
self.num_classes = num_classes
self.num_anchors = num_anchors
# This is to fix using det_utils.Matcher.BETWEEN_THRESHOLDS in TorchScript.
# TorchScript doesn't support class attributes.
# https://github.com/pytorch/vision/pull/1697#issuecomment-630255584
self.BETWEEN_THRESHOLDS = det_utils.Matcher.BETWEEN_THRESHOLDS
def compute_loss(self, targets, head_outputs, matched_idxs):
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor
losses = []
cls_logits = head_outputs['cls_logits']
for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs):
# determine only the foreground
foreground_idxs_per_image = matched_idxs_per_image >= 0
num_foreground = foreground_idxs_per_image.sum()
# no matched_idxs means there were no annotations in this image
# TODO: enable support for images without annotations that works on distributed
if False: # matched_idxs_per_image.numel() == 0:
gt_classes_target = torch.zeros_like(cls_logits_per_image)
valid_idxs_per_image = torch.arange(cls_logits_per_image.shape[0])
else:
# create the target classification
gt_classes_target = torch.zeros_like(cls_logits_per_image)
gt_classes_target[
foreground_idxs_per_image,
targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]]
] = 1.0
# find indices for which anchors should be ignored
valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS
# compute the classification loss
losses.append(sigmoid_focal_loss(
cls_logits_per_image[valid_idxs_per_image],
gt_classes_target[valid_idxs_per_image],
reduction='sum',
) / max(1, num_foreground))
return _sum(losses) / len(targets)
def forward(self, x):
# type: (List[Tensor]) -> Tensor
all_cls_logits = []
for features in x:
cls_logits = self.conv(features)
cls_logits = self.cls_logits(cls_logits)
# Permute classification output from (N, A * K, H, W) to (N, HWA, K).
N, _, H, W = cls_logits.shape
cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4)
all_cls_logits.append(cls_logits)
return torch.cat(all_cls_logits, dim=1)
class RetinaNetRegressionHead(nn.Module):
"""
A regression head for use in RetinaNet.
Arguments:
in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted
"""
__annotations__ = {
'box_coder': det_utils.BoxCoder,
}
def __init__(self, in_channels, num_anchors):
super().__init__()
conv = []
for _ in range(4):
conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
conv.append(nn.ReLU())
self.conv = nn.Sequential(*conv)
self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
torch.nn.init.normal_(self.bbox_reg.weight, std=0.01)
torch.nn.init.zeros_(self.bbox_reg.bias)
for layer in self.conv.children():
if isinstance(layer, nn.Conv2d):
torch.nn.init.normal_(layer.weight, std=0.01)
torch.nn.init.zeros_(layer.bias)
self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor
losses = []
bbox_regression = head_outputs['bbox_regression']
for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in \
zip(targets, bbox_regression, anchors, matched_idxs):
# no matched_idxs means there were no annotations in this image
# TODO enable support for images without annotations with distributed support
# if matched_idxs_per_image.numel() == 0:
# continue
# get the targets corresponding GT for each proposal
# NB: need to clamp the indices because we can have a single
# GT in the image, and matched_idxs can be -2, which goes
# out of bounds
matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image.clamp(min=0)]
# determine only the foreground indices, ignore the rest
foreground_idxs_per_image = matched_idxs_per_image >= 0
num_foreground = foreground_idxs_per_image.sum()
# select only the foreground boxes
matched_gt_boxes_per_image = matched_gt_boxes_per_image[foreground_idxs_per_image, :]
bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
# compute the regression targets
target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
# compute the loss
losses.append(torch.nn.functional.l1_loss(
bbox_regression_per_image,
target_regression,
size_average=False
) / max(1, num_foreground))
return _sum(losses) / max(1, len(targets))
def forward(self, x):
# type: (List[Tensor]) -> Tensor
all_bbox_regression = []
for features in x:
bbox_regression = self.conv(features)
bbox_regression = self.bbox_reg(bbox_regression)
# Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
N, _, H, W = bbox_regression.shape
bbox_regression = bbox_regression.view(N, -1, 4, H, W)
bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4)
all_bbox_regression.append(bbox_regression)
return torch.cat(all_bbox_regression, dim=1)
class RetinaNet(nn.Module):
"""
Implements RetinaNet.
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
image, and should be in 0-1 range. Different images can have different sizes.
The behavior of the model changes depending if it is in training or evaluation mode.
During training, the model expects both the input tensors, as well as a targets (list of dictionary),
containing:
- boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with values
between 0 and H and 0 and W
- labels (Int64Tensor[N]): the class label for each ground-truth box
The model returns a Dict[Tensor] during training, containing the classification and regression
losses.
During inference, the model requires only the input tensors, and returns the post-processed
predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
follows:
- boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values between
0 and H and 0 and W
- labels (Int64Tensor[N]): the predicted labels for each image
- scores (Tensor[N]): the scores for each prediction
Arguments:
backbone (nn.Module): the network used to compute the features for the model.
It should contain an out_channels attribute, which indicates the number of output
channels that each feature map has (and it should be the same for all feature maps).
The backbone should return a single Tensor or an OrderedDict[Tensor].
num_classes (int): number of output classes of the model (excluding the background).
min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
image_mean (Tuple[float, float, float]): mean values used for input normalization.
They are generally the mean values of the dataset on which the backbone has been trained
on
image_std (Tuple[float, float, float]): std values used for input normalization.
They are generally the std values of the dataset on which the backbone has been trained on
anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
maps.
head (nn.Module): Module run on top of the feature pyramid.
Defaults to a module containing a classification and regression module.
score_thresh (float): Score threshold used for postprocessing the detections.
nms_thresh (float): NMS threshold used for postprocessing the detections.
detections_per_img (int): Number of best detections to keep after NMS.
fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
considered as positive during training.
bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
considered as negative during training.
Example:
>>> import torch
>>> import torchvision
>>> from torchvision.models.detection import RetinaNet
>>> from torchvision.models.detection.anchor_utils import AnchorGenerator
>>> # load a pre-trained model for classification and return
>>> # only the features
>>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features
>>> # RetinaNet needs to know the number of
>>> # output channels in a backbone. For mobilenet_v2, it's 1280
>>> # so we need to add it here
>>> backbone.out_channels = 1280
>>>
>>> # let's make the network generate 5 x 3 anchors per spatial
>>> # location, with 5 different sizes and 3 different aspect
>>> # ratios. We have a Tuple[Tuple[int]] because each feature
>>> # map could potentially have different sizes and
>>> # aspect ratios
>>> anchor_generator = AnchorGenerator(
>>> sizes=tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512]),
>>> aspect_ratios=((0.5, 1.0, 2.0),) * 5
>>> )
>>>
>>> # put the pieces together inside a RetinaNet model
>>> model = RetinaNet(backbone,
>>> num_classes=2,
>>> anchor_generator=anchor_generator)
>>> model.eval()
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x)
"""
__annotations__ = {
'box_coder': det_utils.BoxCoder,
'proposal_matcher': det_utils.Matcher,
}
def __init__(self, backbone, num_classes,
# transform parameters
min_size=800, max_size=1333,
image_mean=None, image_std=None,
# Anchor parameters
anchor_generator=None, head=None,
proposal_matcher=None,
score_thresh=0.05,
nms_thresh=0.5,
detections_per_img=300,
fg_iou_thresh=0.5, bg_iou_thresh=0.4):
super().__init__()
if not hasattr(backbone, "out_channels"):
raise ValueError(
"backbone should contain an attribute out_channels "
"specifying the number of output channels (assumed to be the "
"same for all the levels)")
self.backbone = backbone
assert isinstance(anchor_generator, (AnchorGenerator, type(None)))
if anchor_generator is None:
anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512])
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
anchor_generator = AnchorGenerator(
anchor_sizes, aspect_ratios
)
self.anchor_generator = anchor_generator
if head is None:
head = RetinaNetHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes)
self.head = head
if proposal_matcher is None:
proposal_matcher = det_utils.Matcher(
fg_iou_thresh,
bg_iou_thresh,
allow_low_quality_matches=True,
)
self.proposal_matcher = proposal_matcher
self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
if image_mean is None:
image_mean = [0.485, 0.456, 0.406]
if image_std is None:
image_std = [0.229, 0.224, 0.225]
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
self.detections_per_img = detections_per_img
# used only on torchscript mode
self._has_warned = False
@torch.jit.unused
def eager_outputs(self, losses, detections):
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
if self.training:
return losses
return detections
def compute_loss(self, targets, head_outputs, anchors):
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Dict[str, Tensor]
matched_idxs = []
for anchors_per_image, targets_per_image in zip(anchors, targets):
if targets_per_image['boxes'].numel() == 0:
matched_idxs.append(torch.empty((0,), dtype=torch.int32))
continue
match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image)
matched_idxs.append(self.proposal_matcher(match_quality_matrix))
return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
def postprocess_detections(self, head_outputs, anchors, image_shapes):
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
# TODO: Merge this with roi_heads.RoIHeads.postprocess_detections ?
class_logits = head_outputs.pop('cls_logits')
box_regression = head_outputs.pop('bbox_regression')
other_outputs = head_outputs
device = class_logits.device
num_classes = class_logits.shape[-1]
scores = torch.sigmoid(class_logits)
# create labels for each score
labels = torch.arange(num_classes, device=device)
labels = labels.view(1, -1).expand_as(scores)
detections = torch.jit.annotate(List[Dict[str, Tensor]], [])
for index, (box_regression_per_image, scores_per_image, labels_per_image, anchors_per_image, image_shape) in \
enumerate(zip(box_regression, scores, labels, anchors, image_shapes)):
boxes_per_image = self.box_coder.decode_single(box_regression_per_image, anchors_per_image)
boxes_per_image = box_ops.clip_boxes_to_image(boxes_per_image, image_shape)
other_outputs_per_image = [(k, v[index]) for k, v in other_outputs.items()]
image_boxes = []
image_scores = []
image_labels = []
image_other_outputs = torch.jit.annotate(Dict[str, List[Tensor]], {})
for class_index in range(num_classes):
# remove low scoring boxes
inds = torch.gt(scores_per_image[:, class_index], self.score_thresh)
boxes_per_class, scores_per_class, labels_per_class = \
boxes_per_image[inds], scores_per_image[inds, class_index], labels_per_image[inds, class_index]
other_outputs_per_class = [(k, v[inds]) for k, v in other_outputs_per_image]
# remove empty boxes
keep = box_ops.remove_small_boxes(boxes_per_class, min_size=1e-2)
boxes_per_class, scores_per_class, labels_per_class = \
boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep]
other_outputs_per_class = [(k, v[keep]) for k, v in other_outputs_per_class]
# non-maximum suppression, independently done per class
keep = box_ops.nms(boxes_per_class, scores_per_class, self.nms_thresh)
# keep only topk scoring predictions
keep = keep[:self.detections_per_img]
boxes_per_class, scores_per_class, labels_per_class = \
boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep]
other_outputs_per_class = [(k, v[keep]) for k, v in other_outputs_per_class]
image_boxes.append(boxes_per_class)
image_scores.append(scores_per_class)
image_labels.append(labels_per_class)
for k, v in other_outputs_per_class:
if k not in image_other_outputs:
image_other_outputs[k] = []
image_other_outputs[k].append(v)
detections.append({
'boxes': torch.cat(image_boxes, dim=0),
'scores': torch.cat(image_scores, dim=0),
'labels': torch.cat(image_labels, dim=0),
})
for k, v in image_other_outputs.items():
detections[-1].update({k: torch.cat(v, dim=0)})
return detections
def forward(self, images, targets=None):
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
"""
Arguments:
images (list[Tensor]): images to be processed
targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
Returns:
result (list[BoxList] or dict[Tensor]): the output from the model.
During training, it returns a dict[Tensor] which contains the losses.
During testing, it returns list[BoxList] contains additional fields
like `scores`, `labels` and `mask` (for Mask R-CNN models).
"""
if self.training and targets is None:
raise ValueError("In training mode, targets should be passed")
if self.training:
assert targets is not None
for target in targets:
boxes = target["boxes"]
if isinstance(boxes, torch.Tensor):
if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
raise ValueError("Expected target boxes to be a tensor"
"of shape [N, 4], got {:}.".format(
boxes.shape))
else:
raise ValueError("Expected target boxes to be of type "
"Tensor, got {:}.".format(type(boxes)))
# get the original image sizes
original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
for img in images:
val = img.shape[-2:]
assert len(val) == 2
original_image_sizes.append((val[0], val[1]))
# transform the input
images, targets = self.transform(images, targets)
# Check for degenerate boxes
# TODO: Move this to a function
if targets is not None:
for target_idx, target in enumerate(targets):
boxes = target["boxes"]
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
# print the first degenerate box
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
degen_bb: List[float] = boxes[bb_idx].tolist()
raise ValueError("All bounding boxes should have positive height and width."
" Found invalid box {} for target at index {}."
.format(degen_bb, target_idx))
# get the features from the backbone
features = self.backbone(images.tensors)
if isinstance(features, torch.Tensor):
features = OrderedDict([('0', features)])
# TODO: Do we want a list or a dict?
features = list(features.values())
# compute the retinanet heads outputs using the features
head_outputs = self.head(features)
# create the set of anchors
anchors = self.anchor_generator(images, features)
losses = {}
detections = torch.jit.annotate(List[Dict[str, Tensor]], [])
if self.training:
assert targets is not None
# compute the losses
losses = self.compute_loss(targets, head_outputs, anchors)
else:
# compute the detections
detections = self.postprocess_detections(head_outputs, anchors, images.image_sizes)
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
if torch.jit.is_scripting():
if not self._has_warned:
warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting")
self._has_warned = True
return (losses, detections)
return self.eager_outputs(losses, detections)
model_urls = {
'retinanet_resnet50_fpn_coco':
'https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth',
}
def retinanet_resnet50_fpn(pretrained=False, progress=True,
num_classes=91, pretrained_backbone=True, **kwargs):
"""
Constructs a RetinaNet model with a ResNet-50-FPN backbone.
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
image, and should be in ``0-1`` range. Different images can have different sizes.
The behavior of the model changes depending if it is in training or evaluation mode.
During training, the model expects both the input tensors, as well as a targets (list of dictionary),
containing:
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with values
between ``0`` and ``H`` and ``0`` and ``W``
- labels (``Int64Tensor[N]``): the class label for each ground-truth box
The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
losses.
During inference, the model requires only the input tensors, and returns the post-processed
predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
follows:
- boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with values between
``0`` and ``H`` and ``0`` and ``W``
- labels (``Int64Tensor[N]``): the predicted labels for each image
- scores (``Tensor[N]``): the scores or each prediction
Example::
>>> model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=True)
>>> model.eval()
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x)
Arguments:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
"""
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
# skip P2 because it generates too many anchors (according to their paper)
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone,
returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256))
model = RetinaNet(backbone, num_classes, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['retinanet_resnet50_fpn_coco'],
progress=progress)
model.load_state_dict(state_dict)
return model
...@@ -11,6 +11,9 @@ from .image_list import ImageList ...@@ -11,6 +11,9 @@ from .image_list import ImageList
from torch.jit.annotations import List, Optional, Dict, Tuple from torch.jit.annotations import List, Optional, Dict, Tuple
# Import AnchorGenerator to keep compatibility.
from .anchor_utils import AnchorGenerator
@torch.jit.unused @torch.jit.unused
def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n): def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
...@@ -24,159 +27,6 @@ def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n): ...@@ -24,159 +27,6 @@ def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
return num_anchors, pre_nms_top_n return num_anchors, pre_nms_top_n
class AnchorGenerator(nn.Module):
__annotations__ = {
"cell_anchors": Optional[List[torch.Tensor]],
"_cache": Dict[str, List[torch.Tensor]]
}
"""
Module that generates anchors for a set of feature maps and
image sizes.
The module support computing anchors at multiple sizes and aspect ratios
per feature map. This module assumes aspect ratio = height / width for
each anchor.
sizes and aspect_ratios should have the same number of elements, and it should
correspond to the number of feature maps.
sizes[i] and aspect_ratios[i] can have an arbitrary number of elements,
and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors
per spatial location for feature map i.
Arguments:
sizes (Tuple[Tuple[int]]):
aspect_ratios (Tuple[Tuple[float]]):
"""
def __init__(
self,
sizes=((128, 256, 512),),
aspect_ratios=((0.5, 1.0, 2.0),),
):
super(AnchorGenerator, self).__init__()
if not isinstance(sizes[0], (list, tuple)):
# TODO change this
sizes = tuple((s,) for s in sizes)
if not isinstance(aspect_ratios[0], (list, tuple)):
aspect_ratios = (aspect_ratios,) * len(sizes)
assert len(sizes) == len(aspect_ratios)
self.sizes = sizes
self.aspect_ratios = aspect_ratios
self.cell_anchors = None
self._cache = {}
# TODO: https://github.com/pytorch/pytorch/issues/26792
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
# (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
# This method assumes aspect ratio = height / width for an anchor.
def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"):
# type: (List[int], List[float], int, Device) -> Tensor # noqa: F821
scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios)
w_ratios = 1 / h_ratios
ws = (w_ratios[:, None] * scales[None, :]).view(-1)
hs = (h_ratios[:, None] * scales[None, :]).view(-1)
base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
return base_anchors.round()
def set_cell_anchors(self, dtype, device):
# type: (int, Device) -> None # noqa: F821
if self.cell_anchors is not None:
cell_anchors = self.cell_anchors
assert cell_anchors is not None
# suppose that all anchors have the same device
# which is a valid assumption in the current state of the codebase
if cell_anchors[0].device == device:
return
cell_anchors = [
self.generate_anchors(
sizes,
aspect_ratios,
dtype,
device
)
for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios)
]
self.cell_anchors = cell_anchors
def num_anchors_per_location(self):
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
def grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
anchors = []
cell_anchors = self.cell_anchors
assert cell_anchors is not None
assert len(grid_sizes) == len(strides) == len(cell_anchors)
for size, stride, base_anchors in zip(
grid_sizes, strides, cell_anchors
):
grid_height, grid_width = size
stride_height, stride_width = stride
device = base_anchors.device
# For output anchor, compute [x_center, y_center, x_center, y_center]
shifts_x = torch.arange(
0, grid_width, dtype=torch.float32, device=device
) * stride_width
shifts_y = torch.arange(
0, grid_height, dtype=torch.float32, device=device
) * stride_height
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
# For every (base anchor, output anchor) pair,
# offset each zero-centered base anchor by the center of the output anchor.
anchors.append(
(shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)
)
return anchors
def cached_grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
key = str(grid_sizes) + str(strides)
if key in self._cache:
return self._cache[key]
anchors = self.grid_anchors(grid_sizes, strides)
self._cache[key] = anchors
return anchors
def forward(self, image_list, feature_maps):
# type: (ImageList, List[Tensor]) -> List[Tensor]
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:]
dtype, device = feature_maps[0].dtype, feature_maps[0].device
strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = torch.jit.annotate(List[List[torch.Tensor]], [])
for i, (image_height, image_width) in enumerate(image_list.image_sizes):
anchors_in_image = []
for anchors_per_feature_map in anchors_over_all_feature_maps:
anchors_in_image.append(anchors_per_feature_map)
anchors.append(anchors_in_image)
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
# Clear the cache in case that memory leaks.
self._cache.clear()
return anchors
class RPNHead(nn.Module): class RPNHead(nn.Module):
""" """
Adds a simple RPN Head with classification and regression heads Adds a simple RPN Head with classification and regression heads
...@@ -338,7 +188,7 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -338,7 +188,7 @@ class RegionProposalNetwork(torch.nn.Module):
matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device) matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device)
labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device) labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device)
else: else:
match_quality_matrix = box_ops.box_iou(gt_boxes, anchors_per_image) match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image)
matched_idxs = self.proposal_matcher(match_quality_matrix) matched_idxs = self.proposal_matcher(match_quality_matrix)
# get the targets corresponding GT for each proposal # get the targets corresponding GT for each proposal
# NB: need to clamp the indices because we can have a single # NB: need to clamp the indices because we can have a single
......
...@@ -8,6 +8,7 @@ from .ps_roi_align import ps_roi_align, PSRoIAlign ...@@ -8,6 +8,7 @@ from .ps_roi_align import ps_roi_align, PSRoIAlign
from .ps_roi_pool import ps_roi_pool, PSRoIPool from .ps_roi_pool import ps_roi_pool, PSRoIPool
from .poolers import MultiScaleRoIAlign from .poolers import MultiScaleRoIAlign
from .feature_pyramid_network import FeaturePyramidNetwork from .feature_pyramid_network import FeaturePyramidNetwork
from .focal_loss import sigmoid_focal_loss
from ._register_onnx_ops import _register_custom_op from ._register_onnx_ops import _register_custom_op
...@@ -19,5 +20,6 @@ __all__ = [ ...@@ -19,5 +20,6 @@ __all__ = [
'clip_boxes_to_image', 'box_convert', 'clip_boxes_to_image', 'box_convert',
'box_area', 'box_iou', 'generalized_box_iou', 'roi_align', 'RoIAlign', 'roi_pool', 'box_area', 'box_iou', 'generalized_box_iou', 'roi_align', 'RoIAlign', 'roi_pool',
'RoIPool', '_new_empty_tensor', 'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool', 'RoIPool', '_new_empty_tensor', 'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool',
'PSRoIPool', 'MultiScaleRoIAlign', 'FeaturePyramidNetwork' 'PSRoIPool', 'MultiScaleRoIAlign', 'FeaturePyramidNetwork',
'sigmoid_focal_loss'
] ]
import torch
import torch.nn.functional as F
def sigmoid_focal_loss(
inputs,
targets,
alpha: float = 0.25,
gamma: float = 2,
reduction: str = "none",
):
"""
Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py .
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples or -1 for ignore. Default = 0.25
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
reduction: 'none' | 'mean' | 'sum'
'none': No reduction will be applied to the output.
'mean': The output will be averaged.
'sum': The output will be summed.
Returns:
Loss tensor with the reduction option applied.
"""
p = torch.sigmoid(inputs)
ce_loss = F.binary_cross_entropy_with_logits(
inputs, targets, reduction="none"
)
p_t = p * targets + (1 - p) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()
return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment