Unverified Commit c2bbefc2 authored by 江胤佐's avatar 江胤佐 Committed by GitHub
Browse files

fix type hints and spelling mistake in generalized_rcnn and poolers (#2550)

* fix type hints and move degenerate boxes to a function in torchvision.models.detection.generalized_rcnn

* format code

* format code

* changed to static method

* revert imports

* changed to method

* revert procedure for degenerating boxes
parent df6a7960
...@@ -4,6 +4,7 @@ Implements the Generalized R-CNN framework ...@@ -4,6 +4,7 @@ Implements the Generalized R-CNN framework
""" """
from collections import OrderedDict from collections import OrderedDict
from typing import Union
import torch import torch
from torch import nn from torch import nn
import warnings import warnings
...@@ -35,7 +36,7 @@ class GeneralizedRCNN(nn.Module): ...@@ -35,7 +36,7 @@ class GeneralizedRCNN(nn.Module):
@torch.jit.unused @torch.jit.unused
def eager_outputs(self, losses, detections): def eager_outputs(self, losses, detections):
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
if self.training: if self.training:
return losses return losses
...@@ -85,11 +86,11 @@ class GeneralizedRCNN(nn.Module): ...@@ -85,11 +86,11 @@ class GeneralizedRCNN(nn.Module):
boxes = target["boxes"] boxes = target["boxes"]
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any(): if degenerate_boxes.any():
# print the first degenrate box # print the first degenerate box
bb_idx = degenerate_boxes.any(dim=1).nonzero().view(-1)[0] bb_idx = degenerate_boxes.any(dim=1).nonzero().view(-1)[0]
degen_bb: List[float] = boxes[bb_idx].tolist() degen_bb: List[float] = boxes[bb_idx].tolist()
raise ValueError("All bounding boxes should have positive height and width." raise ValueError("All bounding boxes should have positive height and width."
" Found invaid box {} for target at index {}." " Found invalid box {} for target at index {}."
.format(degen_bb, target_idx)) .format(degen_bb, target_idx))
features = self.backbone(images.tensors) features = self.backbone(images.tensors)
......
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from typing import Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, Tensor from torch import nn, Tensor
...@@ -119,7 +121,7 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -119,7 +121,7 @@ class MultiScaleRoIAlign(nn.Module):
def __init__( def __init__(
self, self,
featmap_names: List[str], featmap_names: List[str],
output_size: List[int], output_size: Union[int, Tuple[int], List[int]],
sampling_ratio: int, sampling_ratio: int,
): ):
super(MultiScaleRoIAlign, self).__init__() super(MultiScaleRoIAlign, self).__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment