Unverified Commit 5bb81c8e authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

RetinaNet object detection (take 2) (#2784)



* Add rough implementation of RetinaNet.

* Move AnchorGenerator to a seperate file.

* Move box similarity to Matcher.

* Expose extra blocks in FPN.

* Expose retinanet in __init__.py.

* Use P6 and P7 in FPN for retinanet.

* Use parameters from retinanet for anchor generation.

* General fixes for retinanet model.

* Implement loss for retinanet heads.

* Output reshaped outputs from retinanet heads.

* Add postprocessing of detections.

* Small fixes.

* Remove unused argument.

* Remove python2 invocation of super.

* Add postprocessing for additional outputs.

* Add missing import of ImageList.

* Remove redundant import.

* Simplify class correction.

* Fix pylint warnings.

* Remove the label adjustment for background class.

* Set default score threshold to 0.05.

* Add weight initialization for regression layer.

* Allow training on images with no annotations.

* Use smooth_l1_loss with beta value.

* Add more typehints for TorchScript conversions.

* Fix linting issues.

* Fix type hints in postprocess_detections.

* Fix type annotations for TorchScript.

* Fix inconsistency with matched_idxs.

* Add retinanet model test.

* Add missing JIT annotations.

* Remove redundant model construction

Make tests pass

* Fix bugs during training on newer PyTorch and unused params in DDP

Needs cleanup and to add back support for images with no annotations

* Cleanup resnet_fpn_backbone

* Use L1 loss for regression

Gives 1mAP improvement over smooth l1

* Disable support for images with no annotations

Need to fix distributed first

* Fix retinanet tests

Need to deduplicate those box checks

* Fix Lint

* Add pretrained model

* Add training info for retinanet
Co-authored-by: default avatarHans Gaiser <hansg91@gmail.com>
Co-authored-by: default avatarHans Gaiser <hans.gaiser@robovalley.com>
Co-authored-by: default avatarHans Gaiser <hans.gaiser@robohouse.com>
parent 42e7f1f0
...@@ -350,6 +350,7 @@ the instances set of COCO train2017 and evaluated on COCO val2017. ...@@ -350,6 +350,7 @@ the instances set of COCO train2017 and evaluated on COCO val2017.
Network box AP mask AP keypoint AP Network box AP mask AP keypoint AP
================================ ======= ======== =========== ================================ ======= ======== ===========
Faster R-CNN ResNet-50 FPN 37.0 - - Faster R-CNN ResNet-50 FPN 37.0 - -
RetinaNet ResNet-50 FPN 36.4 - -
Mask R-CNN ResNet-50 FPN 37.9 34.6 - Mask R-CNN ResNet-50 FPN 37.9 34.6 -
================================ ======= ======== =========== ================================ ======= ======== ===========
...@@ -405,6 +406,7 @@ precision-recall. ...@@ -405,6 +406,7 @@ precision-recall.
Network train time (s / it) test time (s / it) memory (GB) Network train time (s / it) test time (s / it) memory (GB)
============================== =================== ================== =========== ============================== =================== ================== ===========
Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2 Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4 Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8 Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
============================== =================== ================== =========== ============================== =================== ================== ===========
...@@ -416,6 +418,12 @@ Faster R-CNN ...@@ -416,6 +418,12 @@ Faster R-CNN
.. autofunction:: torchvision.models.detection.fasterrcnn_resnet50_fpn .. autofunction:: torchvision.models.detection.fasterrcnn_resnet50_fpn
RetinaNet
------------
.. autofunction:: torchvision.models.detection.retinanet_resnet50_fpn
Mask R-CNN Mask R-CNN
---------- ----------
......
...@@ -27,6 +27,13 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ ...@@ -27,6 +27,13 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr-steps 16 22 --aspect-ratio-group-factor 3
``` ```
### RetinaNet
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--dataset coco --model retinanet_resnet50_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01
```
### Mask R-CNN ### Mask R-CNN
``` ```
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
from .faster_rcnn import * from .faster_rcnn import *
from .mask_rcnn import * from .mask_rcnn import *
from .keypoint_rcnn import * from .keypoint_rcnn import *
from .retinanet import *
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from torch import nn
from torch.jit.annotations import List, Optional, Dict
from .image_list import ImageList
class AnchorGenerator(nn.Module):
"""
Module that generates anchors for a set of feature maps and
image sizes.
The module support computing anchors at multiple sizes and aspect ratios
per feature map. This module assumes aspect ratio = height / width for
each anchor.
sizes and aspect_ratios should have the same number of elements, and it should
correspond to the number of feature maps.
sizes[i] and aspect_ratios[i] can have an arbitrary number of elements,
and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors
per spatial location for feature map i.
Arguments:
sizes (Tuple[Tuple[int]]):
aspect_ratios (Tuple[Tuple[float]]):
"""
__annotations__ = {
"cell_anchors": Optional[List[torch.Tensor]],
"_cache": Dict[str, List[torch.Tensor]]
}
def __init__(
self,
sizes=((128, 256, 512),),
aspect_ratios=((0.5, 1.0, 2.0),),
):
super(AnchorGenerator, self).__init__()
if not isinstance(sizes[0], (list, tuple)):
# TODO change this
sizes = tuple((s,) for s in sizes)
if not isinstance(aspect_ratios[0], (list, tuple)):
aspect_ratios = (aspect_ratios,) * len(sizes)
assert len(sizes) == len(aspect_ratios)
self.sizes = sizes
self.aspect_ratios = aspect_ratios
self.cell_anchors = None
self._cache = {}
# TODO: https://github.com/pytorch/pytorch/issues/26792
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
# (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
# This method assumes aspect ratio = height / width for an anchor.
def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"):
# type: (List[int], List[float], int, Device) -> Tensor # noqa: F821
scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios)
w_ratios = 1 / h_ratios
ws = (w_ratios[:, None] * scales[None, :]).view(-1)
hs = (h_ratios[:, None] * scales[None, :]).view(-1)
base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
return base_anchors.round()
def set_cell_anchors(self, dtype, device):
# type: (int, Device) -> None # noqa: F821
if self.cell_anchors is not None:
cell_anchors = self.cell_anchors
assert cell_anchors is not None
# suppose that all anchors have the same device
# which is a valid assumption in the current state of the codebase
if cell_anchors[0].device == device:
return
cell_anchors = [
self.generate_anchors(
sizes,
aspect_ratios,
dtype,
device
)
for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios)
]
self.cell_anchors = cell_anchors
def num_anchors_per_location(self):
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
def grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
anchors = []
cell_anchors = self.cell_anchors
assert cell_anchors is not None
assert len(grid_sizes) == len(strides) == len(cell_anchors)
for size, stride, base_anchors in zip(
grid_sizes, strides, cell_anchors
):
grid_height, grid_width = size
stride_height, stride_width = stride
device = base_anchors.device
# For output anchor, compute [x_center, y_center, x_center, y_center]
shifts_x = torch.arange(
0, grid_width, dtype=torch.float32, device=device
) * stride_width
shifts_y = torch.arange(
0, grid_height, dtype=torch.float32, device=device
) * stride_height
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
# For every (base anchor, output anchor) pair,
# offset each zero-centered base anchor by the center of the output anchor.
anchors.append(
(shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)
)
return anchors
def cached_grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
key = str(grid_sizes) + str(strides)
if key in self._cache:
return self._cache[key]
anchors = self.grid_anchors(grid_sizes, strides)
self._cache[key] = anchors
return anchors
def forward(self, image_list, feature_maps):
# type: (ImageList, List[Tensor]) -> List[Tensor]
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:]
dtype, device = feature_maps[0].dtype, feature_maps[0].device
strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = torch.jit.annotate(List[List[torch.Tensor]], [])
for i, (image_height, image_width) in enumerate(image_list.image_sizes):
anchors_in_image = []
for anchors_per_feature_map in anchors_over_all_feature_maps:
anchors_in_image.append(anchors_per_feature_map)
anchors.append(anchors_in_image)
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
# Clear the cache in case that memory leaks.
self._cache.clear()
return anchors
...@@ -25,13 +25,17 @@ class BackboneWithFPN(nn.Module): ...@@ -25,13 +25,17 @@ class BackboneWithFPN(nn.Module):
Attributes: Attributes:
out_channels (int): the number of channels in the FPN out_channels (int): the number of channels in the FPN
""" """
def __init__(self, backbone, return_layers, in_channels_list, out_channels): def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None):
super(BackboneWithFPN, self).__init__() super(BackboneWithFPN, self).__init__()
if extra_blocks is None:
extra_blocks = LastLevelMaxPool()
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.fpn = FeaturePyramidNetwork( self.fpn = FeaturePyramidNetwork(
in_channels_list=in_channels_list, in_channels_list=in_channels_list,
out_channels=out_channels, out_channels=out_channels,
extra_blocks=LastLevelMaxPool(), extra_blocks=extra_blocks,
) )
self.out_channels = out_channels self.out_channels = out_channels
...@@ -41,7 +45,14 @@ class BackboneWithFPN(nn.Module): ...@@ -41,7 +45,14 @@ class BackboneWithFPN(nn.Module):
return x return x
def resnet_fpn_backbone(backbone_name, pretrained, norm_layer=misc_nn_ops.FrozenBatchNorm2d, trainable_layers=3): def resnet_fpn_backbone(
backbone_name,
pretrained,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=3,
returned_layers=None,
extra_blocks=None
):
""" """
Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone. Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.
...@@ -82,14 +93,15 @@ def resnet_fpn_backbone(backbone_name, pretrained, norm_layer=misc_nn_ops.Frozen ...@@ -82,14 +93,15 @@ def resnet_fpn_backbone(backbone_name, pretrained, norm_layer=misc_nn_ops.Frozen
if all([not name.startswith(layer) for layer in layers_to_train]): if all([not name.startswith(layer) for layer in layers_to_train]):
parameter.requires_grad_(False) parameter.requires_grad_(False)
return_layers = {'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'} if extra_blocks is None:
extra_blocks = LastLevelMaxPool()
if returned_layers is None:
returned_layers = [1, 2, 3, 4]
assert min(returned_layers) > 0 and max(returned_layers) < 5
return_layers = {f'layer{k}': str(v) for v, k in enumerate(returned_layers)}
in_channels_stage2 = backbone.inplanes // 8 in_channels_stage2 = backbone.inplanes // 8
in_channels_list = [ in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
in_channels_stage2,
in_channels_stage2 * 2,
in_channels_stage2 * 4,
in_channels_stage2 * 8,
]
out_channels = 256 out_channels = 256
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels) return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
...@@ -9,8 +9,9 @@ from torchvision.ops import MultiScaleRoIAlign ...@@ -9,8 +9,9 @@ from torchvision.ops import MultiScaleRoIAlign
from ..utils import load_state_dict_from_url from ..utils import load_state_dict_from_url
from .anchor_utils import AnchorGenerator
from .generalized_rcnn import GeneralizedRCNN from .generalized_rcnn import GeneralizedRCNN
from .rpn import AnchorGenerator, RPNHead, RegionProposalNetwork from .rpn import RPNHead, RegionProposalNetwork
from .roi_heads import RoIHeads from .roi_heads import RoIHeads
from .transform import GeneralizedRCNNTransform from .transform import GeneralizedRCNNTransform
from .backbone_utils import resnet_fpn_backbone from .backbone_utils import resnet_fpn_backbone
......
...@@ -103,7 +103,7 @@ class KeypointRCNN(FasterRCNN): ...@@ -103,7 +103,7 @@ class KeypointRCNN(FasterRCNN):
>>> import torch >>> import torch
>>> import torchvision >>> import torchvision
>>> from torchvision.models.detection import KeypointRCNN >>> from torchvision.models.detection import KeypointRCNN
>>> from torchvision.models.detection.rpn import AnchorGenerator >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
>>> >>>
>>> # load a pre-trained model for classification and return >>> # load a pre-trained model for classification and return
>>> # only the features >>> # only the features
......
...@@ -107,7 +107,7 @@ class MaskRCNN(FasterRCNN): ...@@ -107,7 +107,7 @@ class MaskRCNN(FasterRCNN):
>>> import torch >>> import torch
>>> import torchvision >>> import torchvision
>>> from torchvision.models.detection import MaskRCNN >>> from torchvision.models.detection import MaskRCNN
>>> from torchvision.models.detection.rpn import AnchorGenerator >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
>>> >>>
>>> # load a pre-trained model for classification and return >>> # load a pre-trained model for classification and return
>>> # only the features >>> # only the features
......
This diff is collapsed.
...@@ -11,6 +11,9 @@ from .image_list import ImageList ...@@ -11,6 +11,9 @@ from .image_list import ImageList
from torch.jit.annotations import List, Optional, Dict, Tuple from torch.jit.annotations import List, Optional, Dict, Tuple
# Import AnchorGenerator to keep compatibility.
from .anchor_utils import AnchorGenerator
@torch.jit.unused @torch.jit.unused
def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n): def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
...@@ -24,159 +27,6 @@ def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n): ...@@ -24,159 +27,6 @@ def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
return num_anchors, pre_nms_top_n return num_anchors, pre_nms_top_n
class AnchorGenerator(nn.Module):
__annotations__ = {
"cell_anchors": Optional[List[torch.Tensor]],
"_cache": Dict[str, List[torch.Tensor]]
}
"""
Module that generates anchors for a set of feature maps and
image sizes.
The module support computing anchors at multiple sizes and aspect ratios
per feature map. This module assumes aspect ratio = height / width for
each anchor.
sizes and aspect_ratios should have the same number of elements, and it should
correspond to the number of feature maps.
sizes[i] and aspect_ratios[i] can have an arbitrary number of elements,
and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors
per spatial location for feature map i.
Arguments:
sizes (Tuple[Tuple[int]]):
aspect_ratios (Tuple[Tuple[float]]):
"""
def __init__(
self,
sizes=((128, 256, 512),),
aspect_ratios=((0.5, 1.0, 2.0),),
):
super(AnchorGenerator, self).__init__()
if not isinstance(sizes[0], (list, tuple)):
# TODO change this
sizes = tuple((s,) for s in sizes)
if not isinstance(aspect_ratios[0], (list, tuple)):
aspect_ratios = (aspect_ratios,) * len(sizes)
assert len(sizes) == len(aspect_ratios)
self.sizes = sizes
self.aspect_ratios = aspect_ratios
self.cell_anchors = None
self._cache = {}
# TODO: https://github.com/pytorch/pytorch/issues/26792
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
# (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
# This method assumes aspect ratio = height / width for an anchor.
def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"):
# type: (List[int], List[float], int, Device) -> Tensor # noqa: F821
scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios)
w_ratios = 1 / h_ratios
ws = (w_ratios[:, None] * scales[None, :]).view(-1)
hs = (h_ratios[:, None] * scales[None, :]).view(-1)
base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
return base_anchors.round()
def set_cell_anchors(self, dtype, device):
# type: (int, Device) -> None # noqa: F821
if self.cell_anchors is not None:
cell_anchors = self.cell_anchors
assert cell_anchors is not None
# suppose that all anchors have the same device
# which is a valid assumption in the current state of the codebase
if cell_anchors[0].device == device:
return
cell_anchors = [
self.generate_anchors(
sizes,
aspect_ratios,
dtype,
device
)
for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios)
]
self.cell_anchors = cell_anchors
def num_anchors_per_location(self):
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
def grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
anchors = []
cell_anchors = self.cell_anchors
assert cell_anchors is not None
assert len(grid_sizes) == len(strides) == len(cell_anchors)
for size, stride, base_anchors in zip(
grid_sizes, strides, cell_anchors
):
grid_height, grid_width = size
stride_height, stride_width = stride
device = base_anchors.device
# For output anchor, compute [x_center, y_center, x_center, y_center]
shifts_x = torch.arange(
0, grid_width, dtype=torch.float32, device=device
) * stride_width
shifts_y = torch.arange(
0, grid_height, dtype=torch.float32, device=device
) * stride_height
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
# For every (base anchor, output anchor) pair,
# offset each zero-centered base anchor by the center of the output anchor.
anchors.append(
(shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)
)
return anchors
def cached_grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
key = str(grid_sizes) + str(strides)
if key in self._cache:
return self._cache[key]
anchors = self.grid_anchors(grid_sizes, strides)
self._cache[key] = anchors
return anchors
def forward(self, image_list, feature_maps):
# type: (ImageList, List[Tensor]) -> List[Tensor]
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:]
dtype, device = feature_maps[0].dtype, feature_maps[0].device
strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = torch.jit.annotate(List[List[torch.Tensor]], [])
for i, (image_height, image_width) in enumerate(image_list.image_sizes):
anchors_in_image = []
for anchors_per_feature_map in anchors_over_all_feature_maps:
anchors_in_image.append(anchors_per_feature_map)
anchors.append(anchors_in_image)
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
# Clear the cache in case that memory leaks.
self._cache.clear()
return anchors
class RPNHead(nn.Module): class RPNHead(nn.Module):
""" """
Adds a simple RPN Head with classification and regression heads Adds a simple RPN Head with classification and regression heads
...@@ -338,7 +188,7 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -338,7 +188,7 @@ class RegionProposalNetwork(torch.nn.Module):
matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device) matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device)
labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device) labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device)
else: else:
match_quality_matrix = box_ops.box_iou(gt_boxes, anchors_per_image) match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image)
matched_idxs = self.proposal_matcher(match_quality_matrix) matched_idxs = self.proposal_matcher(match_quality_matrix)
# get the targets corresponding GT for each proposal # get the targets corresponding GT for each proposal
# NB: need to clamp the indices because we can have a single # NB: need to clamp the indices because we can have a single
......
...@@ -8,6 +8,7 @@ from .ps_roi_align import ps_roi_align, PSRoIAlign ...@@ -8,6 +8,7 @@ from .ps_roi_align import ps_roi_align, PSRoIAlign
from .ps_roi_pool import ps_roi_pool, PSRoIPool from .ps_roi_pool import ps_roi_pool, PSRoIPool
from .poolers import MultiScaleRoIAlign from .poolers import MultiScaleRoIAlign
from .feature_pyramid_network import FeaturePyramidNetwork from .feature_pyramid_network import FeaturePyramidNetwork
from .focal_loss import sigmoid_focal_loss
from ._register_onnx_ops import _register_custom_op from ._register_onnx_ops import _register_custom_op
...@@ -19,5 +20,6 @@ __all__ = [ ...@@ -19,5 +20,6 @@ __all__ = [
'clip_boxes_to_image', 'box_convert', 'clip_boxes_to_image', 'box_convert',
'box_area', 'box_iou', 'generalized_box_iou', 'roi_align', 'RoIAlign', 'roi_pool', 'box_area', 'box_iou', 'generalized_box_iou', 'roi_align', 'RoIAlign', 'roi_pool',
'RoIPool', '_new_empty_tensor', 'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool', 'RoIPool', '_new_empty_tensor', 'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool',
'PSRoIPool', 'MultiScaleRoIAlign', 'FeaturePyramidNetwork' 'PSRoIPool', 'MultiScaleRoIAlign', 'FeaturePyramidNetwork',
'sigmoid_focal_loss'
] ]
import torch
import torch.nn.functional as F
def sigmoid_focal_loss(
inputs,
targets,
alpha: float = 0.25,
gamma: float = 2,
reduction: str = "none",
):
"""
Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py .
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples or -1 for ignore. Default = 0.25
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
reduction: 'none' | 'mean' | 'sum'
'none': No reduction will be applied to the output.
'mean': The output will be averaged.
'sum': The output will be summed.
Returns:
Loss tensor with the reduction option applied.
"""
p = torch.sigmoid(inputs)
ce_loss = F.binary_cross_entropy_with_logits(
inputs, targets, reduction="none"
)
p_t = p * targets + (1 - p) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()
return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment