Unverified Commit c790216a authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Add typing Annotations to detection/utils (#4583)



* Start annotating utils

* checking

* Add annotations at _utils.py

* Remove unnecessary comments.

* re-checked typings

* Update typing

* Ignore small error

* Use optional tensor

* Ignore for JIT
Co-authored-by: default avatarKhushi Agrawal <khushiagrawal411@gmail.com>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 9ad33265
......@@ -17,7 +17,55 @@ ignore_errors = True
ignore_errors=True
[mypy-torchvision.models.detection.*]
[mypy-torchvision.models.detection.anchor_utils]
ignore_errors = True
[mypy-torchvision.models.detection.backbone_utils]
ignore_errors = True
[mypy-torchvision.models.detection.image_list]
ignore_errors = True
[mypy-torchvision.models.detection.transform]
ignore_errors = True
[mypy-torchvision.models.detection.rpn]
ignore_errors = True
[mypy-torchvision.models.detection.roi_heads]
ignore_errors = True
[mypy-torchvision.models.detection.generalized_rcnn]
ignore_errors = True
[mypy-torchvision.models.detection.faster_rcnn]
ignore_errors = True
[mypy-torchvision.models.detection.mask_rcnn]
ignore_errors = True
[mypy-torchvision.models.detection.keypoint_rcnn]
ignore_errors = True
[mypy-torchvision.models.detection.retinanet]
ignore_errors = True
[mypy-torchvision.models.detection.ssd]
ignore_errors = True
[mypy-torchvision.models.detection.ssdlite]
ignore_errors = True
......
......@@ -3,7 +3,7 @@ from collections import OrderedDict
from typing import List, Tuple
import torch
from torch import Tensor
from torch import Tensor, nn
from torchvision.ops.misc import FrozenBatchNorm2d
......@@ -12,18 +12,16 @@ class BalancedPositiveNegativeSampler(object):
This class samples batches, ensuring that they contain a fixed proportion of positives
"""
def __init__(self, batch_size_per_image, positive_fraction):
# type: (int, float) -> None
def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None:
"""
Args:
batch_size_per_image (int): number of elements to be selected per image
positive_fraction (float): percentace of positive elements per batch
positive_fraction (float): percentage of positive elements per batch
"""
self.batch_size_per_image = batch_size_per_image
self.positive_fraction = positive_fraction
def __call__(self, matched_idxs):
# type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
"""
Args:
matched idxs: list of tensors containing -1, 0 or positive values.
......@@ -73,8 +71,7 @@ class BalancedPositiveNegativeSampler(object):
@torch.jit._script_if_tracing
def encode_boxes(reference_boxes, proposals, weights):
# type: (torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor:
"""
Encode a set of proposals with respect to some
reference boxes
......@@ -127,8 +124,9 @@ class BoxCoder(object):
the representation used for training the regressors.
"""
def __init__(self, weights, bbox_xform_clip=math.log(1000.0 / 16)):
# type: (Tuple[float, float, float, float], float) -> None
def __init__(
self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16)
) -> None:
"""
Args:
weights (4-element tuple)
......@@ -137,15 +135,14 @@ class BoxCoder(object):
self.weights = weights
self.bbox_xform_clip = bbox_xform_clip
def encode(self, reference_boxes, proposals):
# type: (List[Tensor], List[Tensor]) -> List[Tensor]
def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]:
boxes_per_image = [len(b) for b in reference_boxes]
reference_boxes = torch.cat(reference_boxes, dim=0)
proposals = torch.cat(proposals, dim=0)
targets = self.encode_single(reference_boxes, proposals)
return targets.split(boxes_per_image, 0)
def encode_single(self, reference_boxes, proposals):
def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
"""
Encode a set of proposals with respect to some
reference boxes
......@@ -161,8 +158,7 @@ class BoxCoder(object):
return targets
def decode(self, rel_codes, boxes):
# type: (Tensor, List[Tensor]) -> Tensor
def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
assert isinstance(boxes, (list, tuple))
assert isinstance(rel_codes, torch.Tensor)
boxes_per_image = [b.size(0) for b in boxes]
......@@ -177,7 +173,7 @@ class BoxCoder(object):
pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
return pred_boxes
def decode_single(self, rel_codes, boxes):
def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
"""
From a set of original boxes and encoded relative box offsets,
get the decoded boxes.
......@@ -244,8 +240,7 @@ class Matcher(object):
"BETWEEN_THRESHOLDS": int,
}
def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False):
# type: (float, float, bool) -> None
def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None:
"""
Args:
high_threshold (float): quality values greater than or equal to
......@@ -266,7 +261,7 @@ class Matcher(object):
self.low_threshold = low_threshold
self.allow_low_quality_matches = allow_low_quality_matches
def __call__(self, match_quality_matrix):
def __call__(self, match_quality_matrix: Tensor) -> Tensor:
"""
Args:
match_quality_matrix (Tensor[float]): an MxN tensor, containing the
......@@ -290,7 +285,7 @@ class Matcher(object):
if self.allow_low_quality_matches:
all_matches = matches.clone()
else:
all_matches = None
all_matches = None # type: ignore[assignment]
# Assign candidate matches with low quality to negative (unassigned) values
below_low_threshold = matched_vals < self.low_threshold
......@@ -304,7 +299,7 @@ class Matcher(object):
return matches
def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):
def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None:
"""
Produce additional matches for predictions that have only low-quality matches.
Specifically, for each ground-truth find the set of predictions that have
......@@ -335,10 +330,10 @@ class Matcher(object):
class SSDMatcher(Matcher):
def __init__(self, threshold):
def __init__(self, threshold: float) -> None:
super().__init__(threshold, threshold, allow_low_quality_matches=False)
def __call__(self, match_quality_matrix):
def __call__(self, match_quality_matrix: Tensor) -> Tensor:
matches = super().__call__(match_quality_matrix)
# For each gt, find the prediction with which it has the highest quality
......@@ -350,7 +345,7 @@ class SSDMatcher(Matcher):
return matches
def overwrite_eps(model, eps):
def overwrite_eps(model: nn.Module, eps: float) -> None:
"""
This method overwrites the default eps values of all the
FrozenBatchNorm2d layers of the model with the provided value.
......@@ -368,7 +363,7 @@ def overwrite_eps(model, eps):
module.eps = eps
def retrieve_out_channels(model, size):
def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
"""
This method retrieves the number of output channels of a specific model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment