Unverified Commit 19902d89 authored by gy77's avatar gy77 Committed by GitHub
Browse files

Add type hints in mmcv/ops/points_sampler.py (#2015)



* add typehint in mmcv/ops/points_sampler.py

* Update mmcv/ops/points_sampler.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/ops/points_sampler.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 1211b06b
from typing import List
import torch
from torch import Tensor
from torch import nn as nn
from mmcv.runner import force_fp32
......@@ -8,7 +9,9 @@ from .furthest_point_sample import (furthest_point_sample,
furthest_point_sample_with_dist)
def calc_square_dist(point_feat_a, point_feat_b, norm=True):
def calc_square_dist(point_feat_a: Tensor,
point_feat_b: Tensor,
norm: bool = True) -> Tensor:
"""Calculating square distance between a and b.
Args:
......@@ -34,7 +37,7 @@ def calc_square_dist(point_feat_a, point_feat_b, norm=True):
return dist
def get_sampler_cls(sampler_type):
def get_sampler_cls(sampler_type: str) -> nn.Module:
"""Get the type and mode of points sampler.
Args:
......@@ -74,7 +77,7 @@ class PointsSampler(nn.Module):
def __init__(self,
num_point: List[int],
fps_mod_list: List[str] = ['D-FPS'],
fps_sample_range_list: List[int] = [-1]):
fps_sample_range_list: List[int] = [-1]) -> None:
super().__init__()
# FPS would be applied to different fps_mod in the list,
# so the length of the num_point should be equal to
......@@ -89,7 +92,7 @@ class PointsSampler(nn.Module):
self.fp16_enabled = False
@force_fp32()
def forward(self, points_xyz, features):
def forward(self, points_xyz: Tensor, features: Tensor) -> Tensor:
"""
Args:
points_xyz (torch.Tensor): (B, N, 3) xyz coordinates of
......@@ -134,10 +137,10 @@ class PointsSampler(nn.Module):
class DFPSSampler(nn.Module):
"""Using Euclidean distances of points for FPS."""
def __init__(self):
def __init__(self) -> None:
super().__init__()
def forward(self, points, features, npoint):
def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor:
"""Sampling points with D-FPS."""
fps_idx = furthest_point_sample(points.contiguous(), npoint)
return fps_idx
......@@ -146,10 +149,10 @@ class DFPSSampler(nn.Module):
class FFPSSampler(nn.Module):
"""Using feature distances for FPS."""
def __init__(self):
def __init__(self) -> None:
super().__init__()
def forward(self, points, features, npoint):
def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor:
"""Sampling points with F-FPS."""
assert features is not None, \
'feature input to FFPS_Sampler should not be None'
......@@ -163,10 +166,10 @@ class FFPSSampler(nn.Module):
class FSSampler(nn.Module):
"""Using F-FPS and D-FPS simultaneously."""
def __init__(self):
def __init__(self) -> None:
super().__init__()
def forward(self, points, features, npoint):
def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor:
"""Sampling points with FS_Sampling."""
assert features is not None, \
'feature input to FS_Sampler should not be None'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment