Unverified Commit 699398ad authored by zengxiang68's avatar zengxiang68 Committed by GitHub
Browse files

Add type hints for mmcv/ops (#1987)

* add type hints for mmcv/ops/...

* add type hints for mmcv/ops/...

* add type hints for mmcv/ops/...
parent 84a544fb
...@@ -6,7 +6,7 @@ import torch.nn.functional as F ...@@ -6,7 +6,7 @@ import torch.nn.functional as F
from mmcv.cnn import PLUGIN_LAYERS, Scale from mmcv.cnn import PLUGIN_LAYERS, Scale
def NEG_INF_DIAG(n, device): def NEG_INF_DIAG(n: int, device: torch.device) -> torch.Tensor:
"""Returns a diagonal matrix of size [n, n]. """Returns a diagonal matrix of size [n, n].
The diagonal are all "-inf". This is for avoiding calculating the The diagonal are all "-inf". This is for avoiding calculating the
...@@ -41,7 +41,7 @@ class CrissCrossAttention(nn.Module): ...@@ -41,7 +41,7 @@ class CrissCrossAttention(nn.Module):
in_channels (int): Channels of the input feature map. in_channels (int): Channels of the input feature map.
""" """
def __init__(self, in_channels): def __init__(self, in_channels: int) -> None:
super().__init__() super().__init__()
self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1) self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1) self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
...@@ -49,7 +49,7 @@ class CrissCrossAttention(nn.Module): ...@@ -49,7 +49,7 @@ class CrissCrossAttention(nn.Module):
self.gamma = Scale(0.) self.gamma = Scale(0.)
self.in_channels = in_channels self.in_channels = in_channels
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
"""forward function of Criss-Cross Attention. """forward function of Criss-Cross Attention.
Args: Args:
...@@ -78,7 +78,7 @@ class CrissCrossAttention(nn.Module): ...@@ -78,7 +78,7 @@ class CrissCrossAttention(nn.Module):
return out return out
def __repr__(self): def __repr__(self) -> str:
s = self.__class__.__name__ s = self.__class__.__name__
s += f'(in_channels={self.in_channels})' s += f'(in_channels={self.in_channels})'
return s return s
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
import numpy as np import numpy as np
import torch import torch
...@@ -7,8 +9,9 @@ from ..utils import ext_loader ...@@ -7,8 +9,9 @@ from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['contour_expand']) ext_module = ext_loader.load_ext('_ext', ['contour_expand'])
def contour_expand(kernel_mask, internal_kernel_label, min_kernel_area, def contour_expand(kernel_mask: Union[np.array, torch.Tensor],
kernel_num): internal_kernel_label: Union[np.array, torch.Tensor],
min_kernel_area: int, kernel_num: int) -> list:
"""Expand kernel contours so that foreground pixels are assigned into """Expand kernel contours so that foreground pixels are assigned into
instances. instances.
...@@ -42,7 +45,7 @@ def contour_expand(kernel_mask, internal_kernel_label, min_kernel_area, ...@@ -42,7 +45,7 @@ def contour_expand(kernel_mask, internal_kernel_label, min_kernel_area,
internal_kernel_label, internal_kernel_label,
min_kernel_area=min_kernel_area, min_kernel_area=min_kernel_area,
kernel_num=kernel_num) kernel_num=kernel_num)
label = label.tolist() label = label.tolist() # type: ignore
else: else:
label = ext_module.contour_expand(kernel_mask, internal_kernel_label, label = ext_module.contour_expand(kernel_mask, internal_kernel_label,
min_kernel_area, kernel_num) min_kernel_area, kernel_num)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch
from ..utils import ext_loader from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['convex_iou', 'convex_giou']) ext_module = ext_loader.load_ext('_ext', ['convex_iou', 'convex_giou'])
def convex_giou(pointsets, polygons): def convex_giou(pointsets: torch.Tensor,
polygons: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Return generalized intersection-over-union (Jaccard index) between point """Return generalized intersection-over-union (Jaccard index) between point
sets and polygons. sets and polygons.
...@@ -26,7 +31,8 @@ def convex_giou(pointsets, polygons): ...@@ -26,7 +31,8 @@ def convex_giou(pointsets, polygons):
return convex_giou, points_grad return convex_giou, points_grad
def convex_iou(pointsets, polygons): def convex_iou(pointsets: torch.Tensor,
polygons: torch.Tensor) -> torch.Tensor:
"""Return intersection-over-union (Jaccard index) between point sets and """Return intersection-over-union (Jaccard index) between point sets and
polygons. polygons.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment