Unverified Commit 971c3e45 authored by Vukašin Manojlović's avatar Vukašin Manojlović Committed by GitHub
Browse files

Type annotations for torchvision.ops (#2331)

* Add type annotations for torchvision.ops

* Fix type annotations for torchvision.ops

* Fix typo in import

* Fix undefined name in FeaturePyramidNetwork
parent 67f5fcf7
import torch
from torch import Tensor
from torch.jit.annotations import List
from torch.jit.annotations import List, Tuple
def _cat(tensors, dim=0):
# type: (List[Tensor], int) -> Tensor
def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
"""
Efficient version of torch.cat that avoids a copy if there is only a single element in a list
"""
......@@ -15,8 +14,7 @@ def _cat(tensors, dim=0):
return torch.cat(tensors, dim)
def convert_boxes_to_roi_format(boxes):
# type: (List[Tensor]) -> Tensor
def convert_boxes_to_roi_format(boxes: List[Tensor]) -> Tensor:
concat_boxes = _cat([b for b in boxes], dim=0)
temp = []
for i, b in enumerate(boxes):
......@@ -26,7 +24,7 @@ def convert_boxes_to_roi_format(boxes):
return rois
def check_roi_boxes_shape(boxes):
def check_roi_boxes_shape(boxes: Tensor):
if isinstance(boxes, (list, tuple)):
for _tensor in boxes:
assert _tensor.size(1) == 4, \
......
......@@ -4,8 +4,7 @@ from torch import Tensor
import torchvision
def nms(boxes, scores, iou_threshold):
# type: (Tensor, Tensor, float) -> Tensor
def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
"""
Performs non-maximum suppression (NMS) on the boxes according
to their intersection-over-union (IoU).
......@@ -41,8 +40,12 @@ def nms(boxes, scores, iou_threshold):
@torch.jit._script_if_tracing
def batched_nms(boxes, scores, idxs, iou_threshold):
# type: (Tensor, Tensor, Tensor, float) -> Tensor
def batched_nms(
boxes: Tensor,
scores: Tensor,
idxs: Tensor,
iou_threshold: float,
) -> Tensor:
"""
Performs non-maximum suppression in a batched fashion.
......@@ -83,8 +86,7 @@ def batched_nms(boxes, scores, idxs, iou_threshold):
return keep
def remove_small_boxes(boxes, min_size):
# type: (Tensor, float) -> Tensor
def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
"""
Remove boxes which contains at least one side smaller than min_size.
......@@ -102,8 +104,7 @@ def remove_small_boxes(boxes, min_size):
return keep
def clip_boxes_to_image(boxes, size):
# type: (Tensor, Tuple[int, int]) -> Tensor
def clip_boxes_to_image(boxes: Tensor, size: Tuple[int, int]) -> Tensor:
"""
Clip boxes so that they lie inside an image of size `size`.
......@@ -132,7 +133,7 @@ def clip_boxes_to_image(boxes, size):
return clipped_boxes.reshape(boxes.shape)
def box_area(boxes):
def box_area(boxes: Tensor) -> Tensor:
"""
Computes the area of a set of bounding boxes, which are specified by its
(x1, y1, x2, y2) coordinates.
......@@ -149,7 +150,7 @@ def box_area(boxes):
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
# with slight modifications
def box_iou(boxes1, boxes2):
def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
"""
Return intersection-over-union (Jaccard index) of boxes.
......
......@@ -8,8 +8,15 @@ from torch.nn.modules.utils import _pair
from torch.jit.annotations import Optional, Tuple
def deform_conv2d(input, offset, weight, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1)):
# type: (Tensor, Tensor, Tensor, Optional[Tensor], Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
def deform_conv2d(
input: Tensor,
offset: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
dilation: Tuple[int, int] = (1, 1),
) -> Tensor:
"""
Performs Deformable Convolution, described in Deformable Convolutional Networks
......@@ -80,8 +87,17 @@ class DeformConv2d(nn.Module):
"""
See deform_conv2d
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=1, groups=1, bias=True):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
):
super(DeformConv2d, self).__init__()
if in_channels % groups != 0:
......@@ -107,14 +123,14 @@ class DeformConv2d(nn.Module):
self.reset_parameters()
def reset_parameters(self):
def reset_parameters(self) -> None:
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def forward(self, input, offset):
def forward(self, input: Tensor, offset: Tensor) -> Tensor:
"""
Arguments:
input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
......@@ -125,7 +141,7 @@ class DeformConv2d(nn.Module):
return deform_conv2d(input, offset, self.weight, self.bias, stride=self.stride,
padding=self.padding, dilation=self.dilation)
def __repr__(self):
def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += '{in_channels}'
s += ', {out_channels}'
......
......@@ -4,7 +4,31 @@ import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.jit.annotations import Tuple, List, Dict
from torch.jit.annotations import Tuple, List, Dict, Optional
class ExtraFPNBlock(nn.Module):
"""
Base class for the extra block in the FPN.
Arguments:
results (List[Tensor]): the result of the FPN
x (List[Tensor]): the original feature maps
names (List[str]): the names for each one of the
original feature maps
Returns:
results (List[Tensor]): the extended set of results
of the FPN
names (List[str]): the extended set of names for the results
"""
def forward(
self,
results: List[Tensor],
x: List[Tensor],
names: List[str],
) -> Tuple[List[Tensor], List[str]]:
pass
class FeaturePyramidNetwork(nn.Module):
......@@ -44,7 +68,12 @@ class FeaturePyramidNetwork(nn.Module):
>>> ('feat3', torch.Size([1, 5, 8, 8]))]
"""
def __init__(self, in_channels_list, out_channels, extra_blocks=None):
def __init__(
self,
in_channels_list: List[int],
out_channels: int,
extra_blocks: Optional[ExtraFPNBlock] = None,
):
super(FeaturePyramidNetwork, self).__init__()
self.inner_blocks = nn.ModuleList()
self.layer_blocks = nn.ModuleList()
......@@ -66,8 +95,7 @@ class FeaturePyramidNetwork(nn.Module):
assert isinstance(extra_blocks, ExtraFPNBlock)
self.extra_blocks = extra_blocks
def get_result_from_inner_blocks(self, x, idx):
# type: (Tensor, int) -> Tensor
def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor:
"""
This is equivalent to self.inner_blocks[idx](x),
but torchscript doesn't support this yet
......@@ -85,8 +113,7 @@ class FeaturePyramidNetwork(nn.Module):
i += 1
return out
def get_result_from_layer_blocks(self, x, idx):
# type: (Tensor, int) -> Tensor
def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor:
"""
This is equivalent to self.layer_blocks[idx](x),
but torchscript doesn't support this yet
......@@ -104,8 +131,7 @@ class FeaturePyramidNetwork(nn.Module):
i += 1
return out
def forward(self, x):
# type: (Dict[str, Tensor]) -> Dict[str, Tensor]
def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
Computes the FPN for a set of feature maps.
......@@ -140,31 +166,16 @@ class FeaturePyramidNetwork(nn.Module):
return out
class ExtraFPNBlock(nn.Module):
"""
Base class for the extra block in the FPN.
Arguments:
results (List[Tensor]): the result of the FPN
x (List[Tensor]): the original feature maps
names (List[str]): the names for each one of the
original feature maps
Returns:
results (List[Tensor]): the extended set of results
of the FPN
names (List[str]): the extended set of names for the results
"""
def forward(self, results, x, names):
pass
class LastLevelMaxPool(ExtraFPNBlock):
"""
Applies a max_pool2d on top of the last feature map
"""
def forward(self, x, y, names):
# type: (List[Tensor], List[Tensor], List[str]) -> Tuple[List[Tensor], List[str]]
def forward(
self,
x: List[Tensor],
y: List[Tensor],
names: List[str],
) -> Tuple[List[Tensor], List[str]]:
names.append("pool")
x.append(F.max_pool2d(x[-1], 1, 2, 0))
return x, names
......@@ -174,7 +185,7 @@ class LastLevelP6P7(ExtraFPNBlock):
"""
This module is used in RetinaNet to generate extra layers, P6 and P7.
"""
def __init__(self, in_channels, out_channels):
def __init__(self, in_channels: int, out_channels: int):
super(LastLevelP6P7, self).__init__()
self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
......@@ -183,7 +194,12 @@ class LastLevelP6P7(ExtraFPNBlock):
nn.init.constant_(module.bias, 0)
self.use_P5 = in_channels == out_channels
def forward(self, p, c, names):
def forward(
self,
p: List[Tensor],
c: List[Tensor],
names: List[str],
) -> Tuple[List[Tensor], List[str]]:
p5, c5 = p[-1], c[-1]
x = p5 if self.use_P5 else c5
p6 = self.p6(x)
......
......@@ -10,6 +10,8 @@ is implemented
import warnings
import torch
from torch import Tensor, Size
from torch.jit.annotations import List, Optional, Tuple
class Conv2d(torch.nn.Conv2d):
......@@ -46,7 +48,12 @@ class FrozenBatchNorm2d(torch.nn.Module):
are fixed
"""
def __init__(self, num_features, eps=0., n=None):
def __init__(
self,
num_features: Tuple[int, ...],
eps: float = 0.,
n: Optional[Tuple[int, ...]] = None,
):
# n=None for backward-compatibility
if n is not None:
warnings.warn("`n` argument is deprecated and has been renamed `num_features`",
......@@ -59,8 +66,16 @@ class FrozenBatchNorm2d(torch.nn.Module):
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features))
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
def _load_from_state_dict(
self,
state_dict: dict,
prefix: str,
local_metadata: dict,
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
):
num_batches_tracked_key = prefix + 'num_batches_tracked'
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
......@@ -69,7 +84,7 @@ class FrozenBatchNorm2d(torch.nn.Module):
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
......@@ -80,5 +95,5 @@ class FrozenBatchNorm2d(torch.nn.Module):
bias = b - rm * scale
return x * scale + bias
def __repr__(self):
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.weight.shape[0]})"
......@@ -3,8 +3,7 @@ from torch.jit.annotations import List
from torch import Tensor
def _new_empty_tensor(x, shape):
# type: (Tensor, List[int]) -> Tensor
def _new_empty_tensor(x: Tensor, shape: List[int]) -> Tensor:
"""
Arguments:
input (Tensor): input tensor
......
......@@ -15,8 +15,7 @@ import torchvision
# _onnx_merge_levels() is an implementation supported by ONNX
# that merges the levels to the right indices
@torch.jit.unused
def _onnx_merge_levels(levels, unmerged_results):
# type: (Tensor, List[Tensor]) -> Tensor
def _onnx_merge_levels(levels: Tensor, unmerged_results: List[Tensor]) -> Tensor:
first_result = unmerged_results[0]
dtype, device = first_result.dtype, first_result.device
res = torch.zeros((levels.size(0), first_result.size(1),
......@@ -33,8 +32,13 @@ def _onnx_merge_levels(levels, unmerged_results):
# TODO: (eellison) T54974082 https://github.com/pytorch/pytorch/issues/26744/pytorch/issues/26744
def initLevelMapper(k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
# type: (int, int, int, int, float) -> LevelMapper
def initLevelMapper(
k_min: int,
k_max: int,
canonical_scale: int = 224,
canonical_level: int = 4,
eps: float = 1e-6,
):
return LevelMapper(k_min, k_max, canonical_scale, canonical_level, eps)
......@@ -50,16 +54,21 @@ class LevelMapper(object):
eps (float)
"""
def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
# type: (int, int, int, int, float) -> None
def __init__(
self,
k_min: int,
k_max: int,
canonical_scale: int = 224,
canonical_level: int = 4,
eps: float = 1e-6,
):
self.k_min = k_min
self.k_max = k_max
self.s0 = canonical_scale
self.lvl0 = canonical_level
self.eps = eps
def __call__(self, boxlists):
# type: (List[Tensor]) -> Tensor
def __call__(self, boxlists: List[Tensor]) -> Tensor:
"""
Arguments:
boxlists (list[BoxList])
......@@ -107,7 +116,12 @@ class MultiScaleRoIAlign(nn.Module):
'map_levels': Optional[LevelMapper]
}
def __init__(self, featmap_names, output_size, sampling_ratio):
def __init__(
self,
featmap_names: List[str],
output_size: List[int],
sampling_ratio: int,
):
super(MultiScaleRoIAlign, self).__init__()
if isinstance(output_size, int):
output_size = (output_size, output_size)
......@@ -117,8 +131,7 @@ class MultiScaleRoIAlign(nn.Module):
self.scales = None
self.map_levels = None
def convert_to_roi_format(self, boxes):
# type: (List[Tensor]) -> Tensor
def convert_to_roi_format(self, boxes: List[Tensor]) -> Tensor:
concat_boxes = torch.cat(boxes, dim=0)
device, dtype = concat_boxes.device, concat_boxes.dtype
ids = torch.cat(
......@@ -131,8 +144,7 @@ class MultiScaleRoIAlign(nn.Module):
rois = torch.cat([ids, concat_boxes], dim=1)
return rois
def infer_scale(self, feature, original_size):
# type: (Tensor, List[int]) -> float
def infer_scale(self, feature: Tensor, original_size: List[int]) -> float:
# assumption: the scale is of the form 2 ** (-k), with k integer
size = feature.shape[-2:]
possible_scales = torch.jit.annotate(List[float], [])
......@@ -143,8 +155,11 @@ class MultiScaleRoIAlign(nn.Module):
assert possible_scales[0] == possible_scales[1]
return possible_scales[0]
def setup_scales(self, features, image_shapes):
# type: (List[Tensor], List[Tuple[int, int]]) -> None
def setup_scales(
self,
features: List[Tensor],
image_shapes: List[Tuple[int, int]],
) -> None:
assert len(image_shapes) != 0
max_x = 0
max_y = 0
......@@ -161,8 +176,12 @@ class MultiScaleRoIAlign(nn.Module):
self.scales = scales
self.map_levels = initLevelMapper(int(lvl_min), int(lvl_max))
def forward(self, x, boxes, image_shapes):
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> Tensor
def forward(
self,
x: Dict[str, Tensor],
boxes: List[Tensor],
image_shapes: List[Tuple[int, int]],
) -> Tensor:
"""
Arguments:
x (OrderedDict[Tensor]): feature maps for each level. They are assumed to have
......
......@@ -2,13 +2,18 @@ import torch
from torch import nn, Tensor
from torch.nn.modules.utils import _pair
from torch.jit.annotations import List
from torch.jit.annotations import List, Tuple
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
def ps_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1):
# type: (Tensor, Tensor, int, float, int) -> Tensor
def ps_roi_align(
input: Tensor,
boxes: Tensor,
output_size: int,
spatial_scale: float = 1.0,
sampling_ratio: int = -1,
) -> Tensor:
"""
Performs Position-Sensitive Region of Interest (RoI) Align operator
mentioned in Light-Head R-CNN.
......@@ -49,17 +54,22 @@ class PSRoIAlign(nn.Module):
"""
See ps_roi_align
"""
def __init__(self, output_size, spatial_scale, sampling_ratio):
def __init__(
self,
output_size: int,
spatial_scale: float,
sampling_ratio: int,
):
super(PSRoIAlign, self).__init__()
self.output_size = output_size
self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio
def forward(self, input, rois):
def forward(self, input: Tensor, rois: Tensor) -> Tensor:
return ps_roi_align(input, rois, self.output_size, self.spatial_scale,
self.sampling_ratio)
def __repr__(self):
def __repr__(self) -> str:
tmpstr = self.__class__.__name__ + '('
tmpstr += 'output_size=' + str(self.output_size)
tmpstr += ', spatial_scale=' + str(self.spatial_scale)
......
......@@ -2,13 +2,17 @@ import torch
from torch import nn, Tensor
from torch.nn.modules.utils import _pair
from torch.jit.annotations import List
from torch.jit.annotations import List, Tuple
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
def ps_roi_pool(input, boxes, output_size, spatial_scale=1.0):
# type: (Tensor, Tensor, int, float) -> Tensor
def ps_roi_pool(
input: Tensor,
boxes: Tensor,
output_size: int,
spatial_scale: float = 1.0,
) -> Tensor:
"""
Performs Position-Sensitive Region of Interest (RoI) Pool operator
described in R-FCN
......@@ -43,15 +47,15 @@ class PSRoIPool(nn.Module):
"""
See ps_roi_pool
"""
def __init__(self, output_size, spatial_scale):
def __init__(self, output_size: int, spatial_scale: float):
super(PSRoIPool, self).__init__()
self.output_size = output_size
self.spatial_scale = spatial_scale
def forward(self, input, rois):
def forward(self, input: Tensor, rois: Tensor) -> Tensor:
return ps_roi_pool(input, rois, self.output_size, self.spatial_scale)
def __repr__(self):
def __repr__(self) -> str:
tmpstr = self.__class__.__name__ + '('
tmpstr += 'output_size=' + str(self.output_size)
tmpstr += ', spatial_scale=' + str(self.spatial_scale)
......
......@@ -7,8 +7,14 @@ from torch.jit.annotations import List, BroadcastingList2
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False):
# type: (Tensor, Tensor, BroadcastingList2[int], float, int, bool) -> Tensor
def roi_align(
input: Tensor,
boxes: Tensor,
output_size: BroadcastingList2[int],
spatial_scale: float = 1.0,
sampling_ratio: int = -1,
aligned: bool = False,
) -> Tensor:
"""
Performs Region of Interest (RoI) Align operator described in Mask R-CNN
......@@ -49,17 +55,23 @@ class RoIAlign(nn.Module):
"""
See roi_align
"""
def __init__(self, output_size, spatial_scale, sampling_ratio, aligned=False):
def __init__(
self,
output_size: BroadcastingList2[int],
spatial_scale: float,
sampling_ratio: int,
aligned: bool = False,
):
super(RoIAlign, self).__init__()
self.output_size = output_size
self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio
self.aligned = aligned
def forward(self, input, rois):
def forward(self, input: Tensor, rois: Tensor) -> Tensor:
return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned)
def __repr__(self):
def __repr__(self) -> str:
tmpstr = self.__class__.__name__ + '('
tmpstr += 'output_size=' + str(self.output_size)
tmpstr += ', spatial_scale=' + str(self.spatial_scale)
......
......@@ -7,8 +7,12 @@ from torch.jit.annotations import List, BroadcastingList2
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
def roi_pool(input, boxes, output_size, spatial_scale=1.0):
# type: (Tensor, Tensor, BroadcastingList2[int], float) -> Tensor
def roi_pool(
input: Tensor,
boxes: Tensor,
output_size: BroadcastingList2[int],
spatial_scale: float = 1.0,
) -> Tensor:
"""
Performs Region of Interest (RoI) Pool operator described in Fast R-CNN
......@@ -41,15 +45,15 @@ class RoIPool(nn.Module):
"""
See roi_pool
"""
def __init__(self, output_size, spatial_scale):
def __init__(self, output_size: BroadcastingList2[int], spatial_scale: float):
super(RoIPool, self).__init__()
self.output_size = output_size
self.spatial_scale = spatial_scale
def forward(self, input, rois):
def forward(self, input: Tensor, rois: Tensor) -> Tensor:
return roi_pool(input, rois, self.output_size, self.spatial_scale)
def __repr__(self):
def __repr__(self) -> str:
tmpstr = self.__class__.__name__ + '('
tmpstr += 'output_size=' + str(self.output_size)
tmpstr += ', spatial_scale=' + str(self.spatial_scale)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment