Unverified Commit 7d4bdd43 authored by Hu Ye's avatar Hu Ye Committed by GitHub
Browse files

add FCOS (#4961)



* add fcos

* update fcos

* add giou_loss

* add BoxLinearCoder for FCOS

* add full code for FCOS

* add giou loss

* add fcos

* add __all__

* Fixing lint

* Fixing lint in giou_loss.py

* Add typing annotation to fcos

* Add trained checkpoints

* Use partial to replace lambda

* Minor fixes to docstrings

* Apply ufmt format

* Fixing docstrings

* Fixing jit scripting

* Minor fixes to docstrings

* Fixing jit scripting

* Ignore mypy in fcos

* Fixing trained checkpoints

* Fixing unit-test of jit script

* Fixing docstrings

* Add test/expect/ModelTester.test_fcos_resnet50_fpn_expect.pkl

* Fixing test_detection_model_trainable_backbone_layers

* Update test_fcos_resnet50_fpn_expect.pkl

* rename stride to box size

* remove TODO and fix some typo

* merge some code for better

* impove the comments

* remove decode and encode of BoxLinearCoder

* remove some unnecessary hints

* use the default value in detectron2.

* update doc

* Add unittest for BoxLinearCoder

* Add types in FCOS

* Add docstring for BoxLinearCoder

* Minor fix for the docstring

* update doc

* Update fcos_resnet50_fpn_coco pretained weights url

* Update torchvision/models/detection/fcos.py
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* Update torchvision/models/detection/fcos.py
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* Update torchvision/models/detection/fcos.py
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* Update torchvision/models/detection/fcos.py
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* Add FCOS model documentation

* Fix typo in FCOS documentation

* Add fcos to the prototype builder

* Capitalize COCO_V1

* Fix params of fcos

* fix bug for partial

* Fixing docs indentation

* Fixing docs format in giou_loss

* Adopt Reference for GIoU Loss

* Rename giou_loss to generalized_box_iou_loss

* remove overwrite_eps

* Update AP test values

* Minor fixes for the docs

* Minor fixes for the docs

* Update torchvision/models/detection/fcos.py
Co-authored-by: default avatarZhiqiang Wang <zhiqwang@foxmail.com>

* Update torchvision/prototype/models/detection/fcos.py
Co-authored-by: default avatarZhiqiang Wang <zhiqwang@foxmail.com>
Co-authored-by: default avatarzhiqiang <zhiqwang@foxmail.com>
Co-authored-by: default avatarJoao Gomes <jdsgomes@fb.com>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarJoao Gomes <joaopsgomes@gmail.com>
parent fe65d379
...@@ -597,6 +597,7 @@ The models subpackage contains definitions for the following model ...@@ -597,6 +597,7 @@ The models subpackage contains definitions for the following model
architectures for detection: architectures for detection:
- `Faster R-CNN <https://arxiv.org/abs/1506.01497>`_ - `Faster R-CNN <https://arxiv.org/abs/1506.01497>`_
- `FCOS <https://arxiv.org/abs/1904.01355>`_
- `Mask R-CNN <https://arxiv.org/abs/1703.06870>`_ - `Mask R-CNN <https://arxiv.org/abs/1703.06870>`_
- `RetinaNet <https://arxiv.org/abs/1708.02002>`_ - `RetinaNet <https://arxiv.org/abs/1708.02002>`_
- `SSD <https://arxiv.org/abs/1512.02325>`_ - `SSD <https://arxiv.org/abs/1512.02325>`_
...@@ -642,6 +643,7 @@ Network box AP mask AP keypoint AP ...@@ -642,6 +643,7 @@ Network box AP mask AP keypoint AP
Faster R-CNN ResNet-50 FPN 37.0 - - Faster R-CNN ResNet-50 FPN 37.0 - -
Faster R-CNN MobileNetV3-Large FPN 32.8 - - Faster R-CNN MobileNetV3-Large FPN 32.8 - -
Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - - Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - -
FCOS ResNet-50 FPN 39.2 - -
RetinaNet ResNet-50 FPN 36.4 - - RetinaNet ResNet-50 FPN 36.4 - -
SSD300 VGG16 25.1 - - SSD300 VGG16 25.1 - -
SSDlite320 MobileNetV3-Large 21.3 - - SSDlite320 MobileNetV3-Large 21.3 - -
...@@ -702,6 +704,7 @@ Network train time (s / it) test time (s / it) ...@@ -702,6 +704,7 @@ Network train time (s / it) test time (s / it)
Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2 Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2
Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415 1.0 Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415 1.0
Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6 Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6
FCOS ResNet-50 FPN 0.1450 0.0539 3.3
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1 RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
SSD300 VGG16 0.2093 0.0744 1.5 SSD300 VGG16 0.2093 0.0744 1.5
SSDlite320 MobileNetV3-Large 0.1773 0.0906 1.5 SSDlite320 MobileNetV3-Large 0.1773 0.0906 1.5
...@@ -721,6 +724,15 @@ Faster R-CNN ...@@ -721,6 +724,15 @@ Faster R-CNN
torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn
torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn
FCOS
----
.. autosummary::
:toctree: generated/
:template: function.rst
torchvision.models.detection.fcos_resnet50_fpn
RetinaNet RetinaNet
--------- ---------
......
...@@ -70,6 +70,10 @@ ignore_errors = True ...@@ -70,6 +70,10 @@ ignore_errors = True
ignore_errors = True ignore_errors = True
[mypy-torchvision.models.detection.fcos]
ignore_errors = True
[mypy-torchvision.ops.*] [mypy-torchvision.ops.*]
ignore_errors = True ignore_errors = True
......
...@@ -41,6 +41,13 @@ torchrun --nproc_per_node=8 train.py\ ...@@ -41,6 +41,13 @@ torchrun --nproc_per_node=8 train.py\
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr-steps 16 22 --aspect-ratio-group-factor 3
``` ```
### FCOS ResNet-50 FPN
```
torchrun --nproc_per_node=8 train.py\
--dataset coco --model fcos_resnet50_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --amp
```
### RetinaNet ### RetinaNet
``` ```
torchrun --nproc_per_node=8 train.py\ torchrun --nproc_per_node=8 train.py\
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
...@@ -218,6 +218,7 @@ script_model_unwrapper = { ...@@ -218,6 +218,7 @@ script_model_unwrapper = {
"retinanet_resnet50_fpn": lambda x: x[1], "retinanet_resnet50_fpn": lambda x: x[1],
"ssd300_vgg16": lambda x: x[1], "ssd300_vgg16": lambda x: x[1],
"ssdlite320_mobilenet_v3_large": lambda x: x[1], "ssdlite320_mobilenet_v3_large": lambda x: x[1],
"fcos_resnet50_fpn": lambda x: x[1],
} }
...@@ -274,6 +275,13 @@ _model_params = { ...@@ -274,6 +275,13 @@ _model_params = {
"max_size": 224, "max_size": 224,
"input_shape": (3, 224, 224), "input_shape": (3, 224, 224),
}, },
"fcos_resnet50_fpn": {
"num_classes": 2,
"score_thresh": 0.05,
"min_size": 224,
"max_size": 224,
"input_shape": (3, 224, 224),
},
"maskrcnn_resnet50_fpn": { "maskrcnn_resnet50_fpn": {
"num_classes": 10, "num_classes": 10,
"min_size": 224, "min_size": 224,
...@@ -325,6 +333,10 @@ _model_tests_values = { ...@@ -325,6 +333,10 @@ _model_tests_values = {
"max_trainable": 6, "max_trainable": 6,
"n_trn_params_per_layer": [96, 99, 138, 200, 239, 257, 266], "n_trn_params_per_layer": [96, 99, 138, 200, 239, 257, 266],
}, },
"fcos_resnet50_fpn": {
"max_trainable": 5,
"n_trn_params_per_layer": [54, 64, 83, 96, 106, 107],
},
} }
......
...@@ -22,6 +22,19 @@ class TestModelsDetectionUtils: ...@@ -22,6 +22,19 @@ class TestModelsDetectionUtils:
assert neg[0].sum() == 3 assert neg[0].sum() == 3
assert neg[0][0:6].sum() == 3 assert neg[0][0:6].sum() == 3
def test_box_linear_coder(self):
box_coder = _utils.BoxLinearCoder(normalize_by_size=True)
# Generate a random 10x4 boxes tensor, with coordinates < 50.
boxes = torch.rand(10, 4) * 50
boxes.clamp_(min=1.0) # tiny boxes cause numerical instability in box regression
boxes[:, 2:] += boxes[:, :2]
proposals = torch.tensor([0, 0, 101, 101] * 10).reshape(10, 4).float()
rel_codes = box_coder.encode_single(boxes, proposals)
pred_boxes = box_coder.decode_single(rel_codes, boxes)
torch.allclose(proposals, pred_boxes)
@pytest.mark.parametrize("train_layers, exp_froz_params", [(0, 53), (1, 43), (2, 24), (3, 11), (4, 1), (5, 0)]) @pytest.mark.parametrize("train_layers, exp_froz_params", [(0, 53), (1, 43), (2, 24), (3, 11), (4, 1), (5, 0)])
def test_resnet_fpn_backbone_frozen_layers(self, train_layers, exp_froz_params): def test_resnet_fpn_backbone_frozen_layers(self, train_layers, exp_froz_params):
# we know how many initial layers and parameters of the network should # we know how many initial layers and parameters of the network should
......
...@@ -4,3 +4,4 @@ from .keypoint_rcnn import * ...@@ -4,3 +4,4 @@ from .keypoint_rcnn import *
from .retinanet import * from .retinanet import *
from .ssd import * from .ssd import *
from .ssdlite import * from .ssdlite import *
from .fcos import *
...@@ -217,6 +217,83 @@ class BoxCoder: ...@@ -217,6 +217,83 @@ class BoxCoder:
return pred_boxes return pred_boxes
class BoxLinearCoder:
"""
The linear box-to-box transform defined in FCOS. The transformation is parameterized
by the distance from the center of (square) src box to 4 edges of the target box.
"""
def __init__(self, normalize_by_size: bool = True) -> None:
"""
Args:
normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes.
"""
self.normalize_by_size = normalize_by_size
def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
"""
Encode a set of proposals with respect to some reference boxes
Args:
reference_boxes (Tensor): reference boxes
proposals (Tensor): boxes to be encoded
Returns:
Tensor: the encoded relative box offsets that can be used to
decode the boxes.
"""
# get the center of reference_boxes
reference_boxes_ctr_x = 0.5 * (reference_boxes[:, 0] + reference_boxes[:, 2])
reference_boxes_ctr_y = 0.5 * (reference_boxes[:, 1] + reference_boxes[:, 3])
# get box regression transformation deltas
target_l = reference_boxes_ctr_x - proposals[:, 0]
target_t = reference_boxes_ctr_y - proposals[:, 1]
target_r = proposals[:, 2] - reference_boxes_ctr_x
target_b = proposals[:, 3] - reference_boxes_ctr_y
targets = torch.stack((target_l, target_t, target_r, target_b), dim=1)
if self.normalize_by_size:
reference_boxes_w = reference_boxes[:, 2] - reference_boxes[:, 0]
reference_boxes_h = reference_boxes[:, 3] - reference_boxes[:, 1]
reference_boxes_size = torch.stack(
(reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=1
)
targets = targets / reference_boxes_size
return targets
def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
"""
From a set of original boxes and encoded relative box offsets,
get the decoded boxes.
Args:
rel_codes (Tensor): encoded boxes
boxes (Tensor): reference boxes.
Returns:
Tensor: the predicted boxes with the encoded relative box offsets.
"""
boxes = boxes.to(rel_codes.dtype)
ctr_x = 0.5 * (boxes[:, 0] + boxes[:, 2])
ctr_y = 0.5 * (boxes[:, 1] + boxes[:, 3])
if self.normalize_by_size:
boxes_w = boxes[:, 2] - boxes[:, 0]
boxes_h = boxes[:, 3] - boxes[:, 1]
boxes_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=1)
rel_codes = rel_codes * boxes_size
pred_boxes1 = ctr_x - rel_codes[:, 0]
pred_boxes2 = ctr_y - rel_codes[:, 1]
pred_boxes3 = ctr_x + rel_codes[:, 2]
pred_boxes4 = ctr_y + rel_codes[:, 3]
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=1)
return pred_boxes
class Matcher: class Matcher:
""" """
This class assigns to each predicted "element" (e.g., a box) a ground-truth This class assigns to each predicted "element" (e.g., a box) a ground-truth
......
import math
import warnings
from collections import OrderedDict
from functools import partial
from typing import Callable, Dict, List, Tuple, Optional
import torch
from torch import nn, Tensor
from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops import sigmoid_focal_loss, generalized_box_iou_loss
from ...ops import boxes as box_ops
from ...ops import misc as misc_nn_ops
from ...ops.feature_pyramid_network import LastLevelP6P7
from ...utils import _log_api_usage_once
from ..resnet import resnet50
from . import _utils as det_utils
from .anchor_utils import AnchorGenerator
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
from .transform import GeneralizedRCNNTransform
__all__ = ["FCOS", "fcos_resnet50_fpn"]
class FCOSHead(nn.Module):
"""
A regression and classification head for use in FCOS.
Args:
in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted
num_classes (int): number of classes to be predicted
num_convs (Optional[int]): number of conv layer of head. Default: 4.
"""
__annotations__ = {
"box_coder": det_utils.BoxLinearCoder,
}
def __init__(self, in_channels: int, num_anchors: int, num_classes: int, num_convs: Optional[int] = 4) -> None:
super().__init__()
self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True)
self.classification_head = FCOSClassificationHead(in_channels, num_anchors, num_classes, num_convs)
self.regression_head = FCOSRegressionHead(in_channels, num_anchors, num_convs)
def compute_loss(
self,
targets: List[Dict[str, Tensor]],
head_outputs: Dict[str, Tensor],
anchors: List[Tensor],
matched_idxs: List[Tensor],
) -> Dict[str, Tensor]:
cls_logits = head_outputs["cls_logits"] # [N, HWA, C]
bbox_regression = head_outputs["bbox_regression"] # [N, HWA, 4]
bbox_ctrness = head_outputs["bbox_ctrness"] # [N, HWA, 1]
all_gt_classes_targets = []
all_gt_boxes_targets = []
for targets_per_image, matched_idxs_per_image in zip(targets, matched_idxs):
gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)]
gt_classes_targets[matched_idxs_per_image < 0] = -1 # backgroud
gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)]
all_gt_classes_targets.append(gt_classes_targets)
all_gt_boxes_targets.append(gt_boxes_targets)
all_gt_classes_targets = torch.stack(all_gt_classes_targets)
# compute foregroud
foregroud_mask = all_gt_classes_targets >= 0
num_foreground = foregroud_mask.sum().item()
# classification loss
gt_classes_targets = torch.zeros_like(cls_logits)
gt_classes_targets[foregroud_mask, all_gt_classes_targets[foregroud_mask]] = 1.0
loss_cls = sigmoid_focal_loss(cls_logits, gt_classes_targets, reduction="sum")
# regression loss: GIoU loss
# TODO: vectorize this instead of using a for loop
pred_boxes = [
self.box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
for anchors_per_image, bbox_regression_per_image in zip(anchors, bbox_regression)
]
# amp issue: pred_boxes need to convert float
loss_bbox_reg = generalized_box_iou_loss(
torch.stack(pred_boxes)[foregroud_mask].float(),
torch.stack(all_gt_boxes_targets)[foregroud_mask],
reduction="sum",
)
# ctrness loss
bbox_reg_targets = [
self.box_coder.encode_single(anchors_per_image, boxes_targets_per_image)
for anchors_per_image, boxes_targets_per_image in zip(anchors, all_gt_boxes_targets)
]
bbox_reg_targets = torch.stack(bbox_reg_targets, dim=0)
if len(bbox_reg_targets) == 0:
bbox_reg_targets.new_zeros(len(bbox_reg_targets))
left_right = bbox_reg_targets[:, :, [0, 2]]
top_bottom = bbox_reg_targets[:, :, [1, 3]]
gt_ctrness_targets = torch.sqrt(
(left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0])
* (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
)
pred_centerness = bbox_ctrness.squeeze(dim=2)
loss_bbox_ctrness = nn.functional.binary_cross_entropy_with_logits(
pred_centerness[foregroud_mask], gt_ctrness_targets[foregroud_mask], reduction="sum"
)
return {
"classification": loss_cls / max(1, num_foreground),
"bbox_regression": loss_bbox_reg / max(1, num_foreground),
"bbox_ctrness": loss_bbox_ctrness / max(1, num_foreground),
}
def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
cls_logits = self.classification_head(x)
bbox_regression, bbox_ctrness = self.regression_head(x)
return {
"cls_logits": cls_logits,
"bbox_regression": bbox_regression,
"bbox_ctrness": bbox_ctrness,
}
class FCOSClassificationHead(nn.Module):
"""
A classification head for use in FCOS.
Args:
in_channels (int): number of channels of the input feature.
num_anchors (int): number of anchors to be predicted.
num_classes (int): number of classes to be predicted.
num_convs (Optional[int]): number of conv layer. Default: 4.
prior_probability (Optional[float]): probability of prior. Default: 0.01.
norm_layer: Module specifying the normalization layer to use.
"""
def __init__(
self,
in_channels: int,
num_anchors: int,
num_classes: int,
num_convs: int = 4,
prior_probability: float = 0.01,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
self.num_classes = num_classes
self.num_anchors = num_anchors
if norm_layer is None:
norm_layer = partial(nn.GroupNorm, 32)
conv = []
for _ in range(num_convs):
conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
conv.append(norm_layer(in_channels))
conv.append(nn.ReLU())
self.conv = nn.Sequential(*conv)
for layer in self.conv.children():
if isinstance(layer, nn.Conv2d):
torch.nn.init.normal_(layer.weight, std=0.01)
torch.nn.init.constant_(layer.bias, 0)
self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))
def forward(self, x: List[Tensor]) -> Tensor:
all_cls_logits = []
for features in x:
cls_logits = self.conv(features)
cls_logits = self.cls_logits(cls_logits)
# Permute classification output from (N, A * K, H, W) to (N, HWA, K).
N, _, H, W = cls_logits.shape
cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4)
all_cls_logits.append(cls_logits)
return torch.cat(all_cls_logits, dim=1)
class FCOSRegressionHead(nn.Module):
"""
A regression head for use in FCOS.
Args:
in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted
num_convs (Optional[int]): number of conv layer. Default: 4.
norm_layer: Module specifying the normalization layer to use.
"""
def __init__(
self,
in_channels: int,
num_anchors: int,
num_convs: int = 4,
norm_layer: Optional[Callable[..., nn.Module]] = None,
):
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.GroupNorm, 32)
conv = []
for _ in range(num_convs):
conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
conv.append(norm_layer(in_channels))
conv.append(nn.ReLU())
self.conv = nn.Sequential(*conv)
self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
self.bbox_ctrness = nn.Conv2d(in_channels, num_anchors * 1, kernel_size=3, stride=1, padding=1)
for layer in [self.bbox_reg, self.bbox_ctrness]:
torch.nn.init.normal_(layer.weight, std=0.01)
torch.nn.init.zeros_(layer.bias)
for layer in self.conv.children():
if isinstance(layer, nn.Conv2d):
torch.nn.init.normal_(layer.weight, std=0.01)
torch.nn.init.zeros_(layer.bias)
def forward(self, x: List[Tensor]) -> Tuple[Tensor, Tensor]:
all_bbox_regression = []
all_bbox_ctrness = []
for features in x:
bbox_feature = self.conv(features)
bbox_regression = nn.functional.relu(self.bbox_reg(bbox_feature))
bbox_ctrness = self.bbox_ctrness(bbox_feature)
# permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
N, _, H, W = bbox_regression.shape
bbox_regression = bbox_regression.view(N, -1, 4, H, W)
bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4)
all_bbox_regression.append(bbox_regression)
# permute bbox ctrness output from (N, 1 * A, H, W) to (N, HWA, 1).
bbox_ctrness = bbox_ctrness.view(N, -1, 1, H, W)
bbox_ctrness = bbox_ctrness.permute(0, 3, 4, 1, 2)
bbox_ctrness = bbox_ctrness.reshape(N, -1, 1)
all_bbox_ctrness.append(bbox_ctrness)
return torch.cat(all_bbox_regression, dim=1), torch.cat(all_bbox_ctrness, dim=1)
class FCOS(nn.Module):
"""
Implements FCOS.
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
image, and should be in 0-1 range. Different images can have different sizes.
The behavior of the model changes depending if it is in training or evaluation mode.
During training, the model expects both the input tensors, as well as a targets (list of dictionary),
containing:
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
- labels (Int64Tensor[N]): the class label for each ground-truth box
The model returns a Dict[Tensor] during training, containing the classification, regression
and centerness losses.
During inference, the model requires only the input tensors, and returns the post-processed
predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
follows:
- boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
- labels (Int64Tensor[N]): the predicted labels for each image
- scores (Tensor[N]): the scores for each prediction
Args:
backbone (nn.Module): the network used to compute the features for the model.
It should contain an out_channels attribute, which indicates the number of output
channels that each feature map has (and it should be the same for all feature maps).
The backbone should return a single Tensor or an OrderedDict[Tensor].
num_classes (int): number of output classes of the model (including the background).
min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
image_mean (Tuple[float, float, float]): mean values used for input normalization.
They are generally the mean values of the dataset on which the backbone has been trained
on
image_std (Tuple[float, float, float]): std values used for input normalization.
They are generally the std values of the dataset on which the backbone has been trained on
anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
maps. For FCOS, only set one anchor for per position of each level, the width and height equal to
the stride of feature map, and set aspect ratio = 1.0, so the center of anchor is equivalent to the point
in FCOS paper.
head (nn.Module): Module run on top of the feature pyramid.
Defaults to a module containing a classification and regression module.
center_sampling_radius (int): radius of the "center" of a groundtruth box,
within which all anchor points are labeled positive.
score_thresh (float): Score threshold used for postprocessing the detections.
nms_thresh (float): NMS threshold used for postprocessing the detections.
detections_per_img (int): Number of best detections to keep after NMS.
topk_candidates (int): Number of best detections to keep before NMS.
Example:
>>> import torch
>>> import torchvision
>>> from torchvision.models.detection import FCOS
>>> from torchvision.models.detection.anchor_utils import AnchorGenerator
>>> # load a pre-trained model for classification and return
>>> # only the features
>>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features
>>> # FCOS needs to know the number of
>>> # output channels in a backbone. For mobilenet_v2, it's 1280
>>> # so we need to add it here
>>> backbone.out_channels = 1280
>>>
>>> # let's make the network generate 5 x 3 anchors per spatial
>>> # location, with 5 different sizes and 3 different aspect
>>> # ratios. We have a Tuple[Tuple[int]] because each feature
>>> # map could potentially have different sizes and
>>> # aspect ratios
>>> anchor_generator = AnchorGenerator(
>>> sizes=((8,), (16,), (32,), (64,), (128,)),
>>> aspect_ratios=((1.0,),)
>>> )
>>>
>>> # put the pieces together inside a FCOS model
>>> model = FCOS(
>>> backbone,
>>> num_classes=80,
>>> anchor_generator=anchor_generator,
>>> )
>>> model.eval()
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x)
"""
__annotations__ = {
"box_coder": det_utils.BoxLinearCoder,
}
def __init__(
self,
backbone: nn.Module,
num_classes: int,
# transform parameters
min_size: int = 800,
max_size: int = 1333,
image_mean: Optional[List[float]] = None,
image_std: Optional[List[float]] = None,
# Anchor parameters
anchor_generator: Optional[AnchorGenerator] = None,
head: Optional[nn.Module] = None,
center_sampling_radius: float = 1.5,
score_thresh: float = 0.2,
nms_thresh: float = 0.6,
detections_per_img: int = 100,
topk_candidates: int = 1000,
):
super().__init__()
_log_api_usage_once(self)
if not hasattr(backbone, "out_channels"):
raise ValueError(
"backbone should contain an attribute out_channels "
"specifying the number of output channels (assumed to be the "
"same for all the levels)"
)
self.backbone = backbone
assert isinstance(anchor_generator, (AnchorGenerator, type(None)))
if anchor_generator is None:
anchor_sizes = ((8,), (16,), (32,), (64,), (128,)) # equal to strides of multi-level feature map
aspect_ratios = ((1.0,),) * len(anchor_sizes) # set only one anchor
anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
self.anchor_generator = anchor_generator
assert self.anchor_generator.num_anchors_per_location()[0] == 1
if head is None:
head = FCOSHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes)
self.head = head
self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True)
if image_mean is None:
image_mean = [0.485, 0.456, 0.406]
if image_std is None:
image_std = [0.229, 0.224, 0.225]
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
self.center_sampling_radius = center_sampling_radius
self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
self.detections_per_img = detections_per_img
self.topk_candidates = topk_candidates
# used only on torchscript mode
self._has_warned = False
@torch.jit.unused
def eager_outputs(
self, losses: Dict[str, Tensor], detections: List[Dict[str, Tensor]]
) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
if self.training:
return losses
return detections
def compute_loss(
self,
targets: List[Dict[str, Tensor]],
head_outputs: Dict[str, Tensor],
anchors: List[Tensor],
num_anchors_per_level: List[int],
) -> Dict[str, Tensor]:
matched_idxs = []
for anchors_per_image, targets_per_image in zip(anchors, targets):
if targets_per_image["boxes"].numel() == 0:
matched_idxs.append(
torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
)
continue
gt_boxes = targets_per_image["boxes"]
gt_centers = (gt_boxes[:, :2] + gt_boxes[:, 2:]) / 2 # Nx2
anchor_centers = (anchors_per_image[:, :2] + anchors_per_image[:, 2:]) / 2 # N
anchor_sizes = anchors_per_image[:, 2] - anchors_per_image[:, 0]
# center sampling: anchor point must be close enough to gt center.
pairwise_match = (anchor_centers[:, None, :] - gt_centers[None, :, :]).abs_().max(
dim=2
).values < self.center_sampling_radius * anchor_sizes[:, None]
# compute pairwise distance between N points and M boxes
x, y = anchor_centers.unsqueeze(dim=2).unbind(dim=1) # (N, 1)
x0, y0, x1, y1 = gt_boxes.unsqueeze(dim=0).unbind(dim=2) # (1, M)
pairwise_dist = torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2) # (N, M)
# anchor point must be inside gt
pairwise_match &= pairwise_dist.min(dim=2).values > 0
# each anchor is only responsible for certain scale range.
lower_bound = anchor_sizes * 4
lower_bound[: num_anchors_per_level[0]] = 0
upper_bound = anchor_sizes * 8
upper_bound[-num_anchors_per_level[-1] :] = float("inf")
pairwise_dist = pairwise_dist.max(dim=2).values
pairwise_match &= (pairwise_dist > lower_bound[:, None]) & (pairwise_dist < upper_bound[:, None])
# match the GT box with minimum area, if there are multiple GT matches
gt_areas = (gt_boxes[:, 1] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1]) # N
pairwise_match = pairwise_match.to(torch.float32) * (1e8 - gt_areas[None, :])
min_values, matched_idx = pairwise_match.max(dim=1) # R, per-anchor match
matched_idx[min_values < 1e-5] = -1 # unmatched anchors are assigned -1
matched_idxs.append(matched_idx)
return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
def postprocess_detections(
self, head_outputs: Dict[str, List[Tensor]], anchors: List[List[Tensor]], image_shapes: List[Tuple[int, int]]
) -> List[Dict[str, Tensor]]:
class_logits = head_outputs["cls_logits"]
box_regression = head_outputs["bbox_regression"]
box_ctrness = head_outputs["bbox_ctrness"]
num_images = len(image_shapes)
detections: List[Dict[str, Tensor]] = []
for index in range(num_images):
box_regression_per_image = [br[index] for br in box_regression]
logits_per_image = [cl[index] for cl in class_logits]
box_ctrness_per_image = [bc[index] for bc in box_ctrness]
anchors_per_image, image_shape = anchors[index], image_shapes[index]
image_boxes = []
image_scores = []
image_labels = []
for box_regression_per_level, logits_per_level, box_ctrness_per_level, anchors_per_level in zip(
box_regression_per_image, logits_per_image, box_ctrness_per_image, anchors_per_image
):
num_classes = logits_per_level.shape[-1]
# remove low scoring boxes
scores_per_level = torch.sqrt(
torch.sigmoid(logits_per_level) * torch.sigmoid(box_ctrness_per_level)
).flatten()
keep_idxs = scores_per_level > self.score_thresh
scores_per_level = scores_per_level[keep_idxs]
topk_idxs = torch.where(keep_idxs)[0]
# keep only topk scoring predictions
num_topk = min(self.topk_candidates, topk_idxs.size(0))
scores_per_level, idxs = scores_per_level.topk(num_topk)
topk_idxs = topk_idxs[idxs]
anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
labels_per_level = topk_idxs % num_classes
boxes_per_level = self.box_coder.decode_single(
box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
)
boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
image_boxes.append(boxes_per_level)
image_scores.append(scores_per_level)
image_labels.append(labels_per_level)
image_boxes = torch.cat(image_boxes, dim=0)
image_scores = torch.cat(image_scores, dim=0)
image_labels = torch.cat(image_labels, dim=0)
# non-maximum suppression
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
keep = keep[: self.detections_per_img]
detections.append(
{
"boxes": image_boxes[keep],
"scores": image_scores[keep],
"labels": image_labels[keep],
}
)
return detections
def forward(
self,
images: List[Tensor],
targets: Optional[List[Dict[str, Tensor]]] = None,
) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
"""
Args:
images (list[Tensor]): images to be processed
targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
Returns:
result (list[BoxList] or dict[Tensor]): the output from the model.
During training, it returns a dict[Tensor] which contains the losses.
During testing, it returns list[BoxList] contains additional fields
like `scores`, `labels` and `mask` (for Mask R-CNN models).
"""
if self.training:
if targets is None:
raise ValueError("In training mode, targets should be passed")
for target in targets:
boxes = target["boxes"]
if isinstance(boxes, torch.Tensor):
if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
raise ValueError(f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.")
else:
raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
original_image_sizes: List[Tuple[int, int]] = []
for img in images:
val = img.shape[-2:]
assert len(val) == 2
original_image_sizes.append((val[0], val[1]))
# transform the input
images, targets = self.transform(images, targets)
# Check for degenerate boxes
if targets is not None:
for target_idx, target in enumerate(targets):
boxes = target["boxes"]
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
# print the first degenerate box
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
degen_bb: List[float] = boxes[bb_idx].tolist()
raise ValueError(
"All bounding boxes should have positive height and width."
f" Found invalid box {degen_bb} for target at index {target_idx}."
)
# get the features from the backbone
features = self.backbone(images.tensors)
if isinstance(features, torch.Tensor):
features = OrderedDict([("0", features)])
features = list(features.values())
# compute the fcos heads outputs using the features
head_outputs = self.head(features)
# create the set of anchors
anchors = self.anchor_generator(images, features)
# recover level sizes
num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
losses = {}
detections: List[Dict[str, Tensor]] = []
if self.training:
assert targets is not None
# compute the losses
losses = self.compute_loss(targets, head_outputs, anchors, num_anchors_per_level)
else:
# split outputs per level
split_head_outputs: Dict[str, List[Tensor]] = {}
for k in head_outputs:
split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1))
split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]
# compute the detections
detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
if torch.jit.is_scripting():
if not self._has_warned:
warnings.warn("FCOS always returns a (Losses, Detections) tuple in scripting")
self._has_warned = True
return losses, detections
return self.eager_outputs(losses, detections)
model_urls = {
"fcos_resnet50_fpn_coco": "https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth",
}
def fcos_resnet50_fpn(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 91,
pretrained_backbone: bool = True,
trainable_backbone_layers: Optional[int] = None,
**kwargs,
):
"""
Constructs a FCOS model with a ResNet-50-FPN backbone.
Reference: `"FCOS: Fully Convolutional One-Stage Object Detection" <https://arxiv.org/abs/1904.01355>`_.
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
image, and should be in ``0-1`` range. Different images can have different sizes.
The behavior of the model changes depending if it is in training or evaluation mode.
During training, the model expects both the input tensors, as well as a targets (list of dictionary),
containing:
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
- labels (``Int64Tensor[N]``): the class label for each ground-truth box
The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
losses.
During inference, the model requires only the input tensors, and returns the post-processed
predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
follows, where ``N`` is the number of detections:
- boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
- labels (``Int64Tensor[N]``): the predicted labels for each detection
- scores (``Tensor[N]``): the scores of each detection
For more details on the output, you may refer to :ref:`instance_seg_output`.
Example:
>>> model = torchvision.models.detection.fcos_resnet50_fpn(pretrained=True)
>>> model.eval()
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x)
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers (int, optional): number of trainable (not frozen) resnet layers starting
from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
trainable. If ``None`` is passed (the default) this value is set to 3. Default: None
"""
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3
)
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = _resnet_fpn_extractor(
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
)
model = FCOS(backbone, num_classes, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls["fcos_resnet50_fpn_coco"], progress=progress)
model.load_state_dict(state_dict)
return model
...@@ -13,6 +13,7 @@ from .boxes import box_convert ...@@ -13,6 +13,7 @@ from .boxes import box_convert
from .deform_conv import deform_conv2d, DeformConv2d from .deform_conv import deform_conv2d, DeformConv2d
from .feature_pyramid_network import FeaturePyramidNetwork from .feature_pyramid_network import FeaturePyramidNetwork
from .focal_loss import sigmoid_focal_loss from .focal_loss import sigmoid_focal_loss
from .generalized_box_iou_loss import generalized_box_iou_loss
from .misc import FrozenBatchNorm2d, ConvNormActivation, SqueezeExcitation from .misc import FrozenBatchNorm2d, ConvNormActivation, SqueezeExcitation
from .poolers import MultiScaleRoIAlign from .poolers import MultiScaleRoIAlign
from .ps_roi_align import ps_roi_align, PSRoIAlign from .ps_roi_align import ps_roi_align, PSRoIAlign
...@@ -52,4 +53,5 @@ __all__ = [ ...@@ -52,4 +53,5 @@ __all__ = [
"FrozenBatchNorm2d", "FrozenBatchNorm2d",
"ConvNormActivation", "ConvNormActivation",
"SqueezeExcitation", "SqueezeExcitation",
"generalized_box_iou_loss",
] ]
import torch
def generalized_box_iou_loss(
boxes1: torch.Tensor,
boxes2: torch.Tensor,
reduction: str = "none",
eps: float = 1e-7,
) -> torch.Tensor:
"""
Original implementation from
https://github.com/facebookresearch/fvcore/blob/bfff2ef/fvcore/nn/giou_loss.py
Gradient-friendly IoU loss with an additional penalty that is non-zero when the
boxes do not overlap and scales with the size of their smallest enclosing box.
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable.
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``, and The two boxes should have the
same dimensions.
Args:
boxes1 (Tensor[N, 4] or Tensor[4]): first set of boxes
boxes2 (Tensor[N, 4] or Tensor[4]): second set of boxes
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be
applied to the output. ``'mean'``: The output will be averaged.
``'sum'``: The output will be summed. Default: ``'none'``
eps (float, optional): small number to prevent division by zero. Default: 1e-7
Reference:
Hamid Rezatofighi et. al: Generalized Intersection over Union:
A Metric and A Loss for Bounding Box Regression:
https://arxiv.org/abs/1902.09630
"""
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
assert (x2 >= x1).all(), "bad box: x1 larger than x2"
assert (y2 >= y1).all(), "bad box: y1 larger than y2"
# Intersection keypoints
xkis1 = torch.max(x1, x1g)
ykis1 = torch.max(y1, y1g)
xkis2 = torch.min(x2, x2g)
ykis2 = torch.min(y2, y2g)
intsctk = torch.zeros_like(x1)
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk
iouk = intsctk / (unionk + eps)
# smallest enclosing box
xc1 = torch.min(x1, x1g)
yc1 = torch.min(y1, y1g)
xc2 = torch.max(x2, x2g)
yc2 = torch.max(y2, y2g)
area_c = (xc2 - xc1) * (yc2 - yc1)
miouk = iouk - ((area_c - unionk) / (area_c + eps))
loss = 1 - miouk
if reduction == "mean":
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
elif reduction == "sum":
loss = loss.sum()
return loss
from .faster_rcnn import * from .faster_rcnn import *
from .fcos import *
from .keypoint_rcnn import * from .keypoint_rcnn import *
from .mask_rcnn import * from .mask_rcnn import *
from .retinanet import * from .retinanet import *
......
from typing import Any, Optional
from torchvision.prototype.transforms import CocoEval
from torchvision.transforms.functional import InterpolationMode
from ....models.detection.fcos import (
_resnet_fpn_extractor,
_validate_trainable_layers,
FCOS,
LastLevelP6P7,
misc_nn_ops,
)
from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..resnet import ResNet50_Weights, resnet50
__all__ = [
"FCOS",
"FCOS_ResNet50_FPN_Weights",
"fcos_resnet50_fpn",
]
class FCOS_ResNet50_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth",
transforms=CocoEval,
meta={
"task": "image_object_detection",
"architecture": "FCOS",
"publication_year": 2019,
"num_params": 32269600,
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#fcos-resnet-50-fpn",
"map": 39.2,
},
)
default = COCO_V1
@handle_legacy_interface(
weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
)
def fcos_resnet50_fpn(
*,
weights: Optional[FCOS_ResNet50_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FCOS:
weights = FCOS_ResNet50_FPN_Weights.verify(weights)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91
trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
)
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = _resnet_fpn_extractor(
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
)
model = FCOS(backbone, num_classes, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment