Unverified Commit b9a96e56 authored by WINDSKY45's avatar WINDSKY45 Committed by GitHub
Browse files

[Enhance] Add type hints in /ops: (#2030)



* [Enhance] Add type hints in /ops:
`fused_bias_leakyrelu.py`, 'gather_points.py`, `group_points.py`.
There is no need to add type hints in `furthest_point_sample.py` and
`info.py`.
As for `focal_loss.py`, please see #1994.

* Modidied the default value of a variable.

* [Enhance] Add type hints in:
`knn.py`, `masked_conv.py`, `merge_cells.py`, `min_area_polygons.py`,
`modulated_deform_conv.py`, multi_scale_deform_attn.py`.

* Fix typehint.

* Fixed typehint.

* remove type hints of symbolic

* add no_type_check to ignore mypy check for method
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 305c2a30
...@@ -113,7 +113,8 @@ class FusedBiasLeakyReLUFunctionBackward(Function): ...@@ -113,7 +113,8 @@ class FusedBiasLeakyReLUFunctionBackward(Function):
""" """
@staticmethod @staticmethod
def forward(ctx, grad_output, out, negative_slope, scale): def forward(ctx, grad_output: torch.Tensor, out: torch.Tensor,
negative_slope: float, scale: float) -> tuple:
ctx.save_for_backward(out) ctx.save_for_backward(out)
ctx.negative_slope = negative_slope ctx.negative_slope = negative_slope
ctx.scale = scale ctx.scale = scale
...@@ -139,7 +140,8 @@ class FusedBiasLeakyReLUFunctionBackward(Function): ...@@ -139,7 +140,8 @@ class FusedBiasLeakyReLUFunctionBackward(Function):
return grad_input, grad_bias return grad_input, grad_bias
@staticmethod @staticmethod
def backward(ctx, gradgrad_input, gradgrad_bias): def backward(ctx, gradgrad_input: torch.Tensor,
gradgrad_bias: nn.Parameter) -> tuple:
out, = ctx.saved_tensors out, = ctx.saved_tensors
# The second order deviation, in fact, contains two parts, while the # The second order deviation, in fact, contains two parts, while the
...@@ -160,7 +162,8 @@ class FusedBiasLeakyReLUFunctionBackward(Function): ...@@ -160,7 +162,8 @@ class FusedBiasLeakyReLUFunctionBackward(Function):
class FusedBiasLeakyReLUFunction(Function): class FusedBiasLeakyReLUFunction(Function):
@staticmethod @staticmethod
def forward(ctx, input, bias, negative_slope, scale): def forward(ctx, input: torch.Tensor, bias: nn.Parameter,
negative_slope: float, scale: float) -> torch.Tensor:
empty = input.new_empty(0) empty = input.new_empty(0)
out = ext_module.fused_bias_leakyrelu( out = ext_module.fused_bias_leakyrelu(
...@@ -178,7 +181,7 @@ class FusedBiasLeakyReLUFunction(Function): ...@@ -178,7 +181,7 @@ class FusedBiasLeakyReLUFunction(Function):
return out return out
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output: torch.Tensor) -> tuple:
out, = ctx.saved_tensors out, = ctx.saved_tensors
grad_input, grad_bias = FusedBiasLeakyReLUFunctionBackward.apply( grad_input, grad_bias = FusedBiasLeakyReLUFunctionBackward.apply(
...@@ -204,26 +207,32 @@ class FusedBiasLeakyReLU(nn.Module): ...@@ -204,26 +207,32 @@ class FusedBiasLeakyReLU(nn.Module):
TODO: Implement the CPU version. TODO: Implement the CPU version.
Args: Args:
channel (int): The channel number of the feature map. num_channels (int): The channel number of the feature map.
negative_slope (float, optional): Same as nn.LeakyRelu. negative_slope (float, optional): Same as nn.LeakyRelu.
Defaults to 0.2. Defaults to 0.2.
scale (float, optional): A scalar to adjust the variance of the feature scale (float, optional): A scalar to adjust the variance of the feature
map. Defaults to 2**0.5. map. Defaults to 2**0.5.
""" """
def __init__(self, num_channels, negative_slope=0.2, scale=2**0.5): def __init__(self,
num_channels: int,
negative_slope: float = 0.2,
scale: float = 2**0.5):
super().__init__() super().__init__()
self.bias = nn.Parameter(torch.zeros(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels))
self.negative_slope = negative_slope self.negative_slope = negative_slope
self.scale = scale self.scale = scale
def forward(self, input): def forward(self, input: torch.Tensor) -> torch.Tensor:
return fused_bias_leakyrelu(input, self.bias, self.negative_slope, return fused_bias_leakyrelu(input, self.bias, self.negative_slope,
self.scale) self.scale)
def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5): def fused_bias_leakyrelu(input: torch.Tensor,
bias: nn.Parameter,
negative_slope: float = 0.2,
scale: float = 2**0.5) -> torch.Tensor:
r"""Fused bias leaky ReLU function. r"""Fused bias leaky ReLU function.
This function is introduced in the StyleGAN2: This function is introduced in the StyleGAN2:
...@@ -256,7 +265,10 @@ def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5): ...@@ -256,7 +265,10 @@ def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5):
negative_slope, scale) negative_slope, scale)
def bias_leakyrelu_ref(x, bias, negative_slope=0.2, scale=2**0.5): def bias_leakyrelu_ref(x: torch.Tensor,
bias: nn.Parameter,
negative_slope: float = 0.2,
scale: float = 2**0.5) -> torch.Tensor:
if bias is not None: if bias is not None:
assert bias.ndim == 1 assert bias.ndim == 1
......
from typing import Tuple
import torch import torch
from torch.autograd import Function from torch.autograd import Function
...@@ -37,7 +39,7 @@ class GatherPoints(Function): ...@@ -37,7 +39,7 @@ class GatherPoints(Function):
return output return output
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, None]:
idx, C, N = ctx.for_backwards idx, C, N = ctx.for_backwards
B, npoint = idx.size() B, npoint = idx.size()
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple from typing import Optional, Tuple, Union
import torch import torch
from torch import nn as nn from torch import nn as nn
...@@ -37,15 +37,15 @@ class QueryAndGroup(nn.Module): ...@@ -37,15 +37,15 @@ class QueryAndGroup(nn.Module):
""" """
def __init__(self, def __init__(self,
max_radius, max_radius: float,
sample_num, sample_num: int,
min_radius=0, min_radius: float = 0.,
use_xyz=True, use_xyz: bool = True,
return_grouped_xyz=False, return_grouped_xyz: bool = False,
normalize_xyz=False, normalize_xyz: bool = False,
uniform_sample=False, uniform_sample: bool = False,
return_unique_cnt=False, return_unique_cnt: bool = False,
return_grouped_idx=False): return_grouped_idx: bool = False):
super().__init__() super().__init__()
self.max_radius = max_radius self.max_radius = max_radius
self.min_radius = min_radius self.min_radius = min_radius
...@@ -64,7 +64,12 @@ class QueryAndGroup(nn.Module): ...@@ -64,7 +64,12 @@ class QueryAndGroup(nn.Module):
assert not self.normalize_xyz, \ assert not self.normalize_xyz, \
'can not normalize grouped xyz when max_radius is None' 'can not normalize grouped xyz when max_radius is None'
def forward(self, points_xyz, center_xyz, features=None): def forward(
self,
points_xyz: torch.Tensor,
center_xyz: torch.Tensor,
features: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple]:
""" """
Args: Args:
points_xyz (torch.Tensor): (B, N, 3) xyz coordinates of the points_xyz (torch.Tensor): (B, N, 3) xyz coordinates of the
...@@ -75,7 +80,7 @@ class QueryAndGroup(nn.Module): ...@@ -75,7 +80,7 @@ class QueryAndGroup(nn.Module):
points. points.
Returns: Returns:
torch.Tensor: (B, 3 + C, npoint, sample_num) Grouped Tuple | torch.Tensor: (B, 3 + C, npoint, sample_num) Grouped
concatenated coordinates and features of points. concatenated coordinates and features of points.
""" """
# if self.max_radius is None, we will perform kNN instead of ball query # if self.max_radius is None, we will perform kNN instead of ball query
...@@ -149,7 +154,7 @@ class GroupAll(nn.Module): ...@@ -149,7 +154,7 @@ class GroupAll(nn.Module):
def forward(self, def forward(self,
xyz: torch.Tensor, xyz: torch.Tensor,
new_xyz: torch.Tensor, new_xyz: torch.Tensor,
features: torch.Tensor = None): features: Optional[torch.Tensor] = None) -> torch.Tensor:
""" """
Args: Args:
xyz (Tensor): (B, N, 3) xyz coordinates of the features. xyz (Tensor): (B, N, 3) xyz coordinates of the features.
...@@ -210,8 +215,7 @@ class GroupingOperation(Function): ...@@ -210,8 +215,7 @@ class GroupingOperation(Function):
return output return output
@staticmethod @staticmethod
def backward(ctx, def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, None]:
grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients
......
from typing import Optional
import torch import torch
from torch.autograd import Function from torch.autograd import Function
...@@ -19,7 +21,7 @@ class KNN(Function): ...@@ -19,7 +21,7 @@ class KNN(Function):
def forward(ctx, def forward(ctx,
k: int, k: int,
xyz: torch.Tensor, xyz: torch.Tensor,
center_xyz: torch.Tensor = None, center_xyz: Optional[torch.Tensor] = None,
transposed: bool = False) -> torch.Tensor: transposed: bool = False) -> torch.Tensor:
""" """
Args: Args:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import math import math
from typing import Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -27,7 +28,13 @@ class MaskedConv2dFunction(Function): ...@@ -27,7 +28,13 @@ class MaskedConv2dFunction(Function):
stride_i=stride) stride_i=stride)
@staticmethod @staticmethod
def forward(ctx, features, mask, weight, bias, padding=0, stride=1): def forward(ctx,
features: torch.Tensor,
mask: torch.Tensor,
weight: torch.nn.Parameter,
bias: torch.nn.Parameter,
padding: int = 0,
stride: int = 1) -> torch.Tensor:
assert mask.dim() == 3 and mask.size(0) == 1 assert mask.dim() == 3 and mask.size(0) == 1
assert features.dim() == 4 and features.size(0) == 1 assert features.dim() == 4 and features.size(0) == 1
assert features.size()[2:] == mask.size()[1:] assert features.size()[2:] == mask.size()[1:]
...@@ -75,7 +82,7 @@ class MaskedConv2dFunction(Function): ...@@ -75,7 +82,7 @@ class MaskedConv2dFunction(Function):
@staticmethod @staticmethod
@once_differentiable @once_differentiable
def backward(ctx, grad_output): def backward(ctx, grad_output: torch.Tensor) -> tuple:
return (None, ) * 5 return (None, ) * 5
...@@ -90,18 +97,20 @@ class MaskedConv2d(nn.Conv2d): ...@@ -90,18 +97,20 @@ class MaskedConv2d(nn.Conv2d):
""" """
def __init__(self, def __init__(self,
in_channels, in_channels: int,
out_channels, out_channels: int,
kernel_size, kernel_size: Union[int, Tuple[int, ...]],
stride=1, stride: int = 1,
padding=0, padding: int = 0,
dilation=1, dilation: int = 1,
groups=1, groups: int = 1,
bias=True): bias: bool = True):
super().__init__(in_channels, out_channels, kernel_size, stride, super().__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias) padding, dilation, groups, bias)
def forward(self, input, mask=None): def forward(self,
input: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
if mask is None: # fallback to the normal Conv2d if mask is None: # fallback to the normal Conv2d
return super().forward(input) return super().forward(input)
else: else:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import math import math
from abc import abstractmethod from abc import abstractmethod
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -19,7 +20,7 @@ class BaseMergeCell(nn.Module): ...@@ -19,7 +20,7 @@ class BaseMergeCell(nn.Module):
another convolution layer. another convolution layer.
Args: Args:
in_channels (int): number of input channels in out_conv layer. fused_channels (int): number of input channels in out_conv layer.
out_channels (int): number of output channels in out_conv layer. out_channels (int): number of output channels in out_conv layer.
with_out_conv (bool): Whether to use out_conv layer with_out_conv (bool): Whether to use out_conv layer
out_conv_cfg (dict): Config dict for convolution layer, which should out_conv_cfg (dict): Config dict for convolution layer, which should
...@@ -42,18 +43,18 @@ class BaseMergeCell(nn.Module): ...@@ -42,18 +43,18 @@ class BaseMergeCell(nn.Module):
""" """
def __init__(self, def __init__(self,
fused_channels=256, fused_channels: Optional[int] = 256,
out_channels=256, out_channels: Optional[int] = 256,
with_out_conv=True, with_out_conv: bool = True,
out_conv_cfg=dict( out_conv_cfg: dict = dict(
groups=1, kernel_size=3, padding=1, bias=True), groups=1, kernel_size=3, padding=1, bias=True),
out_norm_cfg=None, out_norm_cfg: Optional[dict] = None,
out_conv_order=('act', 'conv', 'norm'), out_conv_order: tuple = ('act', 'conv', 'norm'),
with_input1_conv=False, with_input1_conv: bool = False,
with_input2_conv=False, with_input2_conv: bool = False,
input_conv_cfg=None, input_conv_cfg: Optional[dict] = None,
input_norm_cfg=None, input_norm_cfg: Optional[dict] = None,
upsample_mode='nearest'): upsample_mode: str = 'nearest'):
super().__init__() super().__init__()
assert upsample_mode in ['nearest', 'bilinear'] assert upsample_mode in ['nearest', 'bilinear']
self.with_out_conv = with_out_conv self.with_out_conv = with_out_conv
...@@ -111,7 +112,10 @@ class BaseMergeCell(nn.Module): ...@@ -111,7 +112,10 @@ class BaseMergeCell(nn.Module):
x = F.max_pool2d(x, kernel_size=kernel_size, stride=kernel_size) x = F.max_pool2d(x, kernel_size=kernel_size, stride=kernel_size)
return x return x
def forward(self, x1, x2, out_size=None): def forward(self,
x1: torch.Tensor,
x2: torch.Tensor,
out_size: Optional[tuple] = None) -> torch.Tensor:
assert x1.shape[:2] == x2.shape[:2] assert x1.shape[:2] == x2.shape[:2]
assert out_size is None or len(out_size) == 2 assert out_size is None or len(out_size) == 2
if out_size is None: # resize to larger one if out_size is None: # resize to larger one
...@@ -131,7 +135,7 @@ class BaseMergeCell(nn.Module): ...@@ -131,7 +135,7 @@ class BaseMergeCell(nn.Module):
class SumCell(BaseMergeCell): class SumCell(BaseMergeCell):
def __init__(self, in_channels, out_channels, **kwargs): def __init__(self, in_channels: int, out_channels: int, **kwargs):
super().__init__(in_channels, out_channels, **kwargs) super().__init__(in_channels, out_channels, **kwargs)
def _binary_op(self, x1, x2): def _binary_op(self, x1, x2):
...@@ -140,7 +144,7 @@ class SumCell(BaseMergeCell): ...@@ -140,7 +144,7 @@ class SumCell(BaseMergeCell):
class ConcatCell(BaseMergeCell): class ConcatCell(BaseMergeCell):
def __init__(self, in_channels, out_channels, **kwargs): def __init__(self, in_channels: int, out_channels: int, **kwargs):
super().__init__(in_channels * 2, out_channels, **kwargs) super().__init__(in_channels * 2, out_channels, **kwargs)
def _binary_op(self, x1, x2): def _binary_op(self, x1, x2):
...@@ -150,7 +154,10 @@ class ConcatCell(BaseMergeCell): ...@@ -150,7 +154,10 @@ class ConcatCell(BaseMergeCell):
class GlobalPoolingCell(BaseMergeCell): class GlobalPoolingCell(BaseMergeCell):
def __init__(self, in_channels=None, out_channels=None, **kwargs): def __init__(self,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
**kwargs):
super().__init__(in_channels, out_channels, **kwargs) super().__init__(in_channels, out_channels, **kwargs)
self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch
from ..utils import ext_loader from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['min_area_polygons']) ext_module = ext_loader.load_ext('_ext', ['min_area_polygons'])
def min_area_polygons(pointsets): def min_area_polygons(pointsets: torch.Tensor) -> torch.Tensor:
"""Find the smallest polygons that surrounds all points in the point sets. """Find the smallest polygons that surrounds all points in the point sets.
Args: Args:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import math import math
from typing import Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -35,16 +36,16 @@ class ModulatedDeformConv2dFunction(Function): ...@@ -35,16 +36,16 @@ class ModulatedDeformConv2dFunction(Function):
@staticmethod @staticmethod
def forward(ctx, def forward(ctx,
input, input: torch.Tensor,
offset, offset: torch.Tensor,
mask, mask: torch.Tensor,
weight, weight: nn.Parameter,
bias=None, bias: Optional[nn.Parameter] = None,
stride=1, stride: int = 1,
padding=0, padding: int = 0,
dilation=1, dilation: int = 1,
groups=1, groups: int = 1,
deform_groups=1): deform_groups: int = 1) -> torch.Tensor:
if input is not None and input.dim() != 4: if input is not None and input.dim() != 4:
raise ValueError( raise ValueError(
f'Expected 4D tensor as input, got {input.dim()}D tensor \ f'Expected 4D tensor as input, got {input.dim()}D tensor \
...@@ -66,7 +67,7 @@ class ModulatedDeformConv2dFunction(Function): ...@@ -66,7 +67,7 @@ class ModulatedDeformConv2dFunction(Function):
# whatever the pytorch version is. # whatever the pytorch version is.
input = input.type_as(offset) input = input.type_as(offset)
weight = weight.type_as(input) weight = weight.type_as(input)
bias = bias.type_as(input) bias = bias.type_as(input) # type: ignore
ctx.save_for_backward(input, offset, mask, weight, bias) ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty( output = input.new_empty(
ModulatedDeformConv2dFunction._output_size(ctx, input, weight)) ModulatedDeformConv2dFunction._output_size(ctx, input, weight))
...@@ -95,7 +96,7 @@ class ModulatedDeformConv2dFunction(Function): ...@@ -95,7 +96,7 @@ class ModulatedDeformConv2dFunction(Function):
@staticmethod @staticmethod
@once_differentiable @once_differentiable
def backward(ctx, grad_output): def backward(ctx, grad_output: torch.Tensor) -> tuple:
input, offset, mask, weight, bias = ctx.saved_tensors input, offset, mask, weight, bias = ctx.saved_tensors
grad_input = torch.zeros_like(input) grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset) grad_offset = torch.zeros_like(offset)
...@@ -159,15 +160,15 @@ class ModulatedDeformConv2d(nn.Module): ...@@ -159,15 +160,15 @@ class ModulatedDeformConv2d(nn.Module):
@deprecated_api_warning({'deformable_groups': 'deform_groups'}, @deprecated_api_warning({'deformable_groups': 'deform_groups'},
cls_name='ModulatedDeformConv2d') cls_name='ModulatedDeformConv2d')
def __init__(self, def __init__(self,
in_channels, in_channels: int,
out_channels, out_channels: int,
kernel_size, kernel_size: Union[int, Tuple[int]],
stride=1, stride: int = 1,
padding=0, padding: int = 0,
dilation=1, dilation: int = 1,
groups=1, groups: int = 1,
deform_groups=1, deform_groups: int = 1,
bias=True): bias: Union[bool, str] = True):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
...@@ -199,7 +200,8 @@ class ModulatedDeformConv2d(nn.Module): ...@@ -199,7 +200,8 @@ class ModulatedDeformConv2d(nn.Module):
if self.bias is not None: if self.bias is not None:
self.bias.data.zero_() self.bias.data.zero_()
def forward(self, x, offset, mask): def forward(self, x: torch.Tensor, offset: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias, return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
self.stride, self.padding, self.stride, self.padding,
self.dilation, self.groups, self.dilation, self.groups,
...@@ -238,13 +240,13 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d): ...@@ -238,13 +240,13 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
bias=True) bias=True)
self.init_weights() self.init_weights()
def init_weights(self): def init_weights(self) -> None:
super().init_weights() super().init_weights()
if hasattr(self, 'conv_offset'): if hasattr(self, 'conv_offset'):
self.conv_offset.weight.data.zero_() self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_() self.conv_offset.bias.data.zero_()
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
out = self.conv_offset(x) out = self.conv_offset(x)
o1, o2, mask = torch.chunk(out, 3, dim=1) o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1) offset = torch.cat((o1, o2), dim=1)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import math import math
import warnings import warnings
from typing import Optional, no_type_check
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd.function import Function, once_differentiable from torch.autograd.function import Function, once_differentiable
import mmcv
from mmcv import deprecated_api_warning from mmcv import deprecated_api_warning
from mmcv.cnn import constant_init, xavier_init from mmcv.cnn import constant_init, xavier_init
from mmcv.cnn.bricks.registry import ATTENTION from mmcv.cnn.bricks.registry import ATTENTION
...@@ -20,8 +22,11 @@ ext_module = ext_loader.load_ext( ...@@ -20,8 +22,11 @@ ext_module = ext_loader.load_ext(
class MultiScaleDeformableAttnFunction(Function): class MultiScaleDeformableAttnFunction(Function):
@staticmethod @staticmethod
def forward(ctx, value, value_spatial_shapes, value_level_start_index, def forward(ctx, value: torch.Tensor, value_spatial_shapes: torch.Tensor,
sampling_locations, attention_weights, im2col_step): value_level_start_index: torch.Tensor,
sampling_locations: torch.Tensor,
attention_weights: torch.Tensor,
im2col_step: torch.Tensor) -> torch.Tensor:
"""GPU version of multi-scale deformable attention. """GPU version of multi-scale deformable attention.
Args: Args:
...@@ -37,7 +42,7 @@ class MultiScaleDeformableAttnFunction(Function): ...@@ -37,7 +42,7 @@ class MultiScaleDeformableAttnFunction(Function):
attention_weights (torch.Tensor): The weight of sampling points attention_weights (torch.Tensor): The weight of sampling points
used when calculate the attention, has shape used when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points), (bs ,num_queries, num_heads, num_levels, num_points),
im2col_step (Tensor): The step used in image to column. im2col_step (torch.Tensor): The step used in image to column.
Returns: Returns:
torch.Tensor: has shape (bs, num_queries, embed_dims) torch.Tensor: has shape (bs, num_queries, embed_dims)
...@@ -58,7 +63,7 @@ class MultiScaleDeformableAttnFunction(Function): ...@@ -58,7 +63,7 @@ class MultiScaleDeformableAttnFunction(Function):
@staticmethod @staticmethod
@once_differentiable @once_differentiable
def backward(ctx, grad_output): def backward(ctx, grad_output: torch.Tensor) -> tuple:
"""GPU version of backward function. """GPU version of backward function.
Args: Args:
...@@ -89,8 +94,10 @@ class MultiScaleDeformableAttnFunction(Function): ...@@ -89,8 +94,10 @@ class MultiScaleDeformableAttnFunction(Function):
grad_sampling_loc, grad_attn_weight, None grad_sampling_loc, grad_attn_weight, None
def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes, def multi_scale_deformable_attn_pytorch(
sampling_locations, attention_weights): value: torch.Tensor, value_spatial_shapes: torch.Tensor,
sampling_locations: torch.Tensor,
attention_weights: torch.Tensor) -> torch.Tensor:
"""CPU version of multi-scale deformable attention. """CPU version of multi-scale deformable attention.
Args: Args:
...@@ -178,15 +185,15 @@ class MultiScaleDeformableAttention(BaseModule): ...@@ -178,15 +185,15 @@ class MultiScaleDeformableAttention(BaseModule):
""" """
def __init__(self, def __init__(self,
embed_dims=256, embed_dims: int = 256,
num_heads=8, num_heads: int = 8,
num_levels=4, num_levels: int = 4,
num_points=4, num_points: int = 4,
im2col_step=64, im2col_step: int = 64,
dropout=0.1, dropout: float = 0.1,
batch_first=False, batch_first: bool = False,
norm_cfg=None, norm_cfg: Optional[dict] = None,
init_cfg=None): init_cfg: Optional[mmcv.ConfigDict] = None):
super().__init__(init_cfg) super().__init__(init_cfg)
if embed_dims % num_heads != 0: if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, ' raise ValueError(f'embed_dims must be divisible by num_heads, '
...@@ -225,7 +232,7 @@ class MultiScaleDeformableAttention(BaseModule): ...@@ -225,7 +232,7 @@ class MultiScaleDeformableAttention(BaseModule):
self.output_proj = nn.Linear(embed_dims, embed_dims) self.output_proj = nn.Linear(embed_dims, embed_dims)
self.init_weights() self.init_weights()
def init_weights(self): def init_weights(self) -> None:
"""Default initialization for Parameters of Module.""" """Default initialization for Parameters of Module."""
constant_init(self.sampling_offsets, 0.) constant_init(self.sampling_offsets, 0.)
thetas = torch.arange( thetas = torch.arange(
...@@ -245,19 +252,20 @@ class MultiScaleDeformableAttention(BaseModule): ...@@ -245,19 +252,20 @@ class MultiScaleDeformableAttention(BaseModule):
xavier_init(self.output_proj, distribution='uniform', bias=0.) xavier_init(self.output_proj, distribution='uniform', bias=0.)
self._is_init = True self._is_init = True
@no_type_check
@deprecated_api_warning({'residual': 'identity'}, @deprecated_api_warning({'residual': 'identity'},
cls_name='MultiScaleDeformableAttention') cls_name='MultiScaleDeformableAttention')
def forward(self, def forward(self,
query, query: torch.Tensor,
key=None, key: Optional[torch.Tensor] = None,
value=None, value: Optional[torch.Tensor] = None,
identity=None, identity: Optional[torch.Tensor] = None,
query_pos=None, query_pos: Optional[torch.Tensor] = None,
key_padding_mask=None, key_padding_mask: Optional[torch.Tensor] = None,
reference_points=None, reference_points: Optional[torch.Tensor] = None,
spatial_shapes=None, spatial_shapes: Optional[torch.Tensor] = None,
level_start_index=None, level_start_index: Optional[torch.Tensor] = None,
**kwargs): **kwargs) -> torch.Tensor:
"""Forward Function of MultiScaleDeformAttention. """Forward Function of MultiScaleDeformAttention.
Args: Args:
...@@ -272,8 +280,8 @@ class MultiScaleDeformableAttention(BaseModule): ...@@ -272,8 +280,8 @@ class MultiScaleDeformableAttention(BaseModule):
`query` will be used. `query` will be used.
query_pos (torch.Tensor): The positional encoding for `query`. query_pos (torch.Tensor): The positional encoding for `query`.
Default: None. Default: None.
key_pos (torch.Tensor): The positional encoding for `key`. Default key_padding_mask (torch.Tensor): ByteTensor for `query`, with
None. shape [bs, num_key].
reference_points (torch.Tensor): The normalized reference reference_points (torch.Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2), points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0), all elements is range in [0, 1], top-left (0,0),
...@@ -281,8 +289,6 @@ class MultiScaleDeformableAttention(BaseModule): ...@@ -281,8 +289,6 @@ class MultiScaleDeformableAttention(BaseModule):
or (N, Length_{query}, num_levels, 4), add or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to additional two dimensions is (w, h) to
form reference boxes. form reference boxes.
key_padding_mask (torch.Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (torch.Tensor): Spatial shape of features in spatial_shapes (torch.Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2), different levels. With shape (num_levels, 2),
last dimension represents (h, w). last dimension represents (h, w).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment