Unverified Commit dcdf8961 authored by ShawnHu's avatar ShawnHu Committed by GitHub
Browse files

Add type hints for mmcv/ops (#2037)

* Add type hints for mmcv/ops/ball_query.py, border_align.py and correlation.py

* Add type hints for mmcv/ops/deform_conv.py, deform_roi_pool.py and deprecated_wrappers.py

* Remove type hints for deform_conv.py and deform_roi_pool.py

* Fix type hints for other files
parent ea0e8cdb
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch
from torch.autograd import Function
......@@ -49,7 +51,7 @@ class BallQuery(Function):
return idx
@staticmethod
def backward(ctx, a=None):
def backward(ctx, a=None) -> Tuple[None, None, None, None]:
return None, None, None, None
......
......@@ -2,6 +2,8 @@
# modified from
# https://github.com/Megvii-BaseDetection/cvpods/blob/master/cvpods/layers/border_align.py
from typing import Tuple
import torch
import torch.nn as nn
from torch.autograd import Function
......@@ -21,7 +23,8 @@ class BorderAlignFunction(Function):
'mmcv::MMCVBorderAlign', input, boxes, pool_size_i=pool_size)
@staticmethod
def forward(ctx, input, boxes, pool_size):
def forward(ctx, input: torch.Tensor, boxes: torch.Tensor,
pool_size: int) -> torch.Tensor:
ctx.pool_size = pool_size
ctx.input_shape = input.size()
......@@ -45,7 +48,8 @@ class BorderAlignFunction(Function):
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
def backward(ctx,
grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
boxes, argmax_idx = ctx.saved_tensors
grad_input = grad_output.new_zeros(ctx.input_shape)
# complex head architecture may cause grad_output uncontiguous
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch
from torch import Tensor, nn
from torch.autograd import Function
......@@ -15,14 +17,14 @@ class CorrelationFunction(Function):
@staticmethod
def forward(ctx,
input1,
input2,
kernel_size=1,
max_displacement=1,
stride=1,
padding=1,
dilation=1,
dilation_patch=1):
input1: Tensor,
input2: Tensor,
kernel_size: int = 1,
max_displacement: int = 1,
stride: int = 1,
padding: int = 1,
dilation: int = 1,
dilation_patch: int = 1) -> Tensor:
ctx.save_for_backward(input1, input2)
......@@ -60,7 +62,9 @@ class CorrelationFunction(Function):
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
def backward(
ctx, grad_output: Tensor
) -> Tuple[Tensor, Tensor, None, None, None, None, None, None]:
input1, input2 = ctx.saved_tensors
kH, kW = ctx.kernel_size
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
......@@ -48,16 +48,16 @@ class DeformConv2dFunction(Function):
@staticmethod
def forward(ctx,
input,
offset,
weight,
stride=1,
padding=0,
dilation=1,
groups=1,
deform_groups=1,
bias=False,
im2col_step=32):
input: Tensor,
offset: Tensor,
weight: Tensor,
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[int, Tuple[int, ...]] = 0,
dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1,
deform_groups: int = 1,
bias: bool = False,
im2col_step: int = 32) -> Tensor:
if input is not None and input.dim() != 4:
raise ValueError(
f'Expected 4D tensor as input, got {input.dim()}D tensor \
......@@ -111,7 +111,10 @@ class DeformConv2dFunction(Function):
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
def backward(
ctx, grad_output: Tensor
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], None,
None, None, None, None, None, None]:
input, offset, weight = ctx.saved_tensors
grad_input = grad_offset = grad_weight = None
......@@ -371,7 +374,7 @@ class DeformConv2dPack(DeformConv2d):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
def forward(self, x: Tensor) -> Tensor: # type: ignore
offset = self.conv_offset(x)
return deform_conv2d(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deform_groups,
......
# Copyright (c) OpenMMLab. All rights reserved.
from torch import nn
from typing import Optional, Tuple
from torch import Tensor, nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
......@@ -28,13 +30,13 @@ class DeformRoIPoolFunction(Function):
@staticmethod
def forward(ctx,
input,
rois,
offset,
output_size,
spatial_scale=1.0,
sampling_ratio=0,
gamma=0.1):
input: Tensor,
rois: Tensor,
offset: Optional[Tensor],
output_size: Tuple[int, ...],
spatial_scale: float = 1.0,
sampling_ratio: int = 0,
gamma: float = 0.1) -> Tensor:
if offset is None:
offset = input.new_zeros(0)
ctx.output_size = _pair(output_size)
......@@ -64,7 +66,9 @@ class DeformRoIPoolFunction(Function):
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
def backward(
ctx, grad_output: Tensor
) -> Tuple[Tensor, None, Tensor, None, None, None, None]:
input, rois, offset = ctx.saved_tensors
grad_input = grad_output.new_zeros(input.shape)
grad_offset = grad_output.new_zeros(offset.shape)
......@@ -92,17 +96,20 @@ deform_roi_pool = DeformRoIPoolFunction.apply
class DeformRoIPool(nn.Module):
def __init__(self,
output_size,
spatial_scale=1.0,
sampling_ratio=0,
gamma=0.1):
output_size: Tuple[int, ...],
spatial_scale: float = 1.0,
sampling_ratio: int = 0,
gamma: float = 0.1):
super().__init__()
self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale)
self.sampling_ratio = int(sampling_ratio)
self.gamma = float(gamma)
def forward(self, input, rois, offset=None):
def forward(self,
input: Tensor,
rois: Tensor,
offset: Optional[Tensor] = None) -> Tensor:
return deform_roi_pool(input, rois, offset, self.output_size,
self.spatial_scale, self.sampling_ratio,
self.gamma)
......@@ -111,12 +118,12 @@ class DeformRoIPool(nn.Module):
class DeformRoIPoolPack(DeformRoIPool):
def __init__(self,
output_size,
output_channels,
deform_fc_channels=1024,
spatial_scale=1.0,
sampling_ratio=0,
gamma=0.1):
output_size: Tuple[int, ...],
output_channels: int,
deform_fc_channels: int = 1024,
spatial_scale: float = 1.0,
sampling_ratio: int = 0,
gamma: float = 0.1):
super().__init__(output_size, spatial_scale, sampling_ratio, gamma)
self.output_channels = output_channels
......@@ -134,7 +141,7 @@ class DeformRoIPoolPack(DeformRoIPool):
self.offset_fc[-1].weight.data.zero_()
self.offset_fc[-1].bias.data.zero_()
def forward(self, input, rois):
def forward(self, input: Tensor, rois: Tensor) -> Tensor: # type: ignore
assert input.size(1) == self.output_channels
x = deform_roi_pool(input, rois, None, self.output_size,
self.spatial_scale, self.sampling_ratio,
......@@ -151,12 +158,12 @@ class DeformRoIPoolPack(DeformRoIPool):
class ModulatedDeformRoIPoolPack(DeformRoIPool):
def __init__(self,
output_size,
output_channels,
deform_fc_channels=1024,
spatial_scale=1.0,
sampling_ratio=0,
gamma=0.1):
output_size: Tuple[int, ...],
output_channels: int,
deform_fc_channels: int = 1024,
spatial_scale: float = 1.0,
sampling_ratio: int = 0,
gamma: float = 0.1):
super().__init__(output_size, spatial_scale, sampling_ratio, gamma)
self.output_channels = output_channels
......@@ -185,7 +192,7 @@ class ModulatedDeformRoIPoolPack(DeformRoIPool):
self.mask_fc[2].weight.data.zero_()
self.mask_fc[2].bias.data.zero_()
def forward(self, input, rois):
def forward(self, input: Tensor, rois: Tensor) -> Tensor: # type: ignore
assert input.size(1) == self.output_channels
x = deform_roi_pool(input, rois, None, self.output_size,
self.spatial_scale, self.sampling_ratio,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment