Unverified Commit e5d9ef52 authored by gy77's avatar gy77 Committed by GitHub
Browse files

Add type hints in mmcv/ops/point_sample.py (#2019)

* Add type hints in mmcv/ops/point_sample.py

* fix docstring and type hint
parent 966b7428
# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend # noqa
from os import path as osp
from typing import Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules.utils import _pair
from torch.onnx.operators import shape_as_tensor
def bilinear_grid_sample(im, grid, align_corners=False):
def bilinear_grid_sample(im: Tensor,
grid: Tensor,
align_corners: bool = False) -> Tensor:
"""Given an input and a flow-field grid, computes the output using input
values and pixel locations from grid. Supported only bilinear interpolation
method to sample the input pixels.
......@@ -17,7 +21,7 @@ def bilinear_grid_sample(im, grid, align_corners=False):
Args:
im (torch.Tensor): Input feature map, shape (N, C, H, W)
grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
align_corners {bool}: If set to True, the extrema (-1 and 1) are
align_corners (bool): If set to True, the extrema (-1 and 1) are
considered as referring to the center points of the input’s
corner pixels. If set to False, they are instead considered as
referring to the corner points of the input’s corner pixels,
......@@ -85,14 +89,14 @@ def bilinear_grid_sample(im, grid, align_corners=False):
return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)
def is_in_onnx_export_without_custom_ops():
def is_in_onnx_export_without_custom_ops() -> bool:
from mmcv.ops import get_onnxruntime_op_path
ort_custom_op_path = get_onnxruntime_op_path()
return torch.onnx.is_in_onnx_export(
) and not osp.exists(ort_custom_op_path)
def normalize(grid):
def normalize(grid: Tensor) -> Tensor:
"""Normalize input grid from [-1, 1] to [0, 1]
Args:
......@@ -105,7 +109,7 @@ def normalize(grid):
return (grid + 1.0) / 2.0
def denormalize(grid):
def denormalize(grid: Tensor) -> Tensor:
"""Denormalize input grid from range [0, 1] to [-1, 1]
Args:
......@@ -118,7 +122,8 @@ def denormalize(grid):
return grid * 2.0 - 1.0
def generate_grid(num_grid, size, device):
def generate_grid(num_grid: int, size: Tuple[int, int],
device: torch.device) -> Tensor:
"""Generate regular square grid of points in [0, 1] x [0, 1] coordinate
space.
......@@ -139,7 +144,8 @@ def generate_grid(num_grid, size, device):
return grid.view(1, -1, 2).expand(num_grid, -1, -1)
def rel_roi_point_to_abs_img_point(rois, rel_roi_points):
def rel_roi_point_to_abs_img_point(rois: Tensor,
rel_roi_points: Tensor) -> Tensor:
"""Convert roi based relative point coordinates to image based absolute
point coordinates.
......@@ -170,7 +176,7 @@ def rel_roi_point_to_abs_img_point(rois, rel_roi_points):
return abs_img_points
def get_shape_from_feature_map(x):
def get_shape_from_feature_map(x: Tensor) -> Tensor:
"""Get spatial resolution of input feature map considering exporting to
onnx mode.
......@@ -189,7 +195,9 @@ def get_shape_from_feature_map(x):
return img_shape
def abs_img_point_to_rel_img_point(abs_img_points, img, spatial_scale=1.):
def abs_img_point_to_rel_img_point(abs_img_points: Tensor,
img: Union[tuple, Tensor],
spatial_scale: float = 1.) -> Tensor:
"""Convert image based absolute point coordinates to image based relative
coordinates for sampling.
......@@ -220,10 +228,10 @@ def abs_img_point_to_rel_img_point(abs_img_points, img, spatial_scale=1.):
return abs_img_points / scale * spatial_scale
def rel_roi_point_to_rel_img_point(rois,
rel_roi_points,
img,
spatial_scale=1.):
def rel_roi_point_to_rel_img_point(rois: Tensor,
rel_roi_points: Tensor,
img: Union[tuple, Tensor],
spatial_scale: float = 1.) -> Tensor:
"""Convert roi based relative point coordinates to image based absolute
point coordinates.
......@@ -247,7 +255,10 @@ def rel_roi_point_to_rel_img_point(rois,
return rel_img_point
def point_sample(input, points, align_corners=False, **kwargs):
def point_sample(input: Tensor,
points: Tensor,
align_corners: bool = False,
**kwargs) -> Tensor:
"""A wrapper around :func:`grid_sample` to support 3D point_coords tensors
Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to
lie inside ``[0, 1] x [0, 1]`` square.
......@@ -285,7 +296,10 @@ def point_sample(input, points, align_corners=False, **kwargs):
class SimpleRoIAlign(nn.Module):
def __init__(self, output_size, spatial_scale, aligned=True):
def __init__(self,
output_size: Tuple[int],
spatial_scale: float,
aligned: bool = True) -> None:
"""Simple RoI align in PointRend, faster than standard RoIAlign.
Args:
......@@ -303,7 +317,7 @@ class SimpleRoIAlign(nn.Module):
self.use_torchvision = False
self.aligned = aligned
def forward(self, features, rois):
def forward(self, features: Tensor, rois: Tensor) -> Tensor:
num_imgs = features.size(0)
num_rois = rois.size(0)
rel_roi_points = generate_grid(
......@@ -339,7 +353,7 @@ class SimpleRoIAlign(nn.Module):
return roi_feats
def __repr__(self):
def __repr__(self) -> str:
format_str = self.__class__.__name__
format_str += '(output_size={}, spatial_scale={}'.format(
self.output_size, self.spatial_scale)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment