Unverified Commit 5906d590 authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Add typing annotations to detection/generalized_rcnn (#4631)



* Update typing

* Fix bug

* Unblock mypy

* Ignore small error
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 9a34c0c9
...@@ -46,10 +46,6 @@ ignore_errors = True ...@@ -46,10 +46,6 @@ ignore_errors = True
ignore_errors = True ignore_errors = True
[mypy-torchvision.models.detection.generalized_rcnn]
ignore_errors = True
[mypy-torchvision.models.detection.faster_rcnn] [mypy-torchvision.models.detection.faster_rcnn]
ignore_errors = True ignore_errors = True
......
...@@ -25,7 +25,7 @@ class GeneralizedRCNN(nn.Module): ...@@ -25,7 +25,7 @@ class GeneralizedRCNN(nn.Module):
the model the model
""" """
def __init__(self, backbone, rpn, roi_heads, transform): def __init__(self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module) -> None:
super().__init__() super().__init__()
_log_api_usage_once(self) _log_api_usage_once(self)
self.transform = transform self.transform = transform
...@@ -36,19 +36,26 @@ class GeneralizedRCNN(nn.Module): ...@@ -36,19 +36,26 @@ class GeneralizedRCNN(nn.Module):
self._has_warned = False self._has_warned = False
@torch.jit.unused @torch.jit.unused
def eager_outputs(self, losses, detections): def eager_outputs(
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]] self,
losses: Dict[str, Tensor],
detections: List[Dict[str, Tensor]],
) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]:
if self.training: if self.training:
return losses return losses
return detections return detections
def forward(self, images, targets=None): def forward(
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] self,
images: List[Tensor],
targets: Optional[List[Dict[str, Tensor]]] = None,
) -> Union[Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]], Dict[str, Tensor], List[Dict[str, Tensor]]]:
""" """
Args: Args:
images (list[Tensor]): images to be processed images (list[Tensor]): images to be processed
targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional) targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
Returns: Returns:
result (list[BoxList] or dict[Tensor]): the output from the model. result (list[BoxList] or dict[Tensor]): the output from the model.
...@@ -97,7 +104,7 @@ class GeneralizedRCNN(nn.Module): ...@@ -97,7 +104,7 @@ class GeneralizedRCNN(nn.Module):
features = OrderedDict([("0", features)]) features = OrderedDict([("0", features)])
proposals, proposal_losses = self.rpn(images, features, targets) proposals, proposal_losses = self.rpn(images, features, targets)
detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets) detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) # type: ignore[operator]
losses = {} losses = {}
losses.update(detector_losses) losses.update(detector_losses)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment