Unverified Commit e5d9ef52 authored by gy77's avatar gy77 Committed by GitHub
Browse files

Add type hints in mmcv/ops/point_sample.py (#2019)

* Add type hints in mmcv/ops/point_sample.py

* fix docstring and type hint
parent 966b7428
# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend # noqa # Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend # noqa
from os import path as osp from os import path as osp
from typing import Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torch.onnx.operators import shape_as_tensor from torch.onnx.operators import shape_as_tensor
def bilinear_grid_sample(im, grid, align_corners=False): def bilinear_grid_sample(im: Tensor,
grid: Tensor,
align_corners: bool = False) -> Tensor:
"""Given an input and a flow-field grid, computes the output using input """Given an input and a flow-field grid, computes the output using input
values and pixel locations from grid. Supported only bilinear interpolation values and pixel locations from grid. Supported only bilinear interpolation
method to sample the input pixels. method to sample the input pixels.
...@@ -17,7 +21,7 @@ def bilinear_grid_sample(im, grid, align_corners=False): ...@@ -17,7 +21,7 @@ def bilinear_grid_sample(im, grid, align_corners=False):
Args: Args:
im (torch.Tensor): Input feature map, shape (N, C, H, W) im (torch.Tensor): Input feature map, shape (N, C, H, W)
grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2) grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
align_corners {bool}: If set to True, the extrema (-1 and 1) are align_corners (bool): If set to True, the extrema (-1 and 1) are
considered as referring to the center points of the input’s considered as referring to the center points of the input’s
corner pixels. If set to False, they are instead considered as corner pixels. If set to False, they are instead considered as
referring to the corner points of the input’s corner pixels, referring to the corner points of the input’s corner pixels,
...@@ -85,14 +89,14 @@ def bilinear_grid_sample(im, grid, align_corners=False): ...@@ -85,14 +89,14 @@ def bilinear_grid_sample(im, grid, align_corners=False):
return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw) return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)
def is_in_onnx_export_without_custom_ops(): def is_in_onnx_export_without_custom_ops() -> bool:
from mmcv.ops import get_onnxruntime_op_path from mmcv.ops import get_onnxruntime_op_path
ort_custom_op_path = get_onnxruntime_op_path() ort_custom_op_path = get_onnxruntime_op_path()
return torch.onnx.is_in_onnx_export( return torch.onnx.is_in_onnx_export(
) and not osp.exists(ort_custom_op_path) ) and not osp.exists(ort_custom_op_path)
def normalize(grid): def normalize(grid: Tensor) -> Tensor:
"""Normalize input grid from [-1, 1] to [0, 1] """Normalize input grid from [-1, 1] to [0, 1]
Args: Args:
...@@ -105,7 +109,7 @@ def normalize(grid): ...@@ -105,7 +109,7 @@ def normalize(grid):
return (grid + 1.0) / 2.0 return (grid + 1.0) / 2.0
def denormalize(grid): def denormalize(grid: Tensor) -> Tensor:
"""Denormalize input grid from range [0, 1] to [-1, 1] """Denormalize input grid from range [0, 1] to [-1, 1]
Args: Args:
...@@ -118,7 +122,8 @@ def denormalize(grid): ...@@ -118,7 +122,8 @@ def denormalize(grid):
return grid * 2.0 - 1.0 return grid * 2.0 - 1.0
def generate_grid(num_grid, size, device): def generate_grid(num_grid: int, size: Tuple[int, int],
device: torch.device) -> Tensor:
"""Generate regular square grid of points in [0, 1] x [0, 1] coordinate """Generate regular square grid of points in [0, 1] x [0, 1] coordinate
space. space.
...@@ -139,7 +144,8 @@ def generate_grid(num_grid, size, device): ...@@ -139,7 +144,8 @@ def generate_grid(num_grid, size, device):
return grid.view(1, -1, 2).expand(num_grid, -1, -1) return grid.view(1, -1, 2).expand(num_grid, -1, -1)
def rel_roi_point_to_abs_img_point(rois, rel_roi_points): def rel_roi_point_to_abs_img_point(rois: Tensor,
rel_roi_points: Tensor) -> Tensor:
"""Convert roi based relative point coordinates to image based absolute """Convert roi based relative point coordinates to image based absolute
point coordinates. point coordinates.
...@@ -170,7 +176,7 @@ def rel_roi_point_to_abs_img_point(rois, rel_roi_points): ...@@ -170,7 +176,7 @@ def rel_roi_point_to_abs_img_point(rois, rel_roi_points):
return abs_img_points return abs_img_points
def get_shape_from_feature_map(x): def get_shape_from_feature_map(x: Tensor) -> Tensor:
"""Get spatial resolution of input feature map considering exporting to """Get spatial resolution of input feature map considering exporting to
onnx mode. onnx mode.
...@@ -189,7 +195,9 @@ def get_shape_from_feature_map(x): ...@@ -189,7 +195,9 @@ def get_shape_from_feature_map(x):
return img_shape return img_shape
def abs_img_point_to_rel_img_point(abs_img_points, img, spatial_scale=1.): def abs_img_point_to_rel_img_point(abs_img_points: Tensor,
img: Union[tuple, Tensor],
spatial_scale: float = 1.) -> Tensor:
"""Convert image based absolute point coordinates to image based relative """Convert image based absolute point coordinates to image based relative
coordinates for sampling. coordinates for sampling.
...@@ -220,10 +228,10 @@ def abs_img_point_to_rel_img_point(abs_img_points, img, spatial_scale=1.): ...@@ -220,10 +228,10 @@ def abs_img_point_to_rel_img_point(abs_img_points, img, spatial_scale=1.):
return abs_img_points / scale * spatial_scale return abs_img_points / scale * spatial_scale
def rel_roi_point_to_rel_img_point(rois, def rel_roi_point_to_rel_img_point(rois: Tensor,
rel_roi_points, rel_roi_points: Tensor,
img, img: Union[tuple, Tensor],
spatial_scale=1.): spatial_scale: float = 1.) -> Tensor:
"""Convert roi based relative point coordinates to image based absolute """Convert roi based relative point coordinates to image based absolute
point coordinates. point coordinates.
...@@ -247,7 +255,10 @@ def rel_roi_point_to_rel_img_point(rois, ...@@ -247,7 +255,10 @@ def rel_roi_point_to_rel_img_point(rois,
return rel_img_point return rel_img_point
def point_sample(input, points, align_corners=False, **kwargs): def point_sample(input: Tensor,
points: Tensor,
align_corners: bool = False,
**kwargs) -> Tensor:
"""A wrapper around :func:`grid_sample` to support 3D point_coords tensors """A wrapper around :func:`grid_sample` to support 3D point_coords tensors
Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to
lie inside ``[0, 1] x [0, 1]`` square. lie inside ``[0, 1] x [0, 1]`` square.
...@@ -285,7 +296,10 @@ def point_sample(input, points, align_corners=False, **kwargs): ...@@ -285,7 +296,10 @@ def point_sample(input, points, align_corners=False, **kwargs):
class SimpleRoIAlign(nn.Module): class SimpleRoIAlign(nn.Module):
def __init__(self, output_size, spatial_scale, aligned=True): def __init__(self,
output_size: Tuple[int],
spatial_scale: float,
aligned: bool = True) -> None:
"""Simple RoI align in PointRend, faster than standard RoIAlign. """Simple RoI align in PointRend, faster than standard RoIAlign.
Args: Args:
...@@ -303,7 +317,7 @@ class SimpleRoIAlign(nn.Module): ...@@ -303,7 +317,7 @@ class SimpleRoIAlign(nn.Module):
self.use_torchvision = False self.use_torchvision = False
self.aligned = aligned self.aligned = aligned
def forward(self, features, rois): def forward(self, features: Tensor, rois: Tensor) -> Tensor:
num_imgs = features.size(0) num_imgs = features.size(0)
num_rois = rois.size(0) num_rois = rois.size(0)
rel_roi_points = generate_grid( rel_roi_points = generate_grid(
...@@ -339,7 +353,7 @@ class SimpleRoIAlign(nn.Module): ...@@ -339,7 +353,7 @@ class SimpleRoIAlign(nn.Module):
return roi_feats return roi_feats
def __repr__(self): def __repr__(self) -> str:
format_str = self.__class__.__name__ format_str = self.__class__.__name__
format_str += '(output_size={}, spatial_scale={}'.format( format_str += '(output_size={}, spatial_scale={}'.format(
self.output_size, self.spatial_scale) self.output_size, self.spatial_scale)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment