Unverified Commit aea2bb28 authored by ShawnHu's avatar ShawnHu Committed by GitHub
Browse files

Add type hints for mmcv/ops (#2032)

* Add type hints in mmcv/ops/carafe.py

* Add type hints in mmcv/ops/corner_pool.py

* Add type hints in mmcv/ops/diff_iou_rotated.py

* Add type hints for other methods for mmcv/ops/corner_pool.py

* Add type hints for other methods in mmcv/ops/carafe.py

* Add type hints for symbolic method
parent bcd32914
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor
from torch.autograd import Function from torch.autograd import Function
from torch.nn.modules.module import Module from torch.nn.modules.module import Module
...@@ -17,7 +20,8 @@ ext_module = ext_loader.load_ext('_ext', [ ...@@ -17,7 +20,8 @@ ext_module = ext_loader.load_ext('_ext', [
class CARAFENaiveFunction(Function): class CARAFENaiveFunction(Function):
@staticmethod @staticmethod
def symbolic(g, features, masks, kernel_size, group_size, scale_factor): def symbolic(g, features: Tensor, masks: Tensor, kernel_size: int,
group_size: int, scale_factor: int) -> Tensor:
return g.op( return g.op(
'mmcv::MMCVCARAFENaive', 'mmcv::MMCVCARAFENaive',
features, features,
...@@ -27,7 +31,8 @@ class CARAFENaiveFunction(Function): ...@@ -27,7 +31,8 @@ class CARAFENaiveFunction(Function):
scale_factor_f=scale_factor) scale_factor_f=scale_factor)
@staticmethod @staticmethod
def forward(ctx, features, masks, kernel_size, group_size, scale_factor): def forward(ctx, features: Tensor, masks: Tensor, kernel_size: int,
group_size: int, scale_factor: int) -> Tensor:
assert scale_factor >= 1 assert scale_factor >= 1
assert masks.size(1) == kernel_size * kernel_size * group_size assert masks.size(1) == kernel_size * kernel_size * group_size
assert masks.size(-1) == features.size(-1) * scale_factor assert masks.size(-1) == features.size(-1) * scale_factor
...@@ -56,7 +61,9 @@ class CARAFENaiveFunction(Function): ...@@ -56,7 +61,9 @@ class CARAFENaiveFunction(Function):
return output return output
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(
ctx,
grad_output: Tensor) -> Tuple[Tensor, Tensor, None, None, None]:
assert grad_output.is_cuda assert grad_output.is_cuda
features, masks = ctx.saved_tensors features, masks = ctx.saved_tensors
...@@ -84,7 +91,7 @@ carafe_naive = CARAFENaiveFunction.apply ...@@ -84,7 +91,7 @@ carafe_naive = CARAFENaiveFunction.apply
class CARAFENaive(Module): class CARAFENaive(Module):
def __init__(self, kernel_size, group_size, scale_factor): def __init__(self, kernel_size: int, group_size: int, scale_factor: int):
super().__init__() super().__init__()
assert isinstance(kernel_size, int) and isinstance( assert isinstance(kernel_size, int) and isinstance(
...@@ -93,7 +100,7 @@ class CARAFENaive(Module): ...@@ -93,7 +100,7 @@ class CARAFENaive(Module):
self.group_size = group_size self.group_size = group_size
self.scale_factor = scale_factor self.scale_factor = scale_factor
def forward(self, features, masks): def forward(self, features: Tensor, masks: Tensor) -> Tensor:
return carafe_naive(features, masks, self.kernel_size, self.group_size, return carafe_naive(features, masks, self.kernel_size, self.group_size,
self.scale_factor) self.scale_factor)
...@@ -101,7 +108,8 @@ class CARAFENaive(Module): ...@@ -101,7 +108,8 @@ class CARAFENaive(Module):
class CARAFEFunction(Function): class CARAFEFunction(Function):
@staticmethod @staticmethod
def symbolic(g, features, masks, kernel_size, group_size, scale_factor): def symbolic(g, features: Tensor, masks: Tensor, kernel_size: int,
group_size: int, scale_factor: int) -> Tensor:
return g.op( return g.op(
'mmcv::MMCVCARAFE', 'mmcv::MMCVCARAFE',
features, features,
...@@ -111,7 +119,8 @@ class CARAFEFunction(Function): ...@@ -111,7 +119,8 @@ class CARAFEFunction(Function):
scale_factor_f=scale_factor) scale_factor_f=scale_factor)
@staticmethod @staticmethod
def forward(ctx, features, masks, kernel_size, group_size, scale_factor): def forward(ctx, features: Tensor, masks: Tensor, kernel_size: int,
group_size: int, scale_factor: int) -> Tensor:
assert scale_factor >= 1 assert scale_factor >= 1
assert masks.size(1) == kernel_size * kernel_size * group_size assert masks.size(1) == kernel_size * kernel_size * group_size
assert masks.size(-1) == features.size(-1) * scale_factor assert masks.size(-1) == features.size(-1) * scale_factor
...@@ -146,7 +155,9 @@ class CARAFEFunction(Function): ...@@ -146,7 +155,9 @@ class CARAFEFunction(Function):
return output return output
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(
ctx,
grad_output: Tensor) -> Tuple[Tensor, Tensor, None, None, None]:
assert grad_output.is_cuda assert grad_output.is_cuda
features, masks, rfeatures = ctx.saved_tensors features, masks, rfeatures = ctx.saved_tensors
...@@ -194,7 +205,7 @@ class CARAFE(Module): ...@@ -194,7 +205,7 @@ class CARAFE(Module):
upsampled feature map upsampled feature map
""" """
def __init__(self, kernel_size, group_size, scale_factor): def __init__(self, kernel_size: int, group_size: int, scale_factor: int):
super().__init__() super().__init__()
assert isinstance(kernel_size, int) and isinstance( assert isinstance(kernel_size, int) and isinstance(
...@@ -203,7 +214,7 @@ class CARAFE(Module): ...@@ -203,7 +214,7 @@ class CARAFE(Module):
self.group_size = group_size self.group_size = group_size
self.scale_factor = scale_factor self.scale_factor = scale_factor
def forward(self, features, masks): def forward(self, features: Tensor, masks: Tensor) -> Tensor:
return carafe(features, masks, self.kernel_size, self.group_size, return carafe(features, masks, self.kernel_size, self.group_size,
self.scale_factor) self.scale_factor)
...@@ -231,13 +242,13 @@ class CARAFEPack(nn.Module): ...@@ -231,13 +242,13 @@ class CARAFEPack(nn.Module):
""" """
def __init__(self, def __init__(self,
channels, channels: int,
scale_factor, scale_factor: int,
up_kernel=5, up_kernel: int = 5,
up_group=1, up_group: int = 1,
encoder_kernel=3, encoder_kernel: int = 3,
encoder_dilation=1, encoder_dilation: int = 1,
compressed_channels=64): compressed_channels: int = 64):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.scale_factor = scale_factor self.scale_factor = scale_factor
...@@ -264,7 +275,7 @@ class CARAFEPack(nn.Module): ...@@ -264,7 +275,7 @@ class CARAFEPack(nn.Module):
xavier_init(m, distribution='uniform') xavier_init(m, distribution='uniform')
normal_init(self.content_encoder, std=0.001) normal_init(self.content_encoder, std=0.001)
def kernel_normalizer(self, mask): def kernel_normalizer(self, mask: Tensor) -> Tensor:
mask = F.pixel_shuffle(mask, self.scale_factor) mask = F.pixel_shuffle(mask, self.scale_factor)
n, mask_c, h, w = mask.size() n, mask_c, h, w = mask.size()
# use float division explicitly, # use float division explicitly,
...@@ -277,11 +288,11 @@ class CARAFEPack(nn.Module): ...@@ -277,11 +288,11 @@ class CARAFEPack(nn.Module):
return mask return mask
def feature_reassemble(self, x, mask): def feature_reassemble(self, x: Tensor, mask: Tensor) -> Tensor:
x = carafe(x, mask, self.up_kernel, self.up_group, self.scale_factor) x = carafe(x, mask, self.up_kernel, self.up_group, self.scale_factor)
return x return x
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
compressed_x = self.channel_compressor(x) compressed_x = self.channel_compressor(x)
mask = self.content_encoder(compressed_x) mask = self.content_encoder(compressed_x)
mask = self.kernel_normalizer(mask) mask = self.kernel_normalizer(mask)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
from torch import nn from torch import Tensor, nn
from torch.autograd import Function from torch.autograd import Function
_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3} _mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3}
def _corner_pool(x, dim, flip): def _corner_pool(x: Tensor, dim: int, flip: bool) -> Tensor:
size = x.size(dim) size = x.size(dim)
output = x.clone() output = x.clone()
...@@ -38,52 +38,52 @@ def _corner_pool(x, dim, flip): ...@@ -38,52 +38,52 @@ def _corner_pool(x, dim, flip):
class TopPoolFunction(Function): class TopPoolFunction(Function):
@staticmethod @staticmethod
def symbolic(g, input): def symbolic(g, input: Tensor) -> Tensor:
output = g.op( output = g.op(
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['top'])) 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['top']))
return output return output
@staticmethod @staticmethod
def forward(ctx, input): def forward(ctx, input: Tensor) -> Tensor:
return _corner_pool(input, 2, True) return _corner_pool(input, 2, True)
class BottomPoolFunction(Function): class BottomPoolFunction(Function):
@staticmethod @staticmethod
def symbolic(g, input): def symbolic(g, input: Tensor) -> Tensor:
output = g.op( output = g.op(
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['bottom'])) 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['bottom']))
return output return output
@staticmethod @staticmethod
def forward(ctx, input): def forward(ctx, input: Tensor) -> Tensor:
return _corner_pool(input, 2, False) return _corner_pool(input, 2, False)
class LeftPoolFunction(Function): class LeftPoolFunction(Function):
@staticmethod @staticmethod
def symbolic(g, input): def symbolic(g, input: Tensor) -> Tensor:
output = g.op( output = g.op(
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['left'])) 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['left']))
return output return output
@staticmethod @staticmethod
def forward(ctx, input): def forward(ctx, input: Tensor) -> Tensor:
return _corner_pool(input, 3, True) return _corner_pool(input, 3, True)
class RightPoolFunction(Function): class RightPoolFunction(Function):
@staticmethod @staticmethod
def symbolic(g, input): def symbolic(g, input: Tensor) -> Tensor:
output = g.op( output = g.op(
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['right'])) 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['right']))
return output return output
@staticmethod @staticmethod
def forward(ctx, input): def forward(ctx, input: Tensor) -> Tensor:
return _corner_pool(input, 3, False) return _corner_pool(input, 3, False)
...@@ -124,13 +124,13 @@ class CornerPool(nn.Module): ...@@ -124,13 +124,13 @@ class CornerPool(nn.Module):
'top': (2, True), 'top': (2, True),
} }
def __init__(self, mode): def __init__(self, mode: str):
super().__init__() super().__init__()
assert mode in self.pool_functions assert mode in self.pool_functions
self.mode = mode self.mode = mode
self.corner_pool = self.pool_functions[mode] self.corner_pool: Function = self.pool_functions[mode]
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
if torch.__version__ != 'parrots' and torch.__version__ >= '1.5.0': if torch.__version__ != 'parrots' and torch.__version__ >= '1.5.0':
if torch.onnx.is_in_onnx_export(): if torch.onnx.is_in_onnx_export():
assert torch.__version__ >= '1.7.0', \ assert torch.__version__ >= '1.7.0', \
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
# Adapted from https://github.com/lilanxiao/Rotated_IoU/blob/master/box_intersection_2d.py # noqa # Adapted from https://github.com/lilanxiao/Rotated_IoU/blob/master/box_intersection_2d.py # noqa
# Adapted from https://github.com/lilanxiao/Rotated_IoU/blob/master/oriented_iou_loss.py # noqa # Adapted from https://github.com/lilanxiao/Rotated_IoU/blob/master/oriented_iou_loss.py # noqa
from typing import Tuple
import torch import torch
from torch import Tensor
from torch.autograd import Function from torch.autograd import Function
from ..utils import ext_loader from ..utils import ext_loader
...@@ -26,7 +29,8 @@ class SortVertices(Function): ...@@ -26,7 +29,8 @@ class SortVertices(Function):
return () return ()
def box_intersection(corners1, corners2): def box_intersection(corners1: Tensor,
corners2: Tensor) -> Tuple[Tensor, Tensor]:
"""Find intersection points of rectangles. """Find intersection points of rectangles.
Convention: if two edges are collinear, there is no intersection point. Convention: if two edges are collinear, there is no intersection point.
...@@ -68,7 +72,7 @@ def box_intersection(corners1, corners2): ...@@ -68,7 +72,7 @@ def box_intersection(corners1, corners2):
return intersections, mask return intersections, mask
def box1_in_box2(corners1, corners2): def box1_in_box2(corners1: Tensor, corners2: Tensor) -> Tensor:
"""Check if corners of box1 lie in box2. """Check if corners of box1 lie in box2.
Convention: if a corner is exactly on the edge of the other box, Convention: if a corner is exactly on the edge of the other box,
it's also a valid point. it's also a valid point.
...@@ -101,7 +105,7 @@ def box1_in_box2(corners1, corners2): ...@@ -101,7 +105,7 @@ def box1_in_box2(corners1, corners2):
return cond1 * cond2 return cond1 * cond2
def box_in_box(corners1, corners2): def box_in_box(corners1: Tensor, corners2: Tensor) -> Tuple[Tensor, Tensor]:
"""Check if corners of two boxes lie in each other. """Check if corners of two boxes lie in each other.
Args: Args:
...@@ -118,8 +122,9 @@ def box_in_box(corners1, corners2): ...@@ -118,8 +122,9 @@ def box_in_box(corners1, corners2):
return c1_in_2, c2_in_1 return c1_in_2, c2_in_1
def build_vertices(corners1, corners2, c1_in_2, c2_in_1, intersections, def build_vertices(corners1: Tensor, corners2: Tensor, c1_in_2: Tensor,
valid_mask): c2_in_1: Tensor, intersections: Tensor,
valid_mask: Tensor) -> Tuple[Tensor, Tensor]:
"""Find vertices of intersection area. """Find vertices of intersection area.
Args: Args:
...@@ -149,7 +154,7 @@ def build_vertices(corners1, corners2, c1_in_2, c2_in_1, intersections, ...@@ -149,7 +154,7 @@ def build_vertices(corners1, corners2, c1_in_2, c2_in_1, intersections,
return vertices, mask return vertices, mask
def sort_indices(vertices, mask): def sort_indices(vertices: Tensor, mask: Tensor) -> Tensor:
"""Sort indices. """Sort indices.
Note: Note:
why 9? the polygon has maximal 8 vertices. why 9? the polygon has maximal 8 vertices.
...@@ -176,7 +181,8 @@ def sort_indices(vertices, mask): ...@@ -176,7 +181,8 @@ def sort_indices(vertices, mask):
return SortVertices.apply(vertices_normalized, mask, num_valid).long() return SortVertices.apply(vertices_normalized, mask, num_valid).long()
def calculate_area(idx_sorted, vertices): def calculate_area(idx_sorted: Tensor,
vertices: Tensor) -> Tuple[Tensor, Tensor]:
"""Calculate area of intersection. """Calculate area of intersection.
Args: Args:
...@@ -197,7 +203,8 @@ def calculate_area(idx_sorted, vertices): ...@@ -197,7 +203,8 @@ def calculate_area(idx_sorted, vertices):
return area, selected return area, selected
def oriented_box_intersection_2d(corners1, corners2): def oriented_box_intersection_2d(corners1: Tensor,
corners2: Tensor) -> Tuple[Tensor, Tensor]:
"""Calculate intersection area of 2d rotated boxes. """Calculate intersection area of 2d rotated boxes.
Args: Args:
...@@ -217,7 +224,7 @@ def oriented_box_intersection_2d(corners1, corners2): ...@@ -217,7 +224,7 @@ def oriented_box_intersection_2d(corners1, corners2):
return calculate_area(sorted_indices, vertices) return calculate_area(sorted_indices, vertices)
def box2corners(box): def box2corners(box: Tensor) -> Tensor:
"""Convert rotated 2d box coordinate to corners. """Convert rotated 2d box coordinate to corners.
Args: Args:
...@@ -245,7 +252,7 @@ def box2corners(box): ...@@ -245,7 +252,7 @@ def box2corners(box):
return rotated return rotated
def diff_iou_rotated_2d(box1, box2): def diff_iou_rotated_2d(box1: Tensor, box2: Tensor) -> Tensor:
"""Calculate differentiable iou of rotated 2d boxes. """Calculate differentiable iou of rotated 2d boxes.
Args: Args:
...@@ -266,7 +273,7 @@ def diff_iou_rotated_2d(box1, box2): ...@@ -266,7 +273,7 @@ def diff_iou_rotated_2d(box1, box2):
return iou return iou
def diff_iou_rotated_3d(box3d1, box3d2): def diff_iou_rotated_3d(box3d1: Tensor, box3d2: Tensor) -> Tensor:
"""Calculate differentiable iou of rotated 3d boxes. """Calculate differentiable iou of rotated 3d boxes.
Args: Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment