Unverified Commit f76e598d authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Improve documentation for detection models (#928)

* Add more documentation for the ops

* Add documentation for Faster R-CNN

* Add documentation for Mask R-CNN and Keypoint R-CNN

* Improve doc for RPN

* Add basic doc for GeneralizedRCNNTransform

* Lint fixes
parent a68db4fa
...@@ -4,7 +4,6 @@ import torch ...@@ -4,7 +4,6 @@ import torch
from torch import nn from torch import nn
# TODO should we remove the unused parameters or not?
class IntermediateLayerGetter(nn.ModuleDict): class IntermediateLayerGetter(nn.ModuleDict):
""" """
Module wrapper that returns intermediate layers from a model Module wrapper that returns intermediate layers from a model
...@@ -12,7 +11,29 @@ class IntermediateLayerGetter(nn.ModuleDict): ...@@ -12,7 +11,29 @@ class IntermediateLayerGetter(nn.ModuleDict):
It has a strong assumption that the modules have been registered It has a strong assumption that the modules have been registered
into the model in the same order as they are used. into the model in the same order as they are used.
This means that one should **not** reuse the same nn.Module This means that one should **not** reuse the same nn.Module
twice in the forward if you want this to work twice in the forward if you want this to work.
Additionally, it is only able to query submodules that are directly
assigned to the model. So if `model` is passed, `model.feature1` can
be returned, but not `model.feature1.layer2`.
Arguments:
model (nn.Module): model on which we will extract the features
return_layers (Dict[name, new_name]): a dict containing the names
of the modules for which the activations will be returned as
the key of the dict, and the value of the dict is the name
of the returned activation (which the user can specify).
Examples::
>>> m = torchvision.models.resnet18(pretrained=True)
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
>>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
>>> {'layer1': 'feat1', 'layer3': 'feat2'})
>>> out = new_m(torch.rand(1, 3, 224, 224))
>>> print([(k, v.shape) for k, v in out.items()])
>>> [('feat1', torch.Size([1, 64, 56, 56])),
>>> ('feat2', torch.Size([1, 256, 14, 14]))]
""" """
def __init__(self, model, return_layers): def __init__(self, model, return_layers):
if not set(return_layers).issubset([name for name, _ in model.named_children()]): if not set(return_layers).issubset([name for name, _ in model.named_children()]):
......
...@@ -8,6 +8,26 @@ from .. import resnet ...@@ -8,6 +8,26 @@ from .. import resnet
class BackboneWithFPN(nn.Sequential): class BackboneWithFPN(nn.Sequential):
"""
Adds a FPN on top of a model.
Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
extract a submodel that returns the feature maps specified in return_layers.
The same limitations of IntermediatLayerGetter apply here.
Arguments:
backbone (nn.Module)
return_layers (Dict[name, new_name]): a dict containing the names
of the modules for which the activations will be returned as
the key of the dict, and the value of the dict is the name
of the returned activation (which the user can specify).
in_channels_list (List[int]): number of channels for each feature map
that is returned, in the order they are present in the OrderedDict
out_channels (int): number of channels in the FPN.
Attributes:
out_channels (int): the number of channels in the FPN
"""
def __init__(self, backbone, return_layers, in_channels_list, out_channels): def __init__(self, backbone, return_layers, in_channels_list, out_channels):
body = IntermediateLayerGetter(backbone, return_layers=return_layers) body = IntermediateLayerGetter(backbone, return_layers=return_layers)
fpn = FeaturePyramidNetwork( fpn = FeaturePyramidNetwork(
......
...@@ -22,6 +22,88 @@ __all__ = [ ...@@ -22,6 +22,88 @@ __all__ = [
class FasterRCNN(GeneralizedRCNN): class FasterRCNN(GeneralizedRCNN):
"""
Implements Faster R-CNN.
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
image, and should be in 0-1 range. Different images can have different sizes.
The behavior of the model changes depending if it is in training or evaluation mode.
During training, the model expects both the input tensors, as well as a targets dictionary,
containing:
boxes (Tensor[N, 4]): the ground-truth boxes in [x0, y0, x1, y1] format, with values
between 0 and H and 0 and W
labels (Tensor[N]): the class label for each ground-truth box
The model returns a Dict[Tensor] during training, containing the classification and regression
losses for both the RPN and the R-CNN.
During inference, the model requires only the input tensors, and returns the post-processed
predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
follows:
boxes (Tensor[N, 4]): the predicted boxes in [x0, y0, x1, y1] format, with values between
0 and H and 0 and W
labels (Tensor[N]): the predicted labels for each image
scores (Tensor[N]): the scores or each prediction
Arguments:
backbone (nn.Module): the network used to compute the features for the model.
It should contain a out_channels attribute, which indicates the number of output
channels that each feature map has (and it should be the same for all feature maps).
The backbone should return a single Tensor or and OrderedDict[Tensor].
num_classes (int): number of output classes of the model (including the background).
If box_predictor is specified, num_classes should be None.
min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
image_mean (Tuple[float, float, float]): mean values used for input normalization.
They are generally the mean values of the dataset on which the backbone has been trained
on
image_std (Tuple[float, float, float]): std values used for input normalization.
They are generally the std values of the dataset on which the backbone has been trained on
rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
maps.
rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
considered as positive during training of the RPN.
rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
considered as negative during training of the RPN.
rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
for computing the loss
rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
of the RPN
box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
the locations indicated by the bounding boxes
box_head (nn.Module): module that takes the cropped feature maps as input
box_predictor (nn.Module): module that takes the output of box_head and returns the
classification logits and box regression deltas.
box_score_thresh (float): during inference, only return proposals with a classification score
greater than box_score_thresh
box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
box_detections_per_img (int): maximum number of detections per image, for all classes.
box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
considered as positive during training of the classification head
box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
considered as negative during training of the classification head
box_batch_size_per_image (int): number of proposals that are sampled during training of the
classification head
box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
of the classification head
bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
bounding boxes
Example::
>>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
>>> model.eval()
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x)
"""
def __init__(self, backbone, num_classes=None, def __init__(self, backbone, num_classes=None,
# transform parameters # transform parameters
min_size=800, max_size=1333, min_size=800, max_size=1333,
...@@ -117,7 +199,11 @@ class FasterRCNN(GeneralizedRCNN): ...@@ -117,7 +199,11 @@ class FasterRCNN(GeneralizedRCNN):
class TwoMLPHead(nn.Module): class TwoMLPHead(nn.Module):
""" """
Heads for FPN for classification Standard heads for FPN-based models
Arguments:
in_channels (int): number of input channels
representation_size (int): size of the intermediate representation
""" """
def __init__(self, in_channels, representation_size): def __init__(self, in_channels, representation_size):
...@@ -136,6 +222,15 @@ class TwoMLPHead(nn.Module): ...@@ -136,6 +222,15 @@ class TwoMLPHead(nn.Module):
class FastRCNNPredictor(nn.Module): class FastRCNNPredictor(nn.Module):
"""
Standard classification + bounding box regression layers
for Fast R-CNN.
Arguments:
in_channels (int): number of input channels
num_classes (int): number of output classes (including background)
"""
def __init__(self, in_channels, num_classes): def __init__(self, in_channels, num_classes):
super(FastRCNNPredictor, self).__init__() super(FastRCNNPredictor, self).__init__()
self.cls_score = nn.Linear(in_channels, num_classes) self.cls_score = nn.Linear(in_channels, num_classes)
...@@ -159,6 +254,13 @@ model_urls = { ...@@ -159,6 +254,13 @@ model_urls = {
def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
num_classes=91, pretrained_backbone=True, **kwargs): num_classes=91, pretrained_backbone=True, **kwargs):
"""
Constructs a Faster R-CNN model with a ResNet-50-FPN backbone.
Arguments:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
"""
if pretrained: if pretrained:
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
pretrained_backbone = False pretrained_backbone = False
......
...@@ -16,6 +16,95 @@ __all__ = [ ...@@ -16,6 +16,95 @@ __all__ = [
class KeypointRCNN(FasterRCNN): class KeypointRCNN(FasterRCNN):
"""
Implements Keypoint R-CNN.
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
image, and should be in 0-1 range. Different images can have different sizes.
The behavior of the model changes depending if it is in training or evaluation mode.
During training, the model expects both the input tensors, as well as a targets dictionary,
containing:
boxes (Tensor[N, 4]): the ground-truth boxes in [x0, y0, x1, y1] format, with values
between 0 and H and 0 and W
labels (Tensor[N]): the class label for each ground-truth box
keypoints (Tensor[N, K, 3]): the K keypoints location for each of the N instances, in the
format [x, y, visibility], where visibility=0 means that the keypoint is not visible.
The model returns a Dict[Tensor] during training, containing the classification and regression
losses for both the RPN and the R-CNN, and the keypoint loss.
During inference, the model requires only the input tensors, and returns the post-processed
predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
follows:
boxes (Tensor[N, 4]): the predicted boxes in [x0, y0, x1, y1] format, with values between
0 and H and 0 and W
labels (Tensor[N]): the predicted labels for each image
scores (Tensor[N]): the scores or each prediction
keypoints (Tensor[N, K, 3]): the locations of the predicted keypoints, in [x, y, v] format.
Arguments:
backbone (nn.Module): the network used to compute the features for the model.
It should contain a out_channels attribute, which indicates the number of output
channels that each feature map has (and it should be the same for all feature maps).
The backbone should return a single Tensor or and OrderedDict[Tensor].
num_classes (int): number of output classes of the model (including the background).
If box_predictor is specified, num_classes should be None.
min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
image_mean (Tuple[float, float, float]): mean values used for input normalization.
They are generally the mean values of the dataset on which the backbone has been trained
on
image_std (Tuple[float, float, float]): std values used for input normalization.
They are generally the std values of the dataset on which the backbone has been trained on
rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
maps.
rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
considered as positive during training of the RPN.
rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
considered as negative during training of the RPN.
rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
for computing the loss
rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
of the RPN
box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
the locations indicated by the bounding boxes
box_head (nn.Module): module that takes the cropped feature maps as input
box_predictor (nn.Module): module that takes the output of box_head and returns the
classification logits and box regression deltas.
box_score_thresh (float): during inference, only return proposals with a classification score
greater than box_score_thresh
box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
box_detections_per_img (int): maximum number of detections per image, for all classes.
box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
considered as positive during training of the classification head
box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
considered as negative during training of the classification head
box_batch_size_per_image (int): number of proposals that are sampled during training of the
classification head
box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
of the classification head
bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
bounding boxes
keypoint_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
the locations indicated by the bounding boxes, which will be used for the keypoint head.
keypoint_head (nn.Module): module that takes the cropped feature maps as input
keypoint_predictor (nn.Module): module that takes the output of the keypoint_head and returns the
heatmap logits
Example::
>>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=True)
>>> model.eval()
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x)
"""
def __init__(self, backbone, num_classes=None, def __init__(self, backbone, num_classes=None,
# transform parameters # transform parameters
min_size=None, max_size=1333, min_size=None, max_size=1333,
...@@ -136,6 +225,13 @@ model_urls = { ...@@ -136,6 +225,13 @@ model_urls = {
def keypointrcnn_resnet50_fpn(pretrained=False, progress=True, def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
num_classes=2, num_keypoints=17, num_classes=2, num_keypoints=17,
pretrained_backbone=True, **kwargs): pretrained_backbone=True, **kwargs):
"""
Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
Arguments:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
"""
if pretrained: if pretrained:
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
pretrained_backbone = False pretrained_backbone = False
......
...@@ -18,6 +18,96 @@ __all__ = [ ...@@ -18,6 +18,96 @@ __all__ = [
class MaskRCNN(FasterRCNN): class MaskRCNN(FasterRCNN):
"""
Implements Mask R-CNN.
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
image, and should be in 0-1 range. Different images can have different sizes.
The behavior of the model changes depending if it is in training or evaluation mode.
During training, the model expects both the input tensors, as well as a targets dictionary,
containing:
boxes (Tensor[N, 4]): the ground-truth boxes in [x0, y0, x1, y1] format, with values
between 0 and H and 0 and W
labels (Tensor[N]): the class label for each ground-truth box
masks (Tensor[N, H, W]): the segmentation binary masks for each instance
The model returns a Dict[Tensor] during training, containing the classification and regression
losses for both the RPN and the R-CNN, and the mask loss.
During inference, the model requires only the input tensors, and returns the post-processed
predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
follows:
boxes (Tensor[N, 4]): the predicted boxes in [x0, y0, x1, y1] format, with values between
0 and H and 0 and W
labels (Tensor[N]): the predicted labels for each image
scores (Tensor[N]): the scores or each prediction
mask (Tensor[N, H, W]): the predicted masks for each instance, in 0-1 range. In order to
obtain the final segmentation masks, the soft masks can be thresholded, generally
with a value of 0.5 (mask >= 0.5)
Arguments:
backbone (nn.Module): the network used to compute the features for the model.
It should contain a out_channels attribute, which indicates the number of output
channels that each feature map has (and it should be the same for all feature maps).
The backbone should return a single Tensor or and OrderedDict[Tensor].
num_classes (int): number of output classes of the model (including the background).
If box_predictor is specified, num_classes should be None.
min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
image_mean (Tuple[float, float, float]): mean values used for input normalization.
They are generally the mean values of the dataset on which the backbone has been trained
on
image_std (Tuple[float, float, float]): std values used for input normalization.
They are generally the std values of the dataset on which the backbone has been trained on
rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
maps.
rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
considered as positive during training of the RPN.
rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
considered as negative during training of the RPN.
rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
for computing the loss
rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
of the RPN
box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
the locations indicated by the bounding boxes
box_head (nn.Module): module that takes the cropped feature maps as input
box_predictor (nn.Module): module that takes the output of box_head and returns the
classification logits and box regression deltas.
box_score_thresh (float): during inference, only return proposals with a classification score
greater than box_score_thresh
box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
box_detections_per_img (int): maximum number of detections per image, for all classes.
box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
considered as positive during training of the classification head
box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
considered as negative during training of the classification head
box_batch_size_per_image (int): number of proposals that are sampled during training of the
classification head
box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
of the classification head
bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
bounding boxes
mask_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
the locations indicated by the bounding boxes, which will be used for the mask head.
mask_head (nn.Module): module that takes the cropped feature maps as input
mask_predictor (nn.Module): module that takes the output of the mask_head and returns the
segmentation mask logits
Example::
>>> model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
>>> model.eval()
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x)
"""
def __init__(self, backbone, num_classes=None, def __init__(self, backbone, num_classes=None,
# transform parameters # transform parameters
min_size=800, max_size=1333, min_size=800, max_size=1333,
...@@ -133,6 +223,13 @@ model_urls = { ...@@ -133,6 +223,13 @@ model_urls = {
def maskrcnn_resnet50_fpn(pretrained=False, progress=True, def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
num_classes=91, pretrained_backbone=True, **kwargs): num_classes=91, pretrained_backbone=True, **kwargs):
"""
Constructs a Mask R-CNN model with a ResNet-50-FPN backbone.
Arguments:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
"""
if pretrained: if pretrained:
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
pretrained_backbone = False pretrained_backbone = False
......
...@@ -10,8 +10,22 @@ from . import _utils as det_utils ...@@ -10,8 +10,22 @@ from . import _utils as det_utils
class AnchorGenerator(nn.Module): class AnchorGenerator(nn.Module):
""" """
For a set of image sizes and feature maps, computes a set Module that generates anchors for a set of feature maps and
of anchors image sizes.
The module support computing anchors at multiple sizes and aspect ratios
per feature map.
sizes and aspect_ratios should have the same number of elements, and it should
correspond to the number of feature maps.
sizes[i] and aspect_ratios[i] can have an arbitrary number of elements,
and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors
per spatial location for feature map i.
Arguments:
sizes (Tuple[Tuple[int]]):
aspect_ratios (Tuple[Tuple[float]]):
""" """
def __init__( def __init__(
...@@ -115,14 +129,13 @@ class AnchorGenerator(nn.Module): ...@@ -115,14 +129,13 @@ class AnchorGenerator(nn.Module):
class RPNHead(nn.Module): class RPNHead(nn.Module):
""" """
Adds a simple RPN Head with classification and regression heads Adds a simple RPN Head with classification and regression heads
"""
def __init__(self, in_channels, num_anchors):
"""
Arguments: Arguments:
in_channels (int): number of channels of the input feature in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted num_anchors (int): number of anchors to be predicted
""" """
def __init__(self, in_channels, num_anchors):
super(RPNHead, self).__init__() super(RPNHead, self).__init__()
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1 in_channels, in_channels, kernel_size=3, stride=1, padding=1
...@@ -185,6 +198,30 @@ def concat_box_prediction_layers(box_cls, box_regression): ...@@ -185,6 +198,30 @@ def concat_box_prediction_layers(box_cls, box_regression):
class RegionProposalNetwork(torch.nn.Module): class RegionProposalNetwork(torch.nn.Module):
"""
Implements Region Proposal Network (RPN).
Arguments:
anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
maps.
head (nn.Module): module that computes the objectness and regression deltas
fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
considered as positive during training of the RPN.
bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
considered as negative during training of the RPN.
batch_size_per_image (int): number of anchors that are sampled during training of the RPN
for computing the loss
positive_fraction (float): proportion of positive anchors in a mini-batch during training
of the RPN
pre_nms_top_n (Dict[int]): number of proposals to keep before applying NMS. It should
contain two fields: training and testing, to allow for different values depending
on training or evaluation
post_nms_top_n (Dict[int]): number of proposals to keep after applying NMS. It should
contain two fields: training and testing, to allow for different values depending
on training or evaluation
nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
"""
def __init__(self, def __init__(self,
anchor_generator, anchor_generator,
...@@ -194,9 +231,6 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -194,9 +231,6 @@ class RegionProposalNetwork(torch.nn.Module):
batch_size_per_image, positive_fraction, batch_size_per_image, positive_fraction,
# #
pre_nms_top_n, post_nms_top_n, nms_thresh): pre_nms_top_n, post_nms_top_n, nms_thresh):
"""
Arguments:
"""
super(RegionProposalNetwork, self).__init__() super(RegionProposalNetwork, self).__init__()
self.anchor_generator = anchor_generator self.anchor_generator = anchor_generator
self.head = head self.head = head
...@@ -310,10 +344,10 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -310,10 +344,10 @@ class RegionProposalNetwork(torch.nn.Module):
def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets): def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets):
""" """
Arguments: Arguments:
anchors (list[list[BoxList]]) objectness (Tensor)
objectness (list[Tensor]) pred_bbox_deltas (Tensor)
pred_bbox_deltas (list[Tensor]) labels (List[Tensor])
targets (list[BoxList]) regression_targets (List[Tensor])
Returns: Returns:
objectness_loss (Tensor) objectness_loss (Tensor)
...@@ -347,15 +381,17 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -347,15 +381,17 @@ class RegionProposalNetwork(torch.nn.Module):
""" """
Arguments: Arguments:
images (ImageList): images for which we want to compute the predictions images (ImageList): images for which we want to compute the predictions
features (list[Tensor]): features computed from the images that are features (List[Tensor]): features computed from the images that are
used for computing the predictions. Each tensor in the list used for computing the predictions. Each tensor in the list
correspond to different feature levels correspond to different feature levels
targets (list[BoxList): ground-truth boxes present in the image (optional) targets (List[Dict[Tensor]): ground-truth boxes present in the image (optional).
If provided, each element in the dict should contain a field `boxes`,
with the locations of the ground-truth boxes.
Returns: Returns:
boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per
image. image.
losses (dict[Tensor]): the losses for the model during training. During losses (Dict[Tensor]): the losses for the model during training. During
testing, it is an empty dict. testing, it is an empty dict.
""" """
# RPN uses all feature maps that are available # RPN uses all feature maps that are available
......
...@@ -9,6 +9,17 @@ from .roi_heads import paste_masks_in_image ...@@ -9,6 +9,17 @@ from .roi_heads import paste_masks_in_image
class GeneralizedRCNNTransform(nn.Module): class GeneralizedRCNNTransform(nn.Module):
"""
Performs input / target transformation before feeding the data to a GeneralizedRCNN
model.
The transformations it perform are:
- input normalization (mean subtraction and std division)
- input / target resizing to match min_size / max_size
It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
"""
def __init__(self, min_size, max_size, image_mean, image_std): def __init__(self, min_size, max_size, image_mean, image_std):
super(GeneralizedRCNNTransform, self).__init__() super(GeneralizedRCNNTransform, self).__init__()
if not isinstance(min_size, (list, tuple)): if not isinstance(min_size, (list, tuple)):
......
...@@ -59,6 +59,17 @@ def batched_nms(boxes, scores, idxs, iou_threshold): ...@@ -59,6 +59,17 @@ def batched_nms(boxes, scores, idxs, iou_threshold):
def remove_small_boxes(boxes, min_size): def remove_small_boxes(boxes, min_size):
"""
Remove boxes which contains at least one side smaller than min_size.
Arguments:
boxes (Tensor[N, 4]): boxes in [x0, y0, x1, y1] format
min_size (int): minimum size
Returns:
keep (Tensor[K]): indices of the boxes that have both sides
larger than min_size
"""
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
keep = (ws >= min_size) & (hs >= min_size) keep = (ws >= min_size) & (hs >= min_size)
keep = keep.nonzero().squeeze(1) keep = keep.nonzero().squeeze(1)
...@@ -67,9 +78,11 @@ def remove_small_boxes(boxes, min_size): ...@@ -67,9 +78,11 @@ def remove_small_boxes(boxes, min_size):
def clip_boxes_to_image(boxes, size): def clip_boxes_to_image(boxes, size):
""" """
Clip boxes so that they lie inside an image of size `size`.
Arguments: Arguments:
boxes (Tensor[N, 4]) boxes (Tensor[N, 4]): boxes in [x0, y0, x1, y1] format
size (Tuple[height, width]) size (Tuple[height, width]): size of the image
Returns: Returns:
clipped_boxes (Tensor[N, 4]) clipped_boxes (Tensor[N, 4])
......
...@@ -7,24 +7,43 @@ from torch import nn ...@@ -7,24 +7,43 @@ from torch import nn
class FeaturePyramidNetwork(nn.Module): class FeaturePyramidNetwork(nn.Module):
""" """
Module that adds a FPN on top of a list of feature maps. Module that adds a FPN from on top of a set of feature maps. This is based on
`"Feature Pyramid Network for Object Detection" <https://arxiv.org/abs/1612.03144>`_.
The feature maps are currently supposed to be in increasing depth The feature maps are currently supposed to be in increasing depth
order, and must be consecutive order.
"""
The input to the model is expected to be an OrderedDict[Tensor], containing
the feature maps on top of which the FPN will be added.
def __init__(
self, in_channels_list, out_channels, extra_blocks=None
):
"""
Arguments: Arguments:
in_channels_list (list[int]): number of channels for each feature map that in_channels_list (list[int]): number of channels for each feature map that
will be fed is passed to the module
out_channels (int): number of channels of the FPN representation out_channels (int): number of channels of the FPN representation
extra_blocks (ExtraFPNBlock or None): if provided, extra operations will extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
be performed. It is expected to take the fpn features, the original be performed. It is expected to take the fpn features, the original
features and the names of the original features as input, and returns features and the names of the original features as input, and returns
a new list of feature maps and their corresponding names a new list of feature maps and their corresponding names
Examples::
>>> m = torchvision.ops.FeaturePyramidNetwork([10, 20, 30], 5)
>>> # get some dummy data
>>> x = OrderedDict()
>>> x['feat0'] = torch.rand(1, 10, 64, 64)
>>> x['feat2'] = torch.rand(1, 20, 16, 16)
>>> x['feat3'] = torch.rand(1, 30, 8, 8)
>>> # compute the FPN on top of x
>>> output = m(x)
>>> print([(k, v.shape) for k, v in output.items()])
>>> # returns
>>> [('feat0', torch.Size([1, 5, 64, 64])),
>>> ('feat2', torch.Size([1, 5, 16, 16])),
>>> ('feat3', torch.Size([1, 5, 8, 8]))]
""" """
def __init__(self, in_channels_list, out_channels, extra_blocks=None):
super(FeaturePyramidNetwork, self).__init__() super(FeaturePyramidNetwork, self).__init__()
self.inner_blocks = nn.ModuleList() self.inner_blocks = nn.ModuleList()
self.layer_blocks = nn.ModuleList() self.layer_blocks = nn.ModuleList()
...@@ -48,8 +67,11 @@ class FeaturePyramidNetwork(nn.Module): ...@@ -48,8 +67,11 @@ class FeaturePyramidNetwork(nn.Module):
def forward(self, x): def forward(self, x):
""" """
Computes the FPN for a set of feature maps.
Arguments: Arguments:
x (OrderedDict[Tensor]): feature maps for each feature level. x (OrderedDict[Tensor]): feature maps for each feature level.
Returns: Returns:
results (OrderedDict[Tensor]): feature maps after FPN layers. results (OrderedDict[Tensor]): feature maps after FPN layers.
They are ordered from highest resolution first. They are ordered from highest resolution first.
...@@ -82,11 +104,28 @@ class FeaturePyramidNetwork(nn.Module): ...@@ -82,11 +104,28 @@ class FeaturePyramidNetwork(nn.Module):
class ExtraFPNBlock(nn.Module): class ExtraFPNBlock(nn.Module):
"""
Base class for the extra block in the FPN.
Arguments:
results (List[Tensor]): the result of the FPN
x (List[Tensor]): the original feature maps
names (List[str]): the names for each one of the
original feature maps
Returns:
results (List[Tensor]): the extended set of results
of the FPN
names (List[str]): the extended set of names for the results
"""
def forward(self, results, x, names): def forward(self, results, x, names):
pass pass
class LastLevelMaxPool(ExtraFPNBlock): class LastLevelMaxPool(ExtraFPNBlock):
"""
Applies a max_pool2d on top of the last feature map
"""
def forward(self, x, y, names): def forward(self, x, y, names):
names.append("pool") names.append("pool")
x.append(F.max_pool2d(x[-1], 1, 2, 0)) x.append(F.max_pool2d(x[-1], 1, 2, 0))
......
...@@ -28,6 +28,11 @@ class _NewEmptyTensorOp(torch.autograd.Function): ...@@ -28,6 +28,11 @@ class _NewEmptyTensorOp(torch.autograd.Function):
class Conv2d(torch.nn.Conv2d): class Conv2d(torch.nn.Conv2d):
"""
Equivalent to nn.Conv2d, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
def forward(self, x): def forward(self, x):
if x.numel() > 0: if x.numel() > 0:
return super(Conv2d, self).forward(x) return super(Conv2d, self).forward(x)
...@@ -44,6 +49,11 @@ class Conv2d(torch.nn.Conv2d): ...@@ -44,6 +49,11 @@ class Conv2d(torch.nn.Conv2d):
class ConvTranspose2d(torch.nn.ConvTranspose2d): class ConvTranspose2d(torch.nn.ConvTranspose2d):
"""
Equivalent to nn.ConvTranspose2d, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
def forward(self, x): def forward(self, x):
if x.numel() > 0: if x.numel() > 0:
return super(ConvTranspose2d, self).forward(x) return super(ConvTranspose2d, self).forward(x)
...@@ -65,6 +75,11 @@ class ConvTranspose2d(torch.nn.ConvTranspose2d): ...@@ -65,6 +75,11 @@ class ConvTranspose2d(torch.nn.ConvTranspose2d):
class BatchNorm2d(torch.nn.BatchNorm2d): class BatchNorm2d(torch.nn.BatchNorm2d):
"""
Equivalent to nn.BatchNorm2d, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
def forward(self, x): def forward(self, x):
if x.numel() > 0: if x.numel() > 0:
return super(BatchNorm2d, self).forward(x) return super(BatchNorm2d, self).forward(x)
...@@ -76,6 +91,11 @@ class BatchNorm2d(torch.nn.BatchNorm2d): ...@@ -76,6 +91,11 @@ class BatchNorm2d(torch.nn.BatchNorm2d):
def interpolate( def interpolate(
input, size=None, scale_factor=None, mode="nearest", align_corners=None input, size=None, scale_factor=None, mode="nearest", align_corners=None
): ):
"""
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
if input.numel() > 0: if input.numel() > 0:
return torch.nn.functional.interpolate( return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners input, size, scale_factor, mode, align_corners
......
...@@ -4,17 +4,13 @@ import torch.nn.functional as F ...@@ -4,17 +4,13 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from torchvision.ops import roi_align from torchvision.ops import roi_align
from torchvision.ops.boxes import box_area from torchvision.ops.boxes import box_area
class LevelMapper(object): class LevelMapper(object):
"""Determine which FPN level each RoI in a set of RoIs should map to based """Determine which FPN level each RoI in a set of RoIs should map to based
on the heuristic in the FPN paper. on the heuristic in the FPN paper.
"""
def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
"""
Arguments: Arguments:
k_min (int) k_min (int)
k_max (int) k_max (int)
...@@ -22,6 +18,8 @@ class LevelMapper(object): ...@@ -22,6 +18,8 @@ class LevelMapper(object):
canonical_level (int) canonical_level (int)
eps (float) eps (float)
""" """
def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
self.k_min = k_min self.k_min = k_min
self.k_max = k_max self.k_max = k_max
self.s0 = canonical_scale self.s0 = canonical_scale
...@@ -44,21 +42,34 @@ class LevelMapper(object): ...@@ -44,21 +42,34 @@ class LevelMapper(object):
class MultiScaleRoIAlign(nn.Module): class MultiScaleRoIAlign(nn.Module):
""" """
Pooler for Detection with or without FPN. Multi-scale RoIAlign pooling, which is useful for detection with or without FPN.
It currently hard-code ROIAlign in the implementation,
but that can be made more generic later on. It infers the scale of the pooling via the heuristics present in the FPN paper.
Also, the requirement of passing the scales is not strictly necessary, as they
can be inferred from the size of the feature map / size of original image,
which is available thanks to the BoxList.
"""
def __init__(self, featmap_names, output_size, sampling_ratio):
"""
Arguments: Arguments:
output_size (list[tuple[int]] or list[int]): output size for the pooled region featmap_names (List[str]): the names of the feature maps that will be used
scales (list[float]): scales for each Pooler for the pooling.
output_size (List[Tuple[int, int]] or List[int]): output size for the pooled region
sampling_ratio (int): sampling ratio for ROIAlign sampling_ratio (int): sampling ratio for ROIAlign
Examples::
>>> m = torchvision.ops.MultiScaleRoIAlign(['feat1', 'feat3'], 3, 2)
>>> i = OrderedDict()
>>> i['feat1'] = torch.rand(1, 5, 64, 64)
>>> i['feat2'] = torch.rand(1, 5, 32, 32) # this feature won't be used in the pooling
>>> i['feat3'] = torch.rand(1, 5, 16, 16)
>>> # create some random bounding boxes
>>> boxes = torch.rand(6, 4) * 256; boxes[:, 2:] += boxes[:, :2]
>>> # original image size, before computing the feature maps
>>> image_sizes = [(512, 512)]
>>> output = m(i, [boxes], image_sizes)
>>> print(output.shape)
>>> torch.Size([6, 5, 3, 3])
""" """
def __init__(self, featmap_names, output_size, sampling_ratio):
super(MultiScaleRoIAlign, self).__init__() super(MultiScaleRoIAlign, self).__init__()
if isinstance(output_size, int): if isinstance(output_size, int):
output_size = (output_size, output_size) output_size = (output_size, output_size)
...@@ -105,8 +116,14 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -105,8 +116,14 @@ class MultiScaleRoIAlign(nn.Module):
def forward(self, x, boxes, image_shapes): def forward(self, x, boxes, image_shapes):
""" """
Arguments: Arguments:
x (OrderedDict[Tensor]): feature maps for each level x (OrderedDict[Tensor]): feature maps for each level. They are assumed to have
boxes (list[BoxList]): boxes to be used to perform the pooling operation. all the same number of channels, but they can have different sizes.
boxes (List[Tensor[N, 4]]): boxes to be used to perform the pooling operation, in
[x0, y0, x1, y1] format and in the image reference size, not the feature map
reference.
image_shapes (List[Tuple[height, width]]): the sizes of each image before they
have been fed to a CNN to obtain feature maps. This allows us to infer the
scale factor for each one of the levels to be pooled.
Returns: Returns:
result (Tensor) result (Tensor)
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment