"examples/sampling/vscode:/vscode.git/clone" did not exist on "8909d1ff03974c9012b50a978faca31e3c86d9b3"
Unverified Commit 19902d89 authored by gy77's avatar gy77 Committed by GitHub
Browse files

Add type hints in mmcv/ops/points_sampler.py (#2015)



* add typehint in mmcv/ops/points_sampler.py

* Update mmcv/ops/points_sampler.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/ops/points_sampler.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 1211b06b
from typing import List from typing import List
import torch import torch
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmcv.runner import force_fp32 from mmcv.runner import force_fp32
...@@ -8,7 +9,9 @@ from .furthest_point_sample import (furthest_point_sample, ...@@ -8,7 +9,9 @@ from .furthest_point_sample import (furthest_point_sample,
furthest_point_sample_with_dist) furthest_point_sample_with_dist)
def calc_square_dist(point_feat_a, point_feat_b, norm=True): def calc_square_dist(point_feat_a: Tensor,
point_feat_b: Tensor,
norm: bool = True) -> Tensor:
"""Calculating square distance between a and b. """Calculating square distance between a and b.
Args: Args:
...@@ -34,7 +37,7 @@ def calc_square_dist(point_feat_a, point_feat_b, norm=True): ...@@ -34,7 +37,7 @@ def calc_square_dist(point_feat_a, point_feat_b, norm=True):
return dist return dist
def get_sampler_cls(sampler_type): def get_sampler_cls(sampler_type: str) -> nn.Module:
"""Get the type and mode of points sampler. """Get the type and mode of points sampler.
Args: Args:
...@@ -74,7 +77,7 @@ class PointsSampler(nn.Module): ...@@ -74,7 +77,7 @@ class PointsSampler(nn.Module):
def __init__(self, def __init__(self,
num_point: List[int], num_point: List[int],
fps_mod_list: List[str] = ['D-FPS'], fps_mod_list: List[str] = ['D-FPS'],
fps_sample_range_list: List[int] = [-1]): fps_sample_range_list: List[int] = [-1]) -> None:
super().__init__() super().__init__()
# FPS would be applied to different fps_mod in the list, # FPS would be applied to different fps_mod in the list,
# so the length of the num_point should be equal to # so the length of the num_point should be equal to
...@@ -89,7 +92,7 @@ class PointsSampler(nn.Module): ...@@ -89,7 +92,7 @@ class PointsSampler(nn.Module):
self.fp16_enabled = False self.fp16_enabled = False
@force_fp32() @force_fp32()
def forward(self, points_xyz, features): def forward(self, points_xyz: Tensor, features: Tensor) -> Tensor:
""" """
Args: Args:
points_xyz (torch.Tensor): (B, N, 3) xyz coordinates of points_xyz (torch.Tensor): (B, N, 3) xyz coordinates of
...@@ -134,10 +137,10 @@ class PointsSampler(nn.Module): ...@@ -134,10 +137,10 @@ class PointsSampler(nn.Module):
class DFPSSampler(nn.Module): class DFPSSampler(nn.Module):
"""Using Euclidean distances of points for FPS.""" """Using Euclidean distances of points for FPS."""
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
def forward(self, points, features, npoint): def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor:
"""Sampling points with D-FPS.""" """Sampling points with D-FPS."""
fps_idx = furthest_point_sample(points.contiguous(), npoint) fps_idx = furthest_point_sample(points.contiguous(), npoint)
return fps_idx return fps_idx
...@@ -146,10 +149,10 @@ class DFPSSampler(nn.Module): ...@@ -146,10 +149,10 @@ class DFPSSampler(nn.Module):
class FFPSSampler(nn.Module): class FFPSSampler(nn.Module):
"""Using feature distances for FPS.""" """Using feature distances for FPS."""
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
def forward(self, points, features, npoint): def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor:
"""Sampling points with F-FPS.""" """Sampling points with F-FPS."""
assert features is not None, \ assert features is not None, \
'feature input to FFPS_Sampler should not be None' 'feature input to FFPS_Sampler should not be None'
...@@ -163,10 +166,10 @@ class FFPSSampler(nn.Module): ...@@ -163,10 +166,10 @@ class FFPSSampler(nn.Module):
class FSSampler(nn.Module): class FSSampler(nn.Module):
"""Using F-FPS and D-FPS simultaneously.""" """Using F-FPS and D-FPS simultaneously."""
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
def forward(self, points, features, npoint): def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor:
"""Sampling points with FS_Sampling.""" """Sampling points with FS_Sampling."""
assert features is not None, \ assert features is not None, \
'feature input to FS_Sampler should not be None' 'feature input to FS_Sampler should not be None'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment