Unverified Commit b9a96e56 authored by WINDSKY45's avatar WINDSKY45 Committed by GitHub
Browse files

[Enhance] Add type hints in /ops: (#2030)



* [Enhance] Add type hints in /ops:
`fused_bias_leakyrelu.py`, 'gather_points.py`, `group_points.py`.
There is no need to add type hints in `furthest_point_sample.py` and
`info.py`.
As for `focal_loss.py`, please see #1994.

* Modidied the default value of a variable.

* [Enhance] Add type hints in:
`knn.py`, `masked_conv.py`, `merge_cells.py`, `min_area_polygons.py`,
`modulated_deform_conv.py`, multi_scale_deform_attn.py`.

* Fix typehint.

* Fixed typehint.

* remove type hints of symbolic

* add no_type_check to ignore mypy check for method
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 305c2a30
......@@ -113,7 +113,8 @@ class FusedBiasLeakyReLUFunctionBackward(Function):
"""
@staticmethod
def forward(ctx, grad_output, out, negative_slope, scale):
def forward(ctx, grad_output: torch.Tensor, out: torch.Tensor,
negative_slope: float, scale: float) -> tuple:
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
......@@ -139,7 +140,8 @@ class FusedBiasLeakyReLUFunctionBackward(Function):
return grad_input, grad_bias
@staticmethod
def backward(ctx, gradgrad_input, gradgrad_bias):
def backward(ctx, gradgrad_input: torch.Tensor,
gradgrad_bias: nn.Parameter) -> tuple:
out, = ctx.saved_tensors
# The second order deviation, in fact, contains two parts, while the
......@@ -160,7 +162,8 @@ class FusedBiasLeakyReLUFunctionBackward(Function):
class FusedBiasLeakyReLUFunction(Function):
@staticmethod
def forward(ctx, input, bias, negative_slope, scale):
def forward(ctx, input: torch.Tensor, bias: nn.Parameter,
negative_slope: float, scale: float) -> torch.Tensor:
empty = input.new_empty(0)
out = ext_module.fused_bias_leakyrelu(
......@@ -178,7 +181,7 @@ class FusedBiasLeakyReLUFunction(Function):
return out
@staticmethod
def backward(ctx, grad_output):
def backward(ctx, grad_output: torch.Tensor) -> tuple:
out, = ctx.saved_tensors
grad_input, grad_bias = FusedBiasLeakyReLUFunctionBackward.apply(
......@@ -204,26 +207,32 @@ class FusedBiasLeakyReLU(nn.Module):
TODO: Implement the CPU version.
Args:
channel (int): The channel number of the feature map.
num_channels (int): The channel number of the feature map.
negative_slope (float, optional): Same as nn.LeakyRelu.
Defaults to 0.2.
scale (float, optional): A scalar to adjust the variance of the feature
map. Defaults to 2**0.5.
"""
def __init__(self, num_channels, negative_slope=0.2, scale=2**0.5):
def __init__(self,
num_channels: int,
negative_slope: float = 0.2,
scale: float = 2**0.5):
super().__init__()
self.bias = nn.Parameter(torch.zeros(num_channels))
self.negative_slope = negative_slope
self.scale = scale
def forward(self, input):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return fused_bias_leakyrelu(input, self.bias, self.negative_slope,
self.scale)
def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5):
def fused_bias_leakyrelu(input: torch.Tensor,
bias: nn.Parameter,
negative_slope: float = 0.2,
scale: float = 2**0.5) -> torch.Tensor:
r"""Fused bias leaky ReLU function.
This function is introduced in the StyleGAN2:
......@@ -256,7 +265,10 @@ def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5):
negative_slope, scale)
def bias_leakyrelu_ref(x, bias, negative_slope=0.2, scale=2**0.5):
def bias_leakyrelu_ref(x: torch.Tensor,
bias: nn.Parameter,
negative_slope: float = 0.2,
scale: float = 2**0.5) -> torch.Tensor:
if bias is not None:
assert bias.ndim == 1
......
from typing import Tuple
import torch
from torch.autograd import Function
......@@ -37,7 +39,7 @@ class GatherPoints(Function):
return output
@staticmethod
def backward(ctx, grad_out):
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, None]:
idx, C, N = ctx.for_backwards
B, npoint = idx.size()
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
from typing import Optional, Tuple, Union
import torch
from torch import nn as nn
......@@ -37,15 +37,15 @@ class QueryAndGroup(nn.Module):
"""
def __init__(self,
max_radius,
sample_num,
min_radius=0,
use_xyz=True,
return_grouped_xyz=False,
normalize_xyz=False,
uniform_sample=False,
return_unique_cnt=False,
return_grouped_idx=False):
max_radius: float,
sample_num: int,
min_radius: float = 0.,
use_xyz: bool = True,
return_grouped_xyz: bool = False,
normalize_xyz: bool = False,
uniform_sample: bool = False,
return_unique_cnt: bool = False,
return_grouped_idx: bool = False):
super().__init__()
self.max_radius = max_radius
self.min_radius = min_radius
......@@ -64,7 +64,12 @@ class QueryAndGroup(nn.Module):
assert not self.normalize_xyz, \
'can not normalize grouped xyz when max_radius is None'
def forward(self, points_xyz, center_xyz, features=None):
def forward(
self,
points_xyz: torch.Tensor,
center_xyz: torch.Tensor,
features: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple]:
"""
Args:
points_xyz (torch.Tensor): (B, N, 3) xyz coordinates of the
......@@ -75,7 +80,7 @@ class QueryAndGroup(nn.Module):
points.
Returns:
torch.Tensor: (B, 3 + C, npoint, sample_num) Grouped
Tuple | torch.Tensor: (B, 3 + C, npoint, sample_num) Grouped
concatenated coordinates and features of points.
"""
# if self.max_radius is None, we will perform kNN instead of ball query
......@@ -149,7 +154,7 @@ class GroupAll(nn.Module):
def forward(self,
xyz: torch.Tensor,
new_xyz: torch.Tensor,
features: torch.Tensor = None):
features: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Args:
xyz (Tensor): (B, N, 3) xyz coordinates of the features.
......@@ -210,8 +215,7 @@ class GroupingOperation(Function):
return output
@staticmethod
def backward(ctx,
grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, None]:
"""
Args:
grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients
......
from typing import Optional
import torch
from torch.autograd import Function
......@@ -19,7 +21,7 @@ class KNN(Function):
def forward(ctx,
k: int,
xyz: torch.Tensor,
center_xyz: torch.Tensor = None,
center_xyz: Optional[torch.Tensor] = None,
transposed: bool = False) -> torch.Tensor:
"""
Args:
......
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
......@@ -27,7 +28,13 @@ class MaskedConv2dFunction(Function):
stride_i=stride)
@staticmethod
def forward(ctx, features, mask, weight, bias, padding=0, stride=1):
def forward(ctx,
features: torch.Tensor,
mask: torch.Tensor,
weight: torch.nn.Parameter,
bias: torch.nn.Parameter,
padding: int = 0,
stride: int = 1) -> torch.Tensor:
assert mask.dim() == 3 and mask.size(0) == 1
assert features.dim() == 4 and features.size(0) == 1
assert features.size()[2:] == mask.size()[1:]
......@@ -75,7 +82,7 @@ class MaskedConv2dFunction(Function):
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
def backward(ctx, grad_output: torch.Tensor) -> tuple:
return (None, ) * 5
......@@ -90,18 +97,20 @@ class MaskedConv2d(nn.Conv2d):
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True):
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True):
super().__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
def forward(self, input, mask=None):
def forward(self,
input: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
if mask is None: # fallback to the normal Conv2d
return super().forward(input)
else:
......
# Copyright (c) OpenMMLab. All rights reserved.
import math
from abc import abstractmethod
from typing import Optional
import torch
import torch.nn as nn
......@@ -19,7 +20,7 @@ class BaseMergeCell(nn.Module):
another convolution layer.
Args:
in_channels (int): number of input channels in out_conv layer.
fused_channels (int): number of input channels in out_conv layer.
out_channels (int): number of output channels in out_conv layer.
with_out_conv (bool): Whether to use out_conv layer
out_conv_cfg (dict): Config dict for convolution layer, which should
......@@ -42,18 +43,18 @@ class BaseMergeCell(nn.Module):
"""
def __init__(self,
fused_channels=256,
out_channels=256,
with_out_conv=True,
out_conv_cfg=dict(
fused_channels: Optional[int] = 256,
out_channels: Optional[int] = 256,
with_out_conv: bool = True,
out_conv_cfg: dict = dict(
groups=1, kernel_size=3, padding=1, bias=True),
out_norm_cfg=None,
out_conv_order=('act', 'conv', 'norm'),
with_input1_conv=False,
with_input2_conv=False,
input_conv_cfg=None,
input_norm_cfg=None,
upsample_mode='nearest'):
out_norm_cfg: Optional[dict] = None,
out_conv_order: tuple = ('act', 'conv', 'norm'),
with_input1_conv: bool = False,
with_input2_conv: bool = False,
input_conv_cfg: Optional[dict] = None,
input_norm_cfg: Optional[dict] = None,
upsample_mode: str = 'nearest'):
super().__init__()
assert upsample_mode in ['nearest', 'bilinear']
self.with_out_conv = with_out_conv
......@@ -111,7 +112,10 @@ class BaseMergeCell(nn.Module):
x = F.max_pool2d(x, kernel_size=kernel_size, stride=kernel_size)
return x
def forward(self, x1, x2, out_size=None):
def forward(self,
x1: torch.Tensor,
x2: torch.Tensor,
out_size: Optional[tuple] = None) -> torch.Tensor:
assert x1.shape[:2] == x2.shape[:2]
assert out_size is None or len(out_size) == 2
if out_size is None: # resize to larger one
......@@ -131,7 +135,7 @@ class BaseMergeCell(nn.Module):
class SumCell(BaseMergeCell):
def __init__(self, in_channels, out_channels, **kwargs):
def __init__(self, in_channels: int, out_channels: int, **kwargs):
super().__init__(in_channels, out_channels, **kwargs)
def _binary_op(self, x1, x2):
......@@ -140,7 +144,7 @@ class SumCell(BaseMergeCell):
class ConcatCell(BaseMergeCell):
def __init__(self, in_channels, out_channels, **kwargs):
def __init__(self, in_channels: int, out_channels: int, **kwargs):
super().__init__(in_channels * 2, out_channels, **kwargs)
def _binary_op(self, x1, x2):
......@@ -150,7 +154,10 @@ class ConcatCell(BaseMergeCell):
class GlobalPoolingCell(BaseMergeCell):
def __init__(self, in_channels=None, out_channels=None, **kwargs):
def __init__(self,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
**kwargs):
super().__init__(in_channels, out_channels, **kwargs)
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['min_area_polygons'])
def min_area_polygons(pointsets):
def min_area_polygons(pointsets: torch.Tensor) -> torch.Tensor:
"""Find the smallest polygons that surrounds all points in the point sets.
Args:
......
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
......@@ -35,16 +36,16 @@ class ModulatedDeformConv2dFunction(Function):
@staticmethod
def forward(ctx,
input,
offset,
mask,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1,
deform_groups=1):
input: torch.Tensor,
offset: torch.Tensor,
mask: torch.Tensor,
weight: nn.Parameter,
bias: Optional[nn.Parameter] = None,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
deform_groups: int = 1) -> torch.Tensor:
if input is not None and input.dim() != 4:
raise ValueError(
f'Expected 4D tensor as input, got {input.dim()}D tensor \
......@@ -66,7 +67,7 @@ class ModulatedDeformConv2dFunction(Function):
# whatever the pytorch version is.
input = input.type_as(offset)
weight = weight.type_as(input)
bias = bias.type_as(input)
bias = bias.type_as(input) # type: ignore
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty(
ModulatedDeformConv2dFunction._output_size(ctx, input, weight))
......@@ -95,7 +96,7 @@ class ModulatedDeformConv2dFunction(Function):
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
def backward(ctx, grad_output: torch.Tensor) -> tuple:
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
......@@ -159,15 +160,15 @@ class ModulatedDeformConv2d(nn.Module):
@deprecated_api_warning({'deformable_groups': 'deform_groups'},
cls_name='ModulatedDeformConv2d')
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deform_groups=1,
bias=True):
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int]],
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
deform_groups: int = 1,
bias: Union[bool, str] = True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
......@@ -199,7 +200,8 @@ class ModulatedDeformConv2d(nn.Module):
if self.bias is not None:
self.bias.data.zero_()
def forward(self, x, offset, mask):
def forward(self, x: torch.Tensor, offset: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
self.stride, self.padding,
self.dilation, self.groups,
......@@ -238,13 +240,13 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
bias=True)
self.init_weights()
def init_weights(self):
def init_weights(self) -> None:
super().init_weights()
if hasattr(self, 'conv_offset'):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
out = self.conv_offset(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
......
# Copyright (c) OpenMMLab. All rights reserved.
import math
import warnings
from typing import Optional, no_type_check
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.function import Function, once_differentiable
import mmcv
from mmcv import deprecated_api_warning
from mmcv.cnn import constant_init, xavier_init
from mmcv.cnn.bricks.registry import ATTENTION
......@@ -20,8 +22,11 @@ ext_module = ext_loader.load_ext(
class MultiScaleDeformableAttnFunction(Function):
@staticmethod
def forward(ctx, value, value_spatial_shapes, value_level_start_index,
sampling_locations, attention_weights, im2col_step):
def forward(ctx, value: torch.Tensor, value_spatial_shapes: torch.Tensor,
value_level_start_index: torch.Tensor,
sampling_locations: torch.Tensor,
attention_weights: torch.Tensor,
im2col_step: torch.Tensor) -> torch.Tensor:
"""GPU version of multi-scale deformable attention.
Args:
......@@ -37,7 +42,7 @@ class MultiScaleDeformableAttnFunction(Function):
attention_weights (torch.Tensor): The weight of sampling points
used when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
im2col_step (Tensor): The step used in image to column.
im2col_step (torch.Tensor): The step used in image to column.
Returns:
torch.Tensor: has shape (bs, num_queries, embed_dims)
......@@ -58,7 +63,7 @@ class MultiScaleDeformableAttnFunction(Function):
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
def backward(ctx, grad_output: torch.Tensor) -> tuple:
"""GPU version of backward function.
Args:
......@@ -89,8 +94,10 @@ class MultiScaleDeformableAttnFunction(Function):
grad_sampling_loc, grad_attn_weight, None
def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes,
sampling_locations, attention_weights):
def multi_scale_deformable_attn_pytorch(
value: torch.Tensor, value_spatial_shapes: torch.Tensor,
sampling_locations: torch.Tensor,
attention_weights: torch.Tensor) -> torch.Tensor:
"""CPU version of multi-scale deformable attention.
Args:
......@@ -178,15 +185,15 @@ class MultiScaleDeformableAttention(BaseModule):
"""
def __init__(self,
embed_dims=256,
num_heads=8,
num_levels=4,
num_points=4,
im2col_step=64,
dropout=0.1,
batch_first=False,
norm_cfg=None,
init_cfg=None):
embed_dims: int = 256,
num_heads: int = 8,
num_levels: int = 4,
num_points: int = 4,
im2col_step: int = 64,
dropout: float = 0.1,
batch_first: bool = False,
norm_cfg: Optional[dict] = None,
init_cfg: Optional[mmcv.ConfigDict] = None):
super().__init__(init_cfg)
if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, '
......@@ -225,7 +232,7 @@ class MultiScaleDeformableAttention(BaseModule):
self.output_proj = nn.Linear(embed_dims, embed_dims)
self.init_weights()
def init_weights(self):
def init_weights(self) -> None:
"""Default initialization for Parameters of Module."""
constant_init(self.sampling_offsets, 0.)
thetas = torch.arange(
......@@ -245,19 +252,20 @@ class MultiScaleDeformableAttention(BaseModule):
xavier_init(self.output_proj, distribution='uniform', bias=0.)
self._is_init = True
@no_type_check
@deprecated_api_warning({'residual': 'identity'},
cls_name='MultiScaleDeformableAttention')
def forward(self,
query,
key=None,
value=None,
identity=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
identity: Optional[torch.Tensor] = None,
query_pos: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
reference_points: Optional[torch.Tensor] = None,
spatial_shapes: Optional[torch.Tensor] = None,
level_start_index: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
"""Forward Function of MultiScaleDeformAttention.
Args:
......@@ -272,8 +280,8 @@ class MultiScaleDeformableAttention(BaseModule):
`query` will be used.
query_pos (torch.Tensor): The positional encoding for `query`.
Default: None.
key_pos (torch.Tensor): The positional encoding for `key`. Default
None.
key_padding_mask (torch.Tensor): ByteTensor for `query`, with
shape [bs, num_key].
reference_points (torch.Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
......@@ -281,8 +289,6 @@ class MultiScaleDeformableAttention(BaseModule):
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (torch.Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (torch.Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment