Unverified Commit 5906d590 authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Add typing annotations to detection/generalized_rcnn (#4631)



* Update typing

* Fix bug

* Unblock mypy

* Ignore small error
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 9a34c0c9
......@@ -46,10 +46,6 @@ ignore_errors = True
ignore_errors = True
[mypy-torchvision.models.detection.generalized_rcnn]
ignore_errors = True
[mypy-torchvision.models.detection.faster_rcnn]
ignore_errors = True
......
......@@ -25,7 +25,7 @@ class GeneralizedRCNN(nn.Module):
the model
"""
def __init__(self, backbone, rpn, roi_heads, transform):
def __init__(self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module) -> None:
super().__init__()
_log_api_usage_once(self)
self.transform = transform
......@@ -36,19 +36,26 @@ class GeneralizedRCNN(nn.Module):
self._has_warned = False
@torch.jit.unused
def eager_outputs(self, losses, detections):
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
def eager_outputs(
self,
losses: Dict[str, Tensor],
detections: List[Dict[str, Tensor]],
) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]:
if self.training:
return losses
return detections
def forward(self, images, targets=None):
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
def forward(
self,
images: List[Tensor],
targets: Optional[List[Dict[str, Tensor]]] = None,
) -> Union[Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]], Dict[str, Tensor], List[Dict[str, Tensor]]]:
"""
Args:
images (list[Tensor]): images to be processed
targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
Returns:
result (list[BoxList] or dict[Tensor]): the output from the model.
......@@ -97,7 +104,7 @@ class GeneralizedRCNN(nn.Module):
features = OrderedDict([("0", features)])
proposals, proposal_losses = self.rpn(images, features, targets)
detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) # type: ignore[operator]
losses = {}
losses.update(detector_losses)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment