"git@developer.sourcefind.cn:OpenDAS/torch-cluster.git" did not exist on "c67425b070ec8fc2d6e4757cd74cf4b171d48902"
Unverified Commit c790216a authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Add typing Annotations to detection/utils (#4583)



* Start annotating utils

* checking

* Add annotations at _utils.py

* Remove unnecessary comments.

* re-checked typings

* Update typing

* Ignore small error

* Use optional tensor

* Ignore for JIT
Co-authored-by: default avatarKhushi Agrawal <khushiagrawal411@gmail.com>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 9ad33265
...@@ -17,7 +17,55 @@ ignore_errors = True ...@@ -17,7 +17,55 @@ ignore_errors = True
ignore_errors=True ignore_errors=True
[mypy-torchvision.models.detection.*] [mypy-torchvision.models.detection.anchor_utils]
ignore_errors = True
[mypy-torchvision.models.detection.backbone_utils]
ignore_errors = True
[mypy-torchvision.models.detection.image_list]
ignore_errors = True
[mypy-torchvision.models.detection.transform]
ignore_errors = True
[mypy-torchvision.models.detection.rpn]
ignore_errors = True
[mypy-torchvision.models.detection.roi_heads]
ignore_errors = True
[mypy-torchvision.models.detection.generalized_rcnn]
ignore_errors = True
[mypy-torchvision.models.detection.faster_rcnn]
ignore_errors = True
[mypy-torchvision.models.detection.mask_rcnn]
ignore_errors = True
[mypy-torchvision.models.detection.keypoint_rcnn]
ignore_errors = True
[mypy-torchvision.models.detection.retinanet]
ignore_errors = True
[mypy-torchvision.models.detection.ssd]
ignore_errors = True
[mypy-torchvision.models.detection.ssdlite]
ignore_errors = True ignore_errors = True
......
...@@ -3,7 +3,7 @@ from collections import OrderedDict ...@@ -3,7 +3,7 @@ from collections import OrderedDict
from typing import List, Tuple from typing import List, Tuple
import torch import torch
from torch import Tensor from torch import Tensor, nn
from torchvision.ops.misc import FrozenBatchNorm2d from torchvision.ops.misc import FrozenBatchNorm2d
...@@ -12,18 +12,16 @@ class BalancedPositiveNegativeSampler(object): ...@@ -12,18 +12,16 @@ class BalancedPositiveNegativeSampler(object):
This class samples batches, ensuring that they contain a fixed proportion of positives This class samples batches, ensuring that they contain a fixed proportion of positives
""" """
def __init__(self, batch_size_per_image, positive_fraction): def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None:
# type: (int, float) -> None
""" """
Args: Args:
batch_size_per_image (int): number of elements to be selected per image batch_size_per_image (int): number of elements to be selected per image
positive_fraction (float): percentace of positive elements per batch positive_fraction (float): percentage of positive elements per batch
""" """
self.batch_size_per_image = batch_size_per_image self.batch_size_per_image = batch_size_per_image
self.positive_fraction = positive_fraction self.positive_fraction = positive_fraction
def __call__(self, matched_idxs): def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
# type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
""" """
Args: Args:
matched idxs: list of tensors containing -1, 0 or positive values. matched idxs: list of tensors containing -1, 0 or positive values.
...@@ -73,8 +71,7 @@ class BalancedPositiveNegativeSampler(object): ...@@ -73,8 +71,7 @@ class BalancedPositiveNegativeSampler(object):
@torch.jit._script_if_tracing @torch.jit._script_if_tracing
def encode_boxes(reference_boxes, proposals, weights): def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor:
# type: (torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
""" """
Encode a set of proposals with respect to some Encode a set of proposals with respect to some
reference boxes reference boxes
...@@ -127,8 +124,9 @@ class BoxCoder(object): ...@@ -127,8 +124,9 @@ class BoxCoder(object):
the representation used for training the regressors. the representation used for training the regressors.
""" """
def __init__(self, weights, bbox_xform_clip=math.log(1000.0 / 16)): def __init__(
# type: (Tuple[float, float, float, float], float) -> None self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16)
) -> None:
""" """
Args: Args:
weights (4-element tuple) weights (4-element tuple)
...@@ -137,15 +135,14 @@ class BoxCoder(object): ...@@ -137,15 +135,14 @@ class BoxCoder(object):
self.weights = weights self.weights = weights
self.bbox_xform_clip = bbox_xform_clip self.bbox_xform_clip = bbox_xform_clip
def encode(self, reference_boxes, proposals): def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]:
# type: (List[Tensor], List[Tensor]) -> List[Tensor]
boxes_per_image = [len(b) for b in reference_boxes] boxes_per_image = [len(b) for b in reference_boxes]
reference_boxes = torch.cat(reference_boxes, dim=0) reference_boxes = torch.cat(reference_boxes, dim=0)
proposals = torch.cat(proposals, dim=0) proposals = torch.cat(proposals, dim=0)
targets = self.encode_single(reference_boxes, proposals) targets = self.encode_single(reference_boxes, proposals)
return targets.split(boxes_per_image, 0) return targets.split(boxes_per_image, 0)
def encode_single(self, reference_boxes, proposals): def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
""" """
Encode a set of proposals with respect to some Encode a set of proposals with respect to some
reference boxes reference boxes
...@@ -161,8 +158,7 @@ class BoxCoder(object): ...@@ -161,8 +158,7 @@ class BoxCoder(object):
return targets return targets
def decode(self, rel_codes, boxes): def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
# type: (Tensor, List[Tensor]) -> Tensor
assert isinstance(boxes, (list, tuple)) assert isinstance(boxes, (list, tuple))
assert isinstance(rel_codes, torch.Tensor) assert isinstance(rel_codes, torch.Tensor)
boxes_per_image = [b.size(0) for b in boxes] boxes_per_image = [b.size(0) for b in boxes]
...@@ -177,7 +173,7 @@ class BoxCoder(object): ...@@ -177,7 +173,7 @@ class BoxCoder(object):
pred_boxes = pred_boxes.reshape(box_sum, -1, 4) pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
return pred_boxes return pred_boxes
def decode_single(self, rel_codes, boxes): def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
""" """
From a set of original boxes and encoded relative box offsets, From a set of original boxes and encoded relative box offsets,
get the decoded boxes. get the decoded boxes.
...@@ -244,8 +240,7 @@ class Matcher(object): ...@@ -244,8 +240,7 @@ class Matcher(object):
"BETWEEN_THRESHOLDS": int, "BETWEEN_THRESHOLDS": int,
} }
def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False): def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None:
# type: (float, float, bool) -> None
""" """
Args: Args:
high_threshold (float): quality values greater than or equal to high_threshold (float): quality values greater than or equal to
...@@ -266,7 +261,7 @@ class Matcher(object): ...@@ -266,7 +261,7 @@ class Matcher(object):
self.low_threshold = low_threshold self.low_threshold = low_threshold
self.allow_low_quality_matches = allow_low_quality_matches self.allow_low_quality_matches = allow_low_quality_matches
def __call__(self, match_quality_matrix): def __call__(self, match_quality_matrix: Tensor) -> Tensor:
""" """
Args: Args:
match_quality_matrix (Tensor[float]): an MxN tensor, containing the match_quality_matrix (Tensor[float]): an MxN tensor, containing the
...@@ -290,7 +285,7 @@ class Matcher(object): ...@@ -290,7 +285,7 @@ class Matcher(object):
if self.allow_low_quality_matches: if self.allow_low_quality_matches:
all_matches = matches.clone() all_matches = matches.clone()
else: else:
all_matches = None all_matches = None # type: ignore[assignment]
# Assign candidate matches with low quality to negative (unassigned) values # Assign candidate matches with low quality to negative (unassigned) values
below_low_threshold = matched_vals < self.low_threshold below_low_threshold = matched_vals < self.low_threshold
...@@ -304,7 +299,7 @@ class Matcher(object): ...@@ -304,7 +299,7 @@ class Matcher(object):
return matches return matches
def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix): def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None:
""" """
Produce additional matches for predictions that have only low-quality matches. Produce additional matches for predictions that have only low-quality matches.
Specifically, for each ground-truth find the set of predictions that have Specifically, for each ground-truth find the set of predictions that have
...@@ -335,10 +330,10 @@ class Matcher(object): ...@@ -335,10 +330,10 @@ class Matcher(object):
class SSDMatcher(Matcher): class SSDMatcher(Matcher):
def __init__(self, threshold): def __init__(self, threshold: float) -> None:
super().__init__(threshold, threshold, allow_low_quality_matches=False) super().__init__(threshold, threshold, allow_low_quality_matches=False)
def __call__(self, match_quality_matrix): def __call__(self, match_quality_matrix: Tensor) -> Tensor:
matches = super().__call__(match_quality_matrix) matches = super().__call__(match_quality_matrix)
# For each gt, find the prediction with which it has the highest quality # For each gt, find the prediction with which it has the highest quality
...@@ -350,7 +345,7 @@ class SSDMatcher(Matcher): ...@@ -350,7 +345,7 @@ class SSDMatcher(Matcher):
return matches return matches
def overwrite_eps(model, eps): def overwrite_eps(model: nn.Module, eps: float) -> None:
""" """
This method overwrites the default eps values of all the This method overwrites the default eps values of all the
FrozenBatchNorm2d layers of the model with the provided value. FrozenBatchNorm2d layers of the model with the provided value.
...@@ -368,7 +363,7 @@ def overwrite_eps(model, eps): ...@@ -368,7 +363,7 @@ def overwrite_eps(model, eps):
module.eps = eps module.eps = eps
def retrieve_out_channels(model, size): def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
""" """
This method retrieves the number of output channels of a specific model. This method retrieves the number of output channels of a specific model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment