Unverified Commit 54a4550b authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Add typing annotations to detection/rpn (#4619)



* Annotate rpn

* Small fix

* Small fix and ignore
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent cb646f0f
...@@ -29,10 +29,6 @@ ignore_errors = True ...@@ -29,10 +29,6 @@ ignore_errors = True
ignore_errors = True ignore_errors = True
[mypy-torchvision.models.detection.rpn]
ignore_errors = True
[mypy-torchvision.models.detection.roi_heads] [mypy-torchvision.models.detection.roi_heads]
ignore_errors = True ignore_errors = True
......
from typing import List, Optional, Dict, Tuple from typing import List, Optional, Dict, Tuple, cast
import torch import torch
import torchvision import torchvision
...@@ -14,14 +14,14 @@ from .image_list import ImageList ...@@ -14,14 +14,14 @@ from .image_list import ImageList
@torch.jit.unused @torch.jit.unused
def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n): def _onnx_get_num_anchors_and_pre_nms_top_n(ob: Tensor, orig_pre_nms_top_n: int) -> Tuple[int, int]:
# type: (Tensor, int) -> Tuple[int, int]
from torch.onnx import operators from torch.onnx import operators
num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0) num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0)
pre_nms_top_n = torch.min(torch.cat((torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype), num_anchors), 0)) pre_nms_top_n = torch.min(torch.cat((torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype), num_anchors), 0))
return num_anchors, pre_nms_top_n # for mypy we cast at runtime
return cast(int, num_anchors), cast(int, pre_nms_top_n)
class RPNHead(nn.Module): class RPNHead(nn.Module):
...@@ -33,18 +33,17 @@ class RPNHead(nn.Module): ...@@ -33,18 +33,17 @@ class RPNHead(nn.Module):
num_anchors (int): number of anchors to be predicted num_anchors (int): number of anchors to be predicted
""" """
def __init__(self, in_channels, num_anchors): def __init__(self, in_channels: int, num_anchors: int) -> None:
super(RPNHead, self).__init__() super(RPNHead, self).__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1) self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1) self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1)
for layer in self.children(): for layer in self.children():
torch.nn.init.normal_(layer.weight, std=0.01) torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
torch.nn.init.constant_(layer.bias, 0) torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type]
def forward(self, x): def forward(self, x: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
# type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
logits = [] logits = []
bbox_reg = [] bbox_reg = []
for feature in x: for feature in x:
...@@ -54,16 +53,14 @@ class RPNHead(nn.Module): ...@@ -54,16 +53,14 @@ class RPNHead(nn.Module):
return logits, bbox_reg return logits, bbox_reg
def permute_and_flatten(layer, N, A, C, H, W): def permute_and_flatten(layer: Tensor, N: int, A: int, C: int, H: int, W: int) -> Tensor:
# type: (Tensor, int, int, int, int, int) -> Tensor
layer = layer.view(N, -1, C, H, W) layer = layer.view(N, -1, C, H, W)
layer = layer.permute(0, 3, 4, 1, 2) layer = layer.permute(0, 3, 4, 1, 2)
layer = layer.reshape(N, -1, C) layer = layer.reshape(N, -1, C)
return layer return layer
def concat_box_prediction_layers(box_cls, box_regression): def concat_box_prediction_layers(box_cls: List[Tensor], box_regression: List[Tensor]) -> Tuple[Tensor, Tensor]:
# type: (List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
box_cls_flattened = [] box_cls_flattened = []
box_regression_flattened = [] box_regression_flattened = []
# for each feature level, permute the outputs to make them be in the # for each feature level, permute the outputs to make them be in the
...@@ -104,10 +101,10 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -104,10 +101,10 @@ class RegionProposalNetwork(torch.nn.Module):
for computing the loss for computing the loss
positive_fraction (float): proportion of positive anchors in a mini-batch during training positive_fraction (float): proportion of positive anchors in a mini-batch during training
of the RPN of the RPN
pre_nms_top_n (Dict[int]): number of proposals to keep before applying NMS. It should pre_nms_top_n (Dict[str, int]): number of proposals to keep before applying NMS. It should
contain two fields: training and testing, to allow for different values depending contain two fields: training and testing, to allow for different values depending
on training or evaluation on training or evaluation
post_nms_top_n (Dict[int]): number of proposals to keep after applying NMS. It should post_nms_top_n (Dict[str, int]): number of proposals to keep after applying NMS. It should
contain two fields: training and testing, to allow for different values depending contain two fields: training and testing, to allow for different values depending
on training or evaluation on training or evaluation
nms_thresh (float): NMS threshold used for postprocessing the RPN proposals nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
...@@ -118,25 +115,23 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -118,25 +115,23 @@ class RegionProposalNetwork(torch.nn.Module):
"box_coder": det_utils.BoxCoder, "box_coder": det_utils.BoxCoder,
"proposal_matcher": det_utils.Matcher, "proposal_matcher": det_utils.Matcher,
"fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler, "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
"pre_nms_top_n": Dict[str, int],
"post_nms_top_n": Dict[str, int],
} }
def __init__( def __init__(
self, self,
anchor_generator, anchor_generator: AnchorGenerator,
head, head: nn.Module,
# # Faster-RCNN Training
fg_iou_thresh, fg_iou_thresh: float,
bg_iou_thresh, bg_iou_thresh: float,
batch_size_per_image, batch_size_per_image: int,
positive_fraction, positive_fraction: float,
# # Faster-RCNN Inference
pre_nms_top_n, pre_nms_top_n: Dict[str, int],
post_nms_top_n, post_nms_top_n: Dict[str, int],
nms_thresh, nms_thresh: float,
score_thresh=0.0, score_thresh: float = 0.0,
): ) -> None:
super(RegionProposalNetwork, self).__init__() super(RegionProposalNetwork, self).__init__()
self.anchor_generator = anchor_generator self.anchor_generator = anchor_generator
self.head = head self.head = head
...@@ -159,18 +154,20 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -159,18 +154,20 @@ class RegionProposalNetwork(torch.nn.Module):
self.score_thresh = score_thresh self.score_thresh = score_thresh
self.min_size = 1e-3 self.min_size = 1e-3
def pre_nms_top_n(self): def pre_nms_top_n(self) -> int:
if self.training: if self.training:
return self._pre_nms_top_n["training"] return self._pre_nms_top_n["training"]
return self._pre_nms_top_n["testing"] return self._pre_nms_top_n["testing"]
def post_nms_top_n(self): def post_nms_top_n(self) -> int:
if self.training: if self.training:
return self._post_nms_top_n["training"] return self._post_nms_top_n["training"]
return self._post_nms_top_n["testing"] return self._post_nms_top_n["testing"]
def assign_targets_to_anchors(self, anchors, targets): def assign_targets_to_anchors(
# type: (List[Tensor], List[Dict[str, Tensor]]) -> Tuple[List[Tensor], List[Tensor]] self, anchors: List[Tensor], targets: List[Dict[str, Tensor]]
) -> Tuple[List[Tensor], List[Tensor]]:
labels = [] labels = []
matched_gt_boxes = [] matched_gt_boxes = []
for anchors_per_image, targets_per_image in zip(anchors, targets): for anchors_per_image, targets_per_image in zip(anchors, targets):
...@@ -205,8 +202,7 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -205,8 +202,7 @@ class RegionProposalNetwork(torch.nn.Module):
matched_gt_boxes.append(matched_gt_boxes_per_image) matched_gt_boxes.append(matched_gt_boxes_per_image)
return labels, matched_gt_boxes return labels, matched_gt_boxes
def _get_top_n_idx(self, objectness, num_anchors_per_level): def _get_top_n_idx(self, objectness: Tensor, num_anchors_per_level: List[int]) -> Tensor:
# type: (Tensor, List[int]) -> Tensor
r = [] r = []
offset = 0 offset = 0
for ob in objectness.split(num_anchors_per_level, 1): for ob in objectness.split(num_anchors_per_level, 1):
...@@ -220,8 +216,14 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -220,8 +216,14 @@ class RegionProposalNetwork(torch.nn.Module):
offset += num_anchors offset += num_anchors
return torch.cat(r, dim=1) return torch.cat(r, dim=1)
def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level): def filter_proposals(
# type: (Tensor, Tensor, List[Tuple[int, int]], List[int]) -> Tuple[List[Tensor], List[Tensor]] self,
proposals: Tensor,
objectness: Tensor,
image_shapes: List[Tuple[int, int]],
num_anchors_per_level: List[int],
) -> Tuple[List[Tensor], List[Tensor]]:
num_images = proposals.shape[0] num_images = proposals.shape[0]
device = proposals.device device = proposals.device
# do not backprop through objectness # do not backprop through objectness
...@@ -271,8 +273,9 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -271,8 +273,9 @@ class RegionProposalNetwork(torch.nn.Module):
final_scores.append(scores) final_scores.append(scores)
return final_boxes, final_scores return final_boxes, final_scores
def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets): def compute_loss(
# type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] self, objectness: Tensor, pred_bbox_deltas: Tensor, labels: List[Tensor], regression_targets: List[Tensor]
) -> Tuple[Tensor, Tensor]:
""" """
Args: Args:
objectness (Tensor) objectness (Tensor)
...@@ -312,25 +315,25 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -312,25 +315,25 @@ class RegionProposalNetwork(torch.nn.Module):
def forward( def forward(
self, self,
images, # type: ImageList images: ImageList,
features, # type: Dict[str, Tensor] features: Dict[str, Tensor],
targets=None, # type: Optional[List[Dict[str, Tensor]]] targets: Optional[List[Dict[str, Tensor]]] = None,
): ) -> Tuple[List[Tensor], Dict[str, Tensor]]:
# type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]]
""" """
Args: Args:
images (ImageList): images for which we want to compute the predictions images (ImageList): images for which we want to compute the predictions
features (OrderedDict[Tensor]): features computed from the images that are features (Dict[str, Tensor]): features computed from the images that are
used for computing the predictions. Each tensor in the list used for computing the predictions. Each tensor in the list
correspond to different feature levels correspond to different feature levels
targets (List[Dict[Tensor]]): ground-truth boxes present in the image (optional). targets (List[Dict[str, Tensor]]): ground-truth boxes present in the image (optional).
If provided, each element in the dict should contain a field `boxes`, If provided, each element in the dict should contain a field `boxes`,
with the locations of the ground-truth boxes. with the locations of the ground-truth boxes.
Returns: Returns:
boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per
image. image.
losses (Dict[Tensor]): the losses for the model during training. During losses (Dict[str, Tensor]): the losses for the model during training. During
testing, it is an empty dict. testing, it is an empty dict.
""" """
# RPN uses all feature maps that are available # RPN uses all feature maps that are available
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment