Unverified Commit dcdf8961 authored by ShawnHu's avatar ShawnHu Committed by GitHub
Browse files

Add type hints for mmcv/ops (#2037)

* Add type hints for mmcv/ops/ball_query.py, border_align.py and correlation.py

* Add type hints for mmcv/ops/deform_conv.py, deform_roi_pool.py and deprecated_wrappers.py

* Remove type hints for deform_conv.py and deform_roi_pool.py

* Fix type hints for other files
parent ea0e8cdb
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch import torch
from torch.autograd import Function from torch.autograd import Function
...@@ -49,7 +51,7 @@ class BallQuery(Function): ...@@ -49,7 +51,7 @@ class BallQuery(Function):
return idx return idx
@staticmethod @staticmethod
def backward(ctx, a=None): def backward(ctx, a=None) -> Tuple[None, None, None, None]:
return None, None, None, None return None, None, None, None
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# modified from # modified from
# https://github.com/Megvii-BaseDetection/cvpods/blob/master/cvpods/layers/border_align.py # https://github.com/Megvii-BaseDetection/cvpods/blob/master/cvpods/layers/border_align.py
from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.autograd import Function from torch.autograd import Function
...@@ -21,7 +23,8 @@ class BorderAlignFunction(Function): ...@@ -21,7 +23,8 @@ class BorderAlignFunction(Function):
'mmcv::MMCVBorderAlign', input, boxes, pool_size_i=pool_size) 'mmcv::MMCVBorderAlign', input, boxes, pool_size_i=pool_size)
@staticmethod @staticmethod
def forward(ctx, input, boxes, pool_size): def forward(ctx, input: torch.Tensor, boxes: torch.Tensor,
pool_size: int) -> torch.Tensor:
ctx.pool_size = pool_size ctx.pool_size = pool_size
ctx.input_shape = input.size() ctx.input_shape = input.size()
...@@ -45,7 +48,8 @@ class BorderAlignFunction(Function): ...@@ -45,7 +48,8 @@ class BorderAlignFunction(Function):
@staticmethod @staticmethod
@once_differentiable @once_differentiable
def backward(ctx, grad_output): def backward(ctx,
grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
boxes, argmax_idx = ctx.saved_tensors boxes, argmax_idx = ctx.saved_tensors
grad_input = grad_output.new_zeros(ctx.input_shape) grad_input = grad_output.new_zeros(ctx.input_shape)
# complex head architecture may cause grad_output uncontiguous # complex head architecture may cause grad_output uncontiguous
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from torch.autograd import Function from torch.autograd import Function
...@@ -15,14 +17,14 @@ class CorrelationFunction(Function): ...@@ -15,14 +17,14 @@ class CorrelationFunction(Function):
@staticmethod @staticmethod
def forward(ctx, def forward(ctx,
input1, input1: Tensor,
input2, input2: Tensor,
kernel_size=1, kernel_size: int = 1,
max_displacement=1, max_displacement: int = 1,
stride=1, stride: int = 1,
padding=1, padding: int = 1,
dilation=1, dilation: int = 1,
dilation_patch=1): dilation_patch: int = 1) -> Tensor:
ctx.save_for_backward(input1, input2) ctx.save_for_backward(input1, input2)
...@@ -60,7 +62,9 @@ class CorrelationFunction(Function): ...@@ -60,7 +62,9 @@ class CorrelationFunction(Function):
@staticmethod @staticmethod
@once_differentiable @once_differentiable
def backward(ctx, grad_output): def backward(
ctx, grad_output: Tensor
) -> Tuple[Tensor, Tensor, None, None, None, None, None, None]:
input1, input2 = ctx.saved_tensors input1, input2 = ctx.saved_tensors
kH, kW = ctx.kernel_size kH, kW = ctx.kernel_size
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -48,16 +48,16 @@ class DeformConv2dFunction(Function): ...@@ -48,16 +48,16 @@ class DeformConv2dFunction(Function):
@staticmethod @staticmethod
def forward(ctx, def forward(ctx,
input, input: Tensor,
offset, offset: Tensor,
weight, weight: Tensor,
stride=1, stride: Union[int, Tuple[int, ...]] = 1,
padding=0, padding: Union[int, Tuple[int, ...]] = 0,
dilation=1, dilation: Union[int, Tuple[int, ...]] = 1,
groups=1, groups: int = 1,
deform_groups=1, deform_groups: int = 1,
bias=False, bias: bool = False,
im2col_step=32): im2col_step: int = 32) -> Tensor:
if input is not None and input.dim() != 4: if input is not None and input.dim() != 4:
raise ValueError( raise ValueError(
f'Expected 4D tensor as input, got {input.dim()}D tensor \ f'Expected 4D tensor as input, got {input.dim()}D tensor \
...@@ -111,7 +111,10 @@ class DeformConv2dFunction(Function): ...@@ -111,7 +111,10 @@ class DeformConv2dFunction(Function):
@staticmethod @staticmethod
@once_differentiable @once_differentiable
def backward(ctx, grad_output): def backward(
ctx, grad_output: Tensor
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], None,
None, None, None, None, None, None]:
input, offset, weight = ctx.saved_tensors input, offset, weight = ctx.saved_tensors
grad_input = grad_offset = grad_weight = None grad_input = grad_offset = grad_weight = None
...@@ -371,7 +374,7 @@ class DeformConv2dPack(DeformConv2d): ...@@ -371,7 +374,7 @@ class DeformConv2dPack(DeformConv2d):
self.conv_offset.weight.data.zero_() self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_() self.conv_offset.bias.data.zero_()
def forward(self, x): def forward(self, x: Tensor) -> Tensor: # type: ignore
offset = self.conv_offset(x) offset = self.conv_offset(x)
return deform_conv2d(x, offset, self.weight, self.stride, self.padding, return deform_conv2d(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deform_groups, self.dilation, self.groups, self.deform_groups,
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from torch import nn from typing import Optional, Tuple
from torch import Tensor, nn
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
...@@ -28,13 +30,13 @@ class DeformRoIPoolFunction(Function): ...@@ -28,13 +30,13 @@ class DeformRoIPoolFunction(Function):
@staticmethod @staticmethod
def forward(ctx, def forward(ctx,
input, input: Tensor,
rois, rois: Tensor,
offset, offset: Optional[Tensor],
output_size, output_size: Tuple[int, ...],
spatial_scale=1.0, spatial_scale: float = 1.0,
sampling_ratio=0, sampling_ratio: int = 0,
gamma=0.1): gamma: float = 0.1) -> Tensor:
if offset is None: if offset is None:
offset = input.new_zeros(0) offset = input.new_zeros(0)
ctx.output_size = _pair(output_size) ctx.output_size = _pair(output_size)
...@@ -64,7 +66,9 @@ class DeformRoIPoolFunction(Function): ...@@ -64,7 +66,9 @@ class DeformRoIPoolFunction(Function):
@staticmethod @staticmethod
@once_differentiable @once_differentiable
def backward(ctx, grad_output): def backward(
ctx, grad_output: Tensor
) -> Tuple[Tensor, None, Tensor, None, None, None, None]:
input, rois, offset = ctx.saved_tensors input, rois, offset = ctx.saved_tensors
grad_input = grad_output.new_zeros(input.shape) grad_input = grad_output.new_zeros(input.shape)
grad_offset = grad_output.new_zeros(offset.shape) grad_offset = grad_output.new_zeros(offset.shape)
...@@ -92,17 +96,20 @@ deform_roi_pool = DeformRoIPoolFunction.apply ...@@ -92,17 +96,20 @@ deform_roi_pool = DeformRoIPoolFunction.apply
class DeformRoIPool(nn.Module): class DeformRoIPool(nn.Module):
def __init__(self, def __init__(self,
output_size, output_size: Tuple[int, ...],
spatial_scale=1.0, spatial_scale: float = 1.0,
sampling_ratio=0, sampling_ratio: int = 0,
gamma=0.1): gamma: float = 0.1):
super().__init__() super().__init__()
self.output_size = _pair(output_size) self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale) self.spatial_scale = float(spatial_scale)
self.sampling_ratio = int(sampling_ratio) self.sampling_ratio = int(sampling_ratio)
self.gamma = float(gamma) self.gamma = float(gamma)
def forward(self, input, rois, offset=None): def forward(self,
input: Tensor,
rois: Tensor,
offset: Optional[Tensor] = None) -> Tensor:
return deform_roi_pool(input, rois, offset, self.output_size, return deform_roi_pool(input, rois, offset, self.output_size,
self.spatial_scale, self.sampling_ratio, self.spatial_scale, self.sampling_ratio,
self.gamma) self.gamma)
...@@ -111,12 +118,12 @@ class DeformRoIPool(nn.Module): ...@@ -111,12 +118,12 @@ class DeformRoIPool(nn.Module):
class DeformRoIPoolPack(DeformRoIPool): class DeformRoIPoolPack(DeformRoIPool):
def __init__(self, def __init__(self,
output_size, output_size: Tuple[int, ...],
output_channels, output_channels: int,
deform_fc_channels=1024, deform_fc_channels: int = 1024,
spatial_scale=1.0, spatial_scale: float = 1.0,
sampling_ratio=0, sampling_ratio: int = 0,
gamma=0.1): gamma: float = 0.1):
super().__init__(output_size, spatial_scale, sampling_ratio, gamma) super().__init__(output_size, spatial_scale, sampling_ratio, gamma)
self.output_channels = output_channels self.output_channels = output_channels
...@@ -134,7 +141,7 @@ class DeformRoIPoolPack(DeformRoIPool): ...@@ -134,7 +141,7 @@ class DeformRoIPoolPack(DeformRoIPool):
self.offset_fc[-1].weight.data.zero_() self.offset_fc[-1].weight.data.zero_()
self.offset_fc[-1].bias.data.zero_() self.offset_fc[-1].bias.data.zero_()
def forward(self, input, rois): def forward(self, input: Tensor, rois: Tensor) -> Tensor: # type: ignore
assert input.size(1) == self.output_channels assert input.size(1) == self.output_channels
x = deform_roi_pool(input, rois, None, self.output_size, x = deform_roi_pool(input, rois, None, self.output_size,
self.spatial_scale, self.sampling_ratio, self.spatial_scale, self.sampling_ratio,
...@@ -151,12 +158,12 @@ class DeformRoIPoolPack(DeformRoIPool): ...@@ -151,12 +158,12 @@ class DeformRoIPoolPack(DeformRoIPool):
class ModulatedDeformRoIPoolPack(DeformRoIPool): class ModulatedDeformRoIPoolPack(DeformRoIPool):
def __init__(self, def __init__(self,
output_size, output_size: Tuple[int, ...],
output_channels, output_channels: int,
deform_fc_channels=1024, deform_fc_channels: int = 1024,
spatial_scale=1.0, spatial_scale: float = 1.0,
sampling_ratio=0, sampling_ratio: int = 0,
gamma=0.1): gamma: float = 0.1):
super().__init__(output_size, spatial_scale, sampling_ratio, gamma) super().__init__(output_size, spatial_scale, sampling_ratio, gamma)
self.output_channels = output_channels self.output_channels = output_channels
...@@ -185,7 +192,7 @@ class ModulatedDeformRoIPoolPack(DeformRoIPool): ...@@ -185,7 +192,7 @@ class ModulatedDeformRoIPoolPack(DeformRoIPool):
self.mask_fc[2].weight.data.zero_() self.mask_fc[2].weight.data.zero_()
self.mask_fc[2].bias.data.zero_() self.mask_fc[2].bias.data.zero_()
def forward(self, input, rois): def forward(self, input: Tensor, rois: Tensor) -> Tensor: # type: ignore
assert input.size(1) == self.output_channels assert input.size(1) == self.output_channels
x = deform_roi_pool(input, rois, None, self.output_size, x = deform_roi_pool(input, rois, None, self.output_size,
self.spatial_scale, self.sampling_ratio, self.spatial_scale, self.sampling_ratio,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment