"...git@developer.sourcefind.cn:OpenDAS/torch-harmonics.git" did not exist on "c877cda61fee6b0fa77f5d8faaa985ad00fc2cab"
Commit 6a9fd47c authored by Xiangxu-0103's avatar Xiangxu-0103 Committed by ZwwWayne
Browse files

[Enhance] Add typehint for models/layers (#2014)

* add typeints for models/layers

* Update builder.py
parent 1a47acdd
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple
import numba import numba
import numpy as np import numpy as np
import torch import torch
from mmcv.ops import nms, nms_rotated from mmcv.ops import nms, nms_rotated
from torch import Tensor
def box3d_multiclass_nms(mlvl_bboxes,
mlvl_bboxes_for_nms, def box3d_multiclass_nms(
mlvl_scores, mlvl_bboxes: Tensor,
score_thr, mlvl_bboxes_for_nms: Tensor,
max_num, mlvl_scores: Tensor,
cfg, score_thr: float,
mlvl_dir_scores=None, max_num: int,
mlvl_attr_scores=None, cfg: dict,
mlvl_bboxes2d=None): mlvl_dir_scores: Optional[Tensor] = None,
mlvl_attr_scores: Optional[Tensor] = None,
mlvl_bboxes2d: Optional[Tensor] = None) -> Tuple[Tensor]:
"""Multi-class NMS for 3D boxes. The IoU used for NMS is defined as the 2D """Multi-class NMS for 3D boxes. The IoU used for NMS is defined as the 2D
IoU between BEV boxes. IoU between BEV boxes.
Args: Args:
mlvl_bboxes (torch.Tensor): Multi-level boxes with shape (N, M). mlvl_bboxes (Tensor): Multi-level boxes with shape (N, M).
M is the dimensions of boxes. M is the dimensions of boxes.
mlvl_bboxes_for_nms (torch.Tensor): Multi-level boxes with shape mlvl_bboxes_for_nms (Tensor): Multi-level boxes with shape (N, 5)
(N, 5) ([x1, y1, x2, y2, ry]). N is the number of boxes. ([x1, y1, x2, y2, ry]). N is the number of boxes.
The coordinate system of the BEV boxes is counterclockwise. The coordinate system of the BEV boxes is counterclockwise.
mlvl_scores (torch.Tensor): Multi-level boxes with shape mlvl_scores (Tensor): Multi-level boxes with shape (N, C + 1).
(N, C + 1). N is the number of boxes. C is the number of classes. N is the number of boxes. C is the number of classes.
score_thr (float): Score threshold to filter boxes with low score_thr (float): Score threshold to filter boxes with low confidence.
confidence.
max_num (int): Maximum number of boxes will be kept. max_num (int): Maximum number of boxes will be kept.
cfg (dict): Configuration dict of NMS. cfg (dict): Configuration dict of NMS.
mlvl_dir_scores (torch.Tensor, optional): Multi-level scores mlvl_dir_scores (Tensor, optional): Multi-level scores of direction
of direction classifier. Defaults to None. classifier. Defaults to None.
mlvl_attr_scores (torch.Tensor, optional): Multi-level scores mlvl_attr_scores (Tensor, optional): Multi-level scores of attribute
of attribute classifier. Defaults to None. classifier. Defaults to None.
mlvl_bboxes2d (torch.Tensor, optional): Multi-level 2D bounding mlvl_bboxes2d (Tensor, optional): Multi-level 2D bounding boxes.
boxes. Defaults to None. Defaults to None.
Returns: Returns:
tuple[torch.Tensor]: Return results after nms, including 3D Tuple[Tensor]: Return results after nms, including 3D bounding boxes,
bounding boxes, scores, labels, direction scores, attribute scores, labels, direction scores, attribute scores (optional) and
scores (optional) and 2D bounding boxes (optional). 2D bounding boxes (optional).
""" """
# do multi class nms # do multi class nms
# the fg class id range: [0, num_classes-1] # the fg class id range: [0, num_classes-1]
...@@ -128,17 +131,18 @@ def box3d_multiclass_nms(mlvl_bboxes, ...@@ -128,17 +131,18 @@ def box3d_multiclass_nms(mlvl_bboxes,
return results return results
def aligned_3d_nms(boxes, scores, classes, thresh): def aligned_3d_nms(boxes: Tensor, scores: Tensor, classes: Tensor,
thresh: float) -> Tensor:
"""3D NMS for aligned boxes. """3D NMS for aligned boxes.
Args: Args:
boxes (torch.Tensor): Aligned box with shape [n, 6]. boxes (Tensor): Aligned box with shape [N, 6].
scores (torch.Tensor): Scores of each box. scores (Tensor): Scores of each box.
classes (torch.Tensor): Class of each box. classes (Tensor): Class of each box.
thresh (float): IoU threshold for nms. thresh (float): IoU threshold for nms.
Returns: Returns:
torch.Tensor: Indices of selected boxes. Tensor: Indices of selected boxes.
""" """
x1 = boxes[:, 0] x1 = boxes[:, 0]
y1 = boxes[:, 1] y1 = boxes[:, 1]
...@@ -179,21 +183,20 @@ def aligned_3d_nms(boxes, scores, classes, thresh): ...@@ -179,21 +183,20 @@ def aligned_3d_nms(boxes, scores, classes, thresh):
@numba.jit(nopython=True) @numba.jit(nopython=True)
def circle_nms(dets, thresh, post_max_size=83): def circle_nms(dets: Tensor, thresh: float, post_max_size: int = 83) -> Tensor:
"""Circular NMS. """Circular NMS.
An object is only counted as positive if no other center An object is only counted as positive if no other center with a higher
with a higher confidence exists within a radius r using a confidence exists within a radius r using a bird-eye view distance metric.
bird-eye view distance metric.
Args: Args:
dets (torch.Tensor): Detection results with the shape of [N, 3]. dets (Tensor): Detection results with the shape of [N, 3].
thresh (float): Value of threshold. thresh (float): Value of threshold.
post_max_size (int, optional): Max number of prediction to be kept. post_max_size (int): Max number of prediction to be kept.
Defaults to 83. Defaults to 83.
Returns: Returns:
torch.Tensor: Indexes of the detections to be kept. Tensor: Indexes of the detections to be kept.
""" """
x1 = dets[:, 0] x1 = dets[:, 0]
y1 = dets[:, 1] y1 = dets[:, 1]
...@@ -228,24 +231,28 @@ def circle_nms(dets, thresh, post_max_size=83): ...@@ -228,24 +231,28 @@ def circle_nms(dets, thresh, post_max_size=83):
# This function duplicates functionality of mmcv.ops.iou_3d.nms_bev # This function duplicates functionality of mmcv.ops.iou_3d.nms_bev
# from mmcv<=1.5, but using cuda ops from mmcv.ops.nms.nms_rotated. # from mmcv<=1.5, but using cuda ops from mmcv.ops.nms.nms_rotated.
# Nms api will be unified in mmdetection3d one day. # Nms api will be unified in mmdetection3d one day.
def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None): def nms_bev(boxes: Tensor,
scores: Tensor,
thresh: float,
pre_max_size: Optional[int] = None,
post_max_size: Optional[int] = None) -> Tensor:
"""NMS function GPU implementation (for BEV boxes). The overlap of two """NMS function GPU implementation (for BEV boxes). The overlap of two
boxes for IoU calculation is defined as the exact overlapping area of the boxes for IoU calculation is defined as the exact overlapping area of the
two boxes. In this function, one can also set ``pre_max_size`` and two boxes. In this function, one can also set ``pre_max_size`` and
``post_max_size``. ``post_max_size``.
Args: Args:
boxes (torch.Tensor): Input boxes with the shape of [N, 5] boxes (Tensor): Input boxes with the shape of [N, 5]
([x1, y1, x2, y2, ry]). ([x1, y1, x2, y2, ry]).
scores (torch.Tensor): Scores of boxes with the shape of [N]. scores (Tensor): Scores of boxes with the shape of [N].
thresh (float): Overlap threshold of NMS. thresh (float): Overlap threshold of NMS.
pre_max_size (int, optional): Max size of boxes before NMS. pre_max_size (int, optional): Max size of boxes before NMS.
Default: None. Defaults to None.
post_max_size (int, optional): Max size of boxes after NMS. post_max_size (int, optional): Max size of boxes after NMS.
Default: None. Defaults to None.
Returns: Returns:
torch.Tensor: Indexes after NMS. Tensor: Indexes after NMS.
""" """
assert boxes.size(1) == 5, 'Input boxes shape should be [N, 5]' assert boxes.size(1) == 5, 'Input boxes shape should be [N, 5]'
order = scores.sort(0, descending=True)[1] order = scores.sort(0, descending=True)[1]
...@@ -271,18 +278,18 @@ def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None): ...@@ -271,18 +278,18 @@ def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None):
# This function duplicates functionality of mmcv.ops.iou_3d.nms_normal_bev # This function duplicates functionality of mmcv.ops.iou_3d.nms_normal_bev
# from mmcv<=1.5, but using cuda ops from mmcv.ops.nms.nms. # from mmcv<=1.5, but using cuda ops from mmcv.ops.nms.nms.
# Nms api will be unified in mmdetection3d one day. # Nms api will be unified in mmdetection3d one day.
def nms_normal_bev(boxes, scores, thresh): def nms_normal_bev(boxes: Tensor, scores: Tensor, thresh: float) -> Tensor:
"""Normal NMS function GPU implementation (for BEV boxes). The overlap of """Normal NMS function GPU implementation (for BEV boxes). The overlap of
two boxes for IoU calculation is defined as the exact overlapping area of two boxes for IoU calculation is defined as the exact overlapping area of
the two boxes WITH their yaw angle set to 0. the two boxes WITH their yaw angle set to 0.
Args: Args:
boxes (torch.Tensor): Input boxes with shape (N, 5). boxes (Tensor): Input boxes with shape (N, 5).
scores (torch.Tensor): Scores of predicted boxes with shape (N). scores (Tensor): Scores of predicted boxes with shape (N).
thresh (float): Overlap threshold of NMS. thresh (float): Overlap threshold of NMS.
Returns: Returns:
torch.Tensor: Remaining indices with scores in descending order. Tensor: Remaining indices with scores in descending order.
""" """
assert boxes.shape[1] == 5, 'Input boxes shape should be [N, 5]' assert boxes.shape[1] == 5, 'Input boxes shape should be [N, 5]'
return nms(boxes[:, :-1], scores, thresh)[1] return nms(boxes[:, :-1], scores, thresh)[1]
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmengine.model import BaseModule from mmengine.model import BaseModule
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmdet3d.utils import ConfigType, OptMultiConfig
class DGCNNFAModule(BaseModule): class DGCNNFAModule(BaseModule):
"""Point feature aggregation module used in DGCNN. """Point feature aggregation module used in DGCNN.
...@@ -11,21 +16,21 @@ class DGCNNFAModule(BaseModule): ...@@ -11,21 +16,21 @@ class DGCNNFAModule(BaseModule):
Aggregate all the features of points. Aggregate all the features of points.
Args: Args:
mlp_channels (list[int]): List of mlp channels. mlp_channels (List[int]): List of mlp channels.
norm_cfg (dict, optional): Type of normalization method. norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
Defaults to dict(type='BN1d'). layer. Defaults to dict(type='BN1d').
act_cfg (dict, optional): Type of activation method. act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
Defaults to dict(type='ReLU'). Defaults to dict(type='ReLU').
init_cfg (dict, optional): Initialization config. Defaults to None. init_cfg (:obj:`ConfigDict` or dict or List[:obj:`Contigdict` or dict],
optional): Initialization config dict. Defaults to None.
""" """
def __init__(self, def __init__(self,
mlp_channels, mlp_channels: List[int],
norm_cfg=dict(type='BN1d'), norm_cfg: ConfigType = dict(type='BN1d'),
act_cfg=dict(type='ReLU'), act_cfg: ConfigType = dict(type='ReLU'),
init_cfg=None): init_cfg: OptMultiConfig = None) -> None:
super().__init__(init_cfg=init_cfg) super(DGCNNFAModule, self).__init__(init_cfg=init_cfg)
self.fp16_enabled = False
self.mlps = nn.Sequential() self.mlps = nn.Sequential()
for i in range(len(mlp_channels) - 1): for i in range(len(mlp_channels) - 1):
self.mlps.add_module( self.mlps.add_module(
...@@ -39,14 +44,14 @@ class DGCNNFAModule(BaseModule): ...@@ -39,14 +44,14 @@ class DGCNNFAModule(BaseModule):
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
act_cfg=act_cfg)) act_cfg=act_cfg))
def forward(self, points): def forward(self, points: List[Tensor]) -> Tensor:
"""forward. """forward.
Args: Args:
points (List[Tensor]): tensor of the features to be aggregated. points (List[Tensor]): Tensor of the features to be aggregated.
Returns: Returns:
Tensor: (B, N, M) M = mlp[-1], tensor of the output points. Tensor: (B, N, M) M = mlp[-1]. Tensor of the output points.
""" """
if len(points) > 1: if len(points) > 1:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmengine.model import BaseModule from mmengine.model import BaseModule
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmdet3d.utils import ConfigType, OptMultiConfig
class DGCNNFPModule(BaseModule): class DGCNNFPModule(BaseModule):
"""Point feature propagation module used in DGCNN. """Point feature propagation module used in DGCNN.
...@@ -10,21 +15,21 @@ class DGCNNFPModule(BaseModule): ...@@ -10,21 +15,21 @@ class DGCNNFPModule(BaseModule):
Propagate the features from one set to another. Propagate the features from one set to another.
Args: Args:
mlp_channels (list[int]): List of mlp channels. mlp_channels (List[int]): List of mlp channels.
norm_cfg (dict, optional): Type of activation method. norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
Defaults to dict(type='BN1d'). layer. Defaults to dict(type='BN1d').
act_cfg (dict, optional): Type of activation method. act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
Defaults to dict(type='ReLU'). Defaults to dict(type='ReLU').
init_cfg (dict, optional): Initialization config. Defaults to None. init_cfg (:obj:`ConfigDict` or dict or List[:obj:`Contigdict` or dict],
optional): Initialization config dict. Defaults to None.
""" """
def __init__(self, def __init__(self,
mlp_channels, mlp_channels: List[int],
norm_cfg=dict(type='BN1d'), norm_cfg: ConfigType = dict(type='BN1d'),
act_cfg=dict(type='ReLU'), act_cfg: ConfigType = dict(type='ReLU'),
init_cfg=None): init_cfg: OptMultiConfig = None) -> None:
super().__init__(init_cfg=init_cfg) super(DGCNNFPModule, self).__init__(init_cfg=init_cfg)
self.fp16_enabled = False
self.mlps = nn.Sequential() self.mlps = nn.Sequential()
for i in range(len(mlp_channels) - 1): for i in range(len(mlp_channels) - 1):
self.mlps.add_module( self.mlps.add_module(
...@@ -38,14 +43,14 @@ class DGCNNFPModule(BaseModule): ...@@ -38,14 +43,14 @@ class DGCNNFPModule(BaseModule):
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
act_cfg=act_cfg)) act_cfg=act_cfg))
def forward(self, points): def forward(self, points: Tensor) -> Tensor:
"""forward. """Forward.
Args: Args:
points (Tensor): (B, N, C) tensor of the input points. points (Tensor): (B, N, C) Tensor of the input points.
Returns: Returns:
Tensor: (B, N, M) M = mlp[-1], tensor of the new points. Tensor: (B, N, M) M = mlp[-1]. Tensor of the new points.
""" """
if points is not None: if points is not None:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.ops.group_points import GroupAll, QueryAndGroup, grouping_operation from mmcv.ops.group_points import GroupAll, QueryAndGroup, grouping_operation
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.utils import ConfigType
class BaseDGCNNGFModule(nn.Module): class BaseDGCNNGFModule(nn.Module):
"""Base module for point graph feature module used in DGCNN. """Base module for point graph feature module used in DGCNN.
Args: Args:
radii (list[float]): List of radius in each knn or ball query. radii (List[float]): List of radius in each knn or ball query.
sample_nums (list[int]): Number of samples in each knn or ball query. sample_nums (List[int]): Number of samples in each knn or ball query.
mlp_channels (list[list[int]]): Specify of the dgcnn before mlp_channels (List[List[int]]): Specify of the dgcnn before the global
the global pooling for each graph feature module. pooling for each graph feature module.
knn_modes (list[str], optional): Type of KNN method, valid mode knn_modes (List[str]): Type of KNN method, valid mode
['F-KNN', 'D-KNN'], Defaults to ['F-KNN']. ['F-KNN', 'D-KNN']. Defaults to ['F-KNN'].
dilated_group (bool, optional): Whether to use dilated ball query. dilated_group (bool): Whether to use dilated ball query.
Defaults to False. Defaults to False.
use_xyz (bool, optional): Whether to use xyz as point features. use_xyz (bool): Whether to use xyz as point features.
Defaults to True. Defaults to True.
pool_mode (str, optional): Type of pooling method. Defaults to 'max'. pool_mode (str): Type of pooling method. Defaults to 'max'.
normalize_xyz (bool, optional): If ball query, whether to normalize normalize_xyz (bool): If ball query, whether to normalize local XYZ
local XYZ with radius. Defaults to False. with radius. Defaults to False.
grouper_return_grouped_xyz (bool, optional): Whether to return grouped grouper_return_grouped_xyz (bool): Whether to return grouped xyz in
xyz in `QueryAndGroup`. Defaults to False. `QueryAndGroup`. Defaults to False.
grouper_return_grouped_idx (bool, optional): Whether to return grouped grouper_return_grouped_idx (bool): Whether to return grouped idx in
idx in `QueryAndGroup`. Defaults to False. `QueryAndGroup`. Defaults to False.
""" """
def __init__(self, def __init__(self,
radii, radii: List[float],
sample_nums, sample_nums: List[int],
mlp_channels, mlp_channels: List[List[int]],
knn_modes=['F-KNN'], knn_modes: List[str] = ['F-KNN'],
dilated_group=False, dilated_group: bool = False,
use_xyz=True, use_xyz: bool = True,
pool_mode='max', pool_mode: str = 'max',
normalize_xyz=False, normalize_xyz: bool = False,
grouper_return_grouped_xyz=False, grouper_return_grouped_xyz: bool = False,
grouper_return_grouped_idx=False): grouper_return_grouped_idx: bool = False) -> None:
super(BaseDGCNNGFModule, self).__init__() super(BaseDGCNNGFModule, self).__init__()
assert len(sample_nums) == len( assert len(sample_nums) == len(
...@@ -82,16 +87,15 @@ class BaseDGCNNGFModule(nn.Module): ...@@ -82,16 +87,15 @@ class BaseDGCNNGFModule(nn.Module):
grouper = GroupAll(use_xyz) grouper = GroupAll(use_xyz)
self.groupers.append(grouper) self.groupers.append(grouper)
def _pool_features(self, features): def _pool_features(self, features: Tensor) -> Tensor:
"""Perform feature aggregation using pooling operation. """Perform feature aggregation using pooling operation.
Args: Args:
features (torch.Tensor): (B, C, N, K) features (Tensor): (B, C, N, K) Features of locally grouped
Features of locally grouped points before pooling. points before pooling.
Returns: Returns:
torch.Tensor: (B, C, N) Tensor: (B, C, N) Pooled features aggregating local information.
Pooled features aggregating local information.
""" """
if self.pool_mode == 'max': if self.pool_mode == 'max':
# (B, C, N, 1) # (B, C, N, 1)
...@@ -106,15 +110,15 @@ class BaseDGCNNGFModule(nn.Module): ...@@ -106,15 +110,15 @@ class BaseDGCNNGFModule(nn.Module):
return new_features.squeeze(-1).contiguous() return new_features.squeeze(-1).contiguous()
def forward(self, points): def forward(self, points: Tensor) -> Tensor:
"""forward. """forward.
Args: Args:
points (Tensor): (B, N, C) input points. points (Tensor): (B, N, C) Input points.
Returns: Returns:
List[Tensor]: (B, N, C1) new points generated from each graph Tensor: (B, N, C1) New points generated from each graph
feature module. feature module.
""" """
new_points_list = [points] new_points_list = [points]
...@@ -155,43 +159,40 @@ class DGCNNGFModule(BaseDGCNNGFModule): ...@@ -155,43 +159,40 @@ class DGCNNGFModule(BaseDGCNNGFModule):
"""Point graph feature module used in DGCNN. """Point graph feature module used in DGCNN.
Args: Args:
mlp_channels (list[int]): Specify of the dgcnn before mlp_channels (List[int]): Specify of the dgcnn before the global
the global pooling for each graph feature module. pooling for each graph feature module.
num_sample (int, optional): Number of samples in each knn or ball num_sample (int, optional): Number of samples in each knn or ball
query. Defaults to None. query. Defaults to None.
knn_mode (str, optional): Type of KNN method, valid mode knn_mode (str): Type of KNN method, valid mode ['F-KNN', 'D-KNN'].
['F-KNN', 'D-KNN']. Defaults to 'F-KNN'. Defaults to 'F-KNN'.
radius (float, optional): Radius to group with. radius (float, optional): Radius to group with. Defaults to None.
Defaults to None. dilated_group (bool): Whether to use dilated ball query.
dilated_group (bool, optional): Whether to use dilated ball query.
Defaults to False. Defaults to False.
norm_cfg (dict, optional): Type of normalization method. norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
Defaults to dict(type='BN2d'). layer. Defaults to dict(type='BN2d').
act_cfg (dict, optional): Type of activation method. act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
Defaults to dict(type='ReLU'). Defaults to dict(type='ReLU').
use_xyz (bool, optional): Whether to use xyz as point features. use_xyz (bool): Whether to use xyz as point features. Defaults to True.
Defaults to True. pool_mode (str): Type of pooling method. Defaults to 'max'.
pool_mode (str, optional): Type of pooling method. normalize_xyz (bool): If ball query, whether to normalize local XYZ
Defaults to 'max'. with radius. Defaults to False.
normalize_xyz (bool, optional): If ball query, whether to normalize bias (bool or str): If specified as `auto`, it will be decided by
local XYZ with radius. Defaults to False. `norm_cfg`. `bias` will be set as True if `norm_cfg` is None,
bias (bool | str, optional): If specified as `auto`, it will be decided
by the norm_cfg. Bias will be set as True if `norm_cfg` is None,
otherwise False. Defaults to 'auto'. otherwise False. Defaults to 'auto'.
""" """
def __init__(self, def __init__(self,
mlp_channels, mlp_channels: List[int],
num_sample=None, num_sample: Optional[int] = None,
knn_mode='F-KNN', knn_mode: str = 'F-KNN',
radius=None, radius: Optional[float] = None,
dilated_group=False, dilated_group: bool = False,
norm_cfg=dict(type='BN2d'), norm_cfg: ConfigType = dict(type='BN2d'),
act_cfg=dict(type='ReLU'), act_cfg: ConfigType = dict(type='ReLU'),
use_xyz=True, use_xyz: bool = True,
pool_mode='max', pool_mode: str = 'max',
normalize_xyz=False, normalize_xyz: bool = False,
bias='auto'): bias: Union[bool, str] = 'auto') -> None:
super(DGCNNGFModule, self).__init__( super(DGCNNGFModule, self).__init__(
mlp_channels=[mlp_channels], mlp_channels=[mlp_channels],
sample_nums=[num_sample], sample_nums=[num_sample],
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmengine.model import BaseModule from mmengine.model import BaseModule
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.utils import ConfigType
class EdgeFusionModule(BaseModule): class EdgeFusionModule(BaseModule):
"""Edge Fusion Module for feature map. """Edge Fusion Module for feature map.
...@@ -12,21 +17,22 @@ class EdgeFusionModule(BaseModule): ...@@ -12,21 +17,22 @@ class EdgeFusionModule(BaseModule):
out_channels (int): The number of output channels. out_channels (int): The number of output channels.
feat_channels (int): The number of channels in feature map feat_channels (int): The number of channels in feature map
during edge feature fusion. during edge feature fusion.
kernel_size (int, optional): Kernel size of convolution. kernel_size (int): Kernel size of convolution. Defaults to 3.
Default: 3. act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
act_cfg (dict, optional): Config of activation. Defaults to dict(type='ReLU').
Default: dict(type='ReLU'). norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
norm_cfg (dict, optional): Config of normalization. layer. Defaults to dict(type='BN1d').
Default: dict(type='BN1d')).
""" """
def __init__(self, def __init__(
out_channels, self,
feat_channels, out_channels: int,
kernel_size=3, feat_channels: int,
act_cfg=dict(type='ReLU'), kernel_size: int = 3,
norm_cfg=dict(type='BN1d')): act_cfg: ConfigType = dict(type='ReLU'),
super().__init__() norm_cfg: ConfigType = dict(type='BN1d')
) -> None:
super(EdgeFusionModule, self).__init__()
self.edge_convs = nn.Sequential( self.edge_convs = nn.Sequential(
ConvModule( ConvModule(
feat_channels, feat_channels,
...@@ -39,22 +45,22 @@ class EdgeFusionModule(BaseModule): ...@@ -39,22 +45,22 @@ class EdgeFusionModule(BaseModule):
nn.Conv1d(feat_channels, out_channels, kernel_size=1)) nn.Conv1d(feat_channels, out_channels, kernel_size=1))
self.feat_channels = feat_channels self.feat_channels = feat_channels
def forward(self, features, fused_features, edge_indices, edge_lens, def forward(self, features: Tensor, fused_features: Tensor,
output_h, output_w): edge_indices: Tensor, edge_lens: List[int], output_h: int,
output_w: int) -> Tensor:
"""Forward pass. """Forward pass.
Args: Args:
features (torch.Tensor): Different representative features features (Tensor): Different representative features for fusion.
for fusion. fused_features (Tensor): Different representative features
fused_features (torch.Tensor): Different representative to be fused.
features to be fused. edge_indices (Tensor): Batch image edge indices.
edge_indices (torch.Tensor): Batch image edge indices. edge_lens (List[int]): List of edge length of each image.
edge_lens (list[int]): List of edge length of each image.
output_h (int): Height of output feature map. output_h (int): Height of output feature map.
output_w (int): Width of output feature map. output_w (int): Width of output feature map.
Returns: Returns:
torch.Tensor: Fused feature maps. Tensor: Fused feature maps.
""" """
batch_size = features.shape[0] batch_size = features.shape[0]
# normalize # normalize
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from functools import partial from functools import partial
from typing import Tuple
import torch import torch
from torch import Tensor
from mmdet3d.structures.points import get_points_type from mmdet3d.structures.points import get_points_type
def apply_3d_transformation(pcd, coord_type, img_meta, reverse=False): def apply_3d_transformation(pcd: Tensor,
coord_type: str,
img_meta: dict,
reverse: bool = False) -> Tensor:
"""Apply transformation to input point cloud. """Apply transformation to input point cloud.
Args: Args:
pcd (torch.Tensor): The point cloud to be transformed. pcd (Tensor): The point cloud to be transformed.
coord_type (str): 'DEPTH' or 'CAMERA' or 'LIDAR'. coord_type (str): 'DEPTH' or 'CAMERA' or 'LIDAR'.
img_meta(dict): Meta info regarding data transformation. img_meta(dict): Meta info regarding data transformation.
reverse (bool): Reversed transformation or not. reverse (bool): Reversed transformation or not. Defaults to False.
Note: Note:
The elements in img_meta['transformation_3d_flow']: The elements in img_meta['transformation_3d_flow']:
"T" stands for translation;
"S" stands for scale; - "T" stands for translation;
"R" stands for rotation; - "S" stands for scale;
"HF" stands for horizontal flip; - "R" stands for rotation;
"VF" stands for vertical flip. - "HF" stands for horizontal flip;
- "VF" stands for vertical flip.
Returns: Returns:
torch.Tensor: The transformed point cloud. Tensor: The transformed point cloud.
""" """
dtype = pcd.dtype dtype = pcd.dtype
...@@ -92,16 +98,18 @@ def apply_3d_transformation(pcd, coord_type, img_meta, reverse=False): ...@@ -92,16 +98,18 @@ def apply_3d_transformation(pcd, coord_type, img_meta, reverse=False):
return pcd.coord return pcd.coord
def extract_2d_info(img_meta, tensor): def extract_2d_info(
img_meta: dict,
tensor: Tensor) -> Tuple[int, int, int, int, Tensor, bool, Tensor]:
"""Extract image augmentation information from img_meta. """Extract image augmentation information from img_meta.
Args: Args:
img_meta(dict): Meta info regarding data transformation. img_meta (dict): Meta info regarding data transformation.
tensor(torch.Tensor): Input tensor used to create new ones. tensor (Tensor): Input tensor used to create new ones.
Returns: Returns:
(int, int, int, int, torch.Tensor, bool, torch.Tensor): Tuple[int, int, int, int, torch.Tensor, bool, torch.Tensor]:
The extracted information. The extracted information.
""" """
img_shape = img_meta['img_shape'] img_shape = img_meta['img_shape']
ori_shape = img_meta['ori_shape'] ori_shape = img_meta['ori_shape']
...@@ -120,17 +128,17 @@ def extract_2d_info(img_meta, tensor): ...@@ -120,17 +128,17 @@ def extract_2d_info(img_meta, tensor):
img_crop_offset) img_crop_offset)
def bbox_2d_transform(img_meta, bbox_2d, ori2new): def bbox_2d_transform(img_meta: dict, bbox_2d: Tensor,
ori2new: bool) -> Tensor:
"""Transform 2d bbox according to img_meta. """Transform 2d bbox according to img_meta.
Args: Args:
img_meta(dict): Meta info regarding data transformation. img_meta (dict): Meta info regarding data transformation.
bbox_2d (torch.Tensor): Shape (..., >4) bbox_2d (Tensor): Shape (..., >4) The input 2d bboxes to transform.
The input 2d bboxes to transform.
ori2new (bool): Origin img coord system to new or not. ori2new (bool): Origin img coord system to new or not.
Returns: Returns:
torch.Tensor: The transformed 2d bboxes. Tensor: The transformed 2d bboxes.
""" """
img_h, img_w, ori_h, ori_w, img_scale_factor, img_flip, \ img_h, img_w, ori_h, ori_w, img_scale_factor, img_flip, \
...@@ -174,17 +182,17 @@ def bbox_2d_transform(img_meta, bbox_2d, ori2new): ...@@ -174,17 +182,17 @@ def bbox_2d_transform(img_meta, bbox_2d, ori2new):
return bbox_2d_new return bbox_2d_new
def coord_2d_transform(img_meta, coord_2d, ori2new): def coord_2d_transform(img_meta: dict, coord_2d: Tensor,
ori2new: bool) -> Tensor:
"""Transform 2d pixel coordinates according to img_meta. """Transform 2d pixel coordinates according to img_meta.
Args: Args:
img_meta(dict): Meta info regarding data transformation. img_meta (dict): Meta info regarding data transformation.
coord_2d (torch.Tensor): Shape (..., 2) coord_2d (Tensor): Shape (..., 2) The input 2d coords to transform.
The input 2d coords to transform.
ori2new (bool): Origin img coord system to new or not. ori2new (bool): Origin img coord system to new or not.
Returns: Returns:
torch.Tensor: The transformed 2d coordinates. Tensor: The transformed 2d coordinates.
""" """
img_h, img_w, ori_h, ori_w, img_scale_factor, img_flip, \ img_h, img_w, ori_h, ori_w, img_scale_factor, img_flip, \
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple, Union
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmengine.model import BaseModule from mmengine.model import BaseModule
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet3d.structures.bbox_3d import (get_proj_mat_by_coord_type, from mmdet3d.structures.bbox_3d import (get_proj_mat_by_coord_type,
points_cam2img, points_img2cam) points_cam2img, points_img2cam)
from mmdet3d.utils import OptConfigType, OptMultiConfig
from . import apply_3d_transformation from . import apply_3d_transformation
def point_sample(img_meta, def point_sample(img_meta: dict,
img_features, img_features: Tensor,
points, points: Tensor,
proj_mat, proj_mat: Tensor,
coord_type, coord_type: str,
img_scale_factor, img_scale_factor: Tensor,
img_crop_offset, img_crop_offset: Tensor,
img_flip, img_flip: bool,
img_pad_shape, img_pad_shape: Tuple[int],
img_shape, img_shape: Tuple[int],
aligned=True, aligned: bool = True,
padding_mode='zeros', padding_mode: str = 'zeros',
align_corners=True, align_corners: bool = True,
valid_flag=False): valid_flag: bool = False) -> Tensor:
"""Obtain image features using points. """Obtain image features using points.
Args: Args:
img_meta (dict): Meta info. img_meta (dict): Meta info.
img_features (torch.Tensor): 1 x C x H x W image features. img_features (Tensor): 1 x C x H x W image features.
points (torch.Tensor): Nx3 point cloud in LiDAR coordinates. points (Tensor): Nx3 point cloud in LiDAR coordinates.
proj_mat (torch.Tensor): 4x4 transformation matrix. proj_mat (Tensor): 4x4 transformation matrix.
coord_type (str): 'DEPTH' or 'CAMERA' or 'LIDAR'. coord_type (str): 'DEPTH' or 'CAMERA' or 'LIDAR'.
img_scale_factor (torch.Tensor): Scale factor with shape of img_scale_factor (Tensor): Scale factor with shape of
(w_scale, h_scale). (w_scale, h_scale).
img_crop_offset (torch.Tensor): Crop offset used to crop img_crop_offset (Tensor): Crop offset used to crop image during
image during data augmentation with shape of (w_offset, h_offset). data augmentation with shape of (w_offset, h_offset).
img_flip (bool): Whether the image is flipped. img_flip (bool): Whether the image is flipped.
img_pad_shape (tuple[int]): int tuple indicates the h & w after img_pad_shape (Tuple[int]): Int tuple indicates the h & w after
padding, this is necessary to obtain features in feature map. padding. This is necessary to obtain features in feature map.
img_shape (tuple[int]): int tuple indicates the h & w before padding img_shape (Tuple[int]): Int tuple indicates the h & w before padding
after scaling, this is necessary for flipping coordinates. after scaling. This is necessary for flipping coordinates.
aligned (bool): Whether use bilinear interpolation when aligned (bool): Whether to use bilinear interpolation when
sampling image features for each point. Defaults to True. sampling image features for each point. Defaults to True.
padding_mode (str): Padding mode when padding values for padding_mode (str): Padding mode when padding values for
features of out-of-image points. Defaults to 'zeros'. features of out-of-image points. Defaults to 'zeros'.
align_corners (bool): Whether to align corners when align_corners (bool): Whether to align corners when
sampling image features for each point. Defaults to True. sampling image features for each point. Defaults to True.
valid_flag (bool): Whether to filter out the points that valid_flag (bool): Whether to filter out the points that outside
outside the image and with depth smaller than 0. Defaults to the image and with depth smaller than 0. Defaults to False.
False.
Returns: Returns:
torch.Tensor: NxC image features sampled by point coordinates. Tensor: NxC image features sampled by point coordinates.
""" """
# apply transformation based on info in img_meta # apply transformation based on info in img_meta
...@@ -114,55 +117,55 @@ class PointFusion(BaseModule): ...@@ -114,55 +117,55 @@ class PointFusion(BaseModule):
"""Fuse image features from multi-scale features. """Fuse image features from multi-scale features.
Args: Args:
img_channels (list[int] | int): Channels of image features. img_channels (List[int] or int): Channels of image features.
It could be a list if the input is multi-scale image features. It could be a list if the input is multi-scale image features.
pts_channels (int): Channels of point features pts_channels (int): Channels of point features
mid_channels (int): Channels of middle layers mid_channels (int): Channels of middle layers
out_channels (int): Channels of output fused features out_channels (int): Channels of output fused features
img_levels (int, optional): Number of image levels. Defaults to 3. img_levels (List[int] or int): Number of image levels. Defaults to 3.
coord_type (str): 'DEPTH' or 'CAMERA' or 'LIDAR'. coord_type (str): 'DEPTH' or 'CAMERA' or 'LIDAR'. Defaults to 'LIDAR'.
Defaults to 'LIDAR'. conv_cfg (:obj:`ConfigDict` or dict): Config dict for convolution
conv_cfg (dict, optional): Dict config of conv layers of middle layers of middle layers. Defaults to None.
layers. Defaults to None. norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
norm_cfg (dict, optional): Dict config of norm layers of middle layers of middle layers. Defaults to None.
layers. Defaults to None. act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
act_cfg (dict, optional): Dict config of activatation layers.
Defaults to None. Defaults to None.
activate_out (bool, optional): Whether to apply relu activation init_cfg (:obj:`ConfigDict` or dict or List[:obj:`Contigdict` or dict],
to output features. Defaults to True. optional): Initialization config dict. Defaults to None.
fuse_out (bool, optional): Whether apply conv layer to the fused activate_out (bool): Whether to apply relu activation to output
features. Defaults to False. features. Defaults to True.
dropout_ratio (int, float, optional): Dropout ratio of image fuse_out (bool): Whether to apply conv layer to the fused features.
features to prevent overfitting. Defaults to 0. Defaults to False.
aligned (bool, optional): Whether apply aligned feature fusion. dropout_ratio (int or float): Dropout ratio of image features to
prevent overfitting. Defaults to 0.
aligned (bool): Whether to apply aligned feature fusion.
Defaults to True.
align_corners (bool): Whether to align corner when sampling features
according to points. Defaults to True.
padding_mode (str): Mode used to pad the features of points that do not
have corresponding image features. Defaults to 'zeros'.
lateral_conv (bool): Whether to apply lateral convs to image features.
Defaults to True. Defaults to True.
align_corners (bool, optional): Whether to align corner when
sampling features according to points. Defaults to True.
padding_mode (str, optional): Mode used to pad the features of
points that do not have corresponding image features.
Defaults to 'zeros'.
lateral_conv (bool, optional): Whether to apply lateral convs
to image features. Defaults to True.
""" """
def __init__(self, def __init__(self,
img_channels, img_channels: Union[List[int], int],
pts_channels, pts_channels: int,
mid_channels, mid_channels: int,
out_channels, out_channels: int,
img_levels=3, img_levels: Union[List[int], int] = 3,
coord_type='LIDAR', coord_type: str = 'LIDAR',
conv_cfg=None, conv_cfg: OptConfigType = None,
norm_cfg=None, norm_cfg: OptConfigType = None,
act_cfg=None, act_cfg: OptConfigType = None,
init_cfg=None, init_cfg: OptMultiConfig = None,
activate_out=True, activate_out: bool = True,
fuse_out=False, fuse_out: bool = False,
dropout_ratio=0, dropout_ratio: Union[int, float] = 0,
aligned=True, aligned: bool = True,
align_corners=True, align_corners: bool = True,
padding_mode='zeros', padding_mode: str = 'zeros',
lateral_conv=True): lateral_conv: bool = True) -> None:
super(PointFusion, self).__init__(init_cfg=init_cfg) super(PointFusion, self).__init__(init_cfg=init_cfg)
if isinstance(img_levels, int): if isinstance(img_levels, int):
img_levels = [img_levels] img_levels = [img_levels]
...@@ -225,18 +228,19 @@ class PointFusion(BaseModule): ...@@ -225,18 +228,19 @@ class PointFusion(BaseModule):
dict(type='Xavier', layer='Linear', distribution='uniform') dict(type='Xavier', layer='Linear', distribution='uniform')
] ]
def forward(self, img_feats, pts, pts_feats, img_metas): def forward(self, img_feats: List[Tensor], pts: List[Tensor],
pts_feats: Tensor, img_metas: List[dict]) -> Tensor:
"""Forward function. """Forward function.
Args: Args:
img_feats (list[torch.Tensor]): Image features. img_feats (List[Tensor]): Image features.
pts: [list[torch.Tensor]]: A batch of points with shape N x 3. pts: (List[Tensor]): A batch of points with shape N x 3.
pts_feats (torch.Tensor): A tensor consist of point features of the pts_feats (Tensor): A tensor consist of point features of the
total batch. total batch.
img_metas (list[dict]): Meta information of images. img_metas (List[dict]): Meta information of images.
Returns: Returns:
torch.Tensor: Fused features of each point. Tensor: Fused features of each point.
""" """
img_pts = self.obtain_mlvl_feats(img_feats, pts, img_metas) img_pts = self.obtain_mlvl_feats(img_feats, pts, img_metas)
img_pre_fuse = self.img_transform(img_pts) img_pre_fuse = self.img_transform(img_pts)
...@@ -252,17 +256,18 @@ class PointFusion(BaseModule): ...@@ -252,17 +256,18 @@ class PointFusion(BaseModule):
return fuse_out return fuse_out
def obtain_mlvl_feats(self, img_feats, pts, img_metas): def obtain_mlvl_feats(self, img_feats: List[Tensor], pts: List[Tensor],
img_metas: List[dict]) -> Tensor:
"""Obtain multi-level features for each point. """Obtain multi-level features for each point.
Args: Args:
img_feats (list(torch.Tensor)): Multi-scale image features produced img_feats (List[Tensor]): Multi-scale image features produced
by image backbone in shape (N, C, H, W). by image backbone in shape (N, C, H, W).
pts (list[torch.Tensor]): Points of each sample. pts (List[Tensor]): Points of each sample.
img_metas (list[dict]): Meta information for each sample. img_metas (List[dict]): Meta information for each sample.
Returns: Returns:
torch.Tensor: Corresponding image features of each point. Tensor: Corresponding image features of each point.
""" """
if self.lateral_convs is not None: if self.lateral_convs is not None:
img_ins = [ img_ins = [
...@@ -285,17 +290,17 @@ class PointFusion(BaseModule): ...@@ -285,17 +290,17 @@ class PointFusion(BaseModule):
img_pts = torch.cat(img_feats_per_point, dim=0) img_pts = torch.cat(img_feats_per_point, dim=0)
return img_pts return img_pts
def sample_single(self, img_feats, pts, img_meta): def sample_single(self, img_feats: Tensor, pts: Tensor,
img_meta: dict) -> Tensor:
"""Sample features from single level image feature map. """Sample features from single level image feature map.
Args: Args:
img_feats (torch.Tensor): Image feature map in shape img_feats (Tensor): Image feature map in shape (1, C, H, W).
(1, C, H, W). pts (Tensor): Points of a single sample.
pts (torch.Tensor): Points of a single sample.
img_meta (dict): Meta information of the single sample. img_meta (dict): Meta information of the single sample.
Returns: Returns:
torch.Tensor: Single level image features of each point. Tensor: Single level image features of each point.
""" """
# TODO: image transformation also extracted # TODO: image transformation also extracted
img_scale_factor = ( img_scale_factor = (
...@@ -324,49 +329,47 @@ class PointFusion(BaseModule): ...@@ -324,49 +329,47 @@ class PointFusion(BaseModule):
return img_pts return img_pts
def voxel_sample(voxel_features, def voxel_sample(voxel_features: Tensor,
voxel_range, voxel_range: List[float],
voxel_size, voxel_size: List[float],
depth_samples, depth_samples: Tensor,
proj_mat, proj_mat: Tensor,
downsample_factor, downsample_factor: int,
img_scale_factor, img_scale_factor: Tensor,
img_crop_offset, img_crop_offset: Tensor,
img_flip, img_flip: bool,
img_pad_shape, img_pad_shape: Tuple[int],
img_shape, img_shape: Tuple[int],
aligned=True, aligned: bool = True,
padding_mode='zeros', padding_mode: str = 'zeros',
align_corners=True): align_corners: bool = True) -> Tensor:
"""Obtain image features using points. """Obtain image features using points.
Args: Args:
voxel_features (torch.Tensor): 1 x C x Nx x Ny x Nz voxel features. voxel_features (Tensor): 1 x C x Nx x Ny x Nz voxel features.
voxel_range (list): The range of voxel features. voxel_range (List[float]): The range of voxel features.
voxel_size (:obj:`ConfigDict` or dict): The voxel size of voxel voxel_size (List[float]): The voxel size of voxel features.
features. depth_samples (Tensor): N depth samples in LiDAR coordinates.
depth_samples (torch.Tensor): N depth samples in LiDAR coordinates. proj_mat (Tensor): ORIGINAL LiDAR2img projection matrix for N views.
proj_mat (torch.Tensor): ORIGINAL LiDAR2img projection matrix
for N views.
downsample_factor (int): The downsample factor in rescaling. downsample_factor (int): The downsample factor in rescaling.
img_scale_factor (tuple[torch.Tensor]): Scale factor with shape of img_scale_factor (Tensor): Scale factor with shape of
(w_scale, h_scale). (w_scale, h_scale).
img_crop_offset (tuple[torch.Tensor]): Crop offset used to crop img_crop_offset (Tensor): Crop offset used to crop image during
image during data augmentation with shape of (w_offset, h_offset). data augmentation with shape of (w_offset, h_offset).
img_flip (bool): Whether the image is flipped. img_flip (bool): Whether the image is flipped.
img_pad_shape (tuple[int]): int tuple indicates the h & w after img_pad_shape (Tuple[int]): Int tuple indicates the h & w after
padding, this is necessary to obtain features in feature map. padding. This is necessary to obtain features in feature map.
img_shape (tuple[int]): int tuple indicates the h & w before padding img_shape (Tuple[int]): Int tuple indicates the h & w before padding
after scaling, this is necessary for flipping coordinates. after scaling. This is necessary for flipping coordinates.
aligned (bool, optional): Whether use bilinear interpolation when aligned (bool): Whether to use bilinear interpolation when
sampling image features for each point. Defaults to True. sampling image features for each point. Defaults to True.
padding_mode (str, optional): Padding mode when padding values for padding_mode (str): Padding mode when padding values for
features of out-of-image points. Defaults to 'zeros'. features of out-of-image points. Defaults to 'zeros'.
align_corners (bool, optional): Whether to align corners when align_corners (bool): Whether to align corners when
sampling image features for each point. Defaults to True. sampling image features for each point. Defaults to True.
Returns: Returns:
torch.Tensor: 1xCxDxHxW frustum features sampled from voxel features. Tensor: 1xCxDxHxW frustum features sampled from voxel features.
""" """
# construct frustum grid # construct frustum grid
device = voxel_features.device device = voxel_features.device
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch import torch
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
...@@ -14,27 +17,33 @@ class VoteFusion(nn.Module): ...@@ -14,27 +17,33 @@ class VoteFusion(nn.Module):
"""Fuse 2d features from 3d seeds. """Fuse 2d features from 3d seeds.
Args: Args:
num_classes (int): number of classes. num_classes (int): Number of classes.
max_imvote_per_pixel (int): max number of imvotes. max_imvote_per_pixel (int): Max number of imvotes.
""" """
def __init__(self, num_classes=10, max_imvote_per_pixel=3): def __init__(self,
num_classes: int = 10,
max_imvote_per_pixel: int = 3) -> None:
super(VoteFusion, self).__init__() super(VoteFusion, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.max_imvote_per_pixel = max_imvote_per_pixel self.max_imvote_per_pixel = max_imvote_per_pixel
def forward(self, imgs, bboxes_2d_rescaled, seeds_3d_depth, img_metas): def forward(self, imgs: List[Tensor], bboxes_2d_rescaled: List[Tensor],
seeds_3d_depth: List[Tensor],
img_metas: List[dict]) -> Tuple[Tensor]:
"""Forward function. """Forward function.
Args: Args:
imgs (list[torch.Tensor]): Image features. imgs (List[Tensor]): Image features.
bboxes_2d_rescaled (list[torch.Tensor]): 2D bboxes. bboxes_2d_rescaled (List[Tensor]): 2D bboxes.
seeds_3d_depth (torch.Tensor): 3D seeds. seeds_3d_depth (List[Tensor]): 3D seeds.
img_metas (list[dict]): Meta information of images. img_metas (List[dict]): Meta information of images.
Returns: Returns:
torch.Tensor: Concatenated cues of each point. Tuple[Tensor]:
torch.Tensor: Validity mask of each feature.
- img_features: Concatenated cues of each point.
- masks: Validity mask of each feature.
""" """
img_features = [] img_features = []
masks = [] masks = []
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmengine.model import BaseModule from mmengine.model import BaseModule
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmdet3d.utils import ConfigType, OptMultiConfig
class MLP(BaseModule): class MLP(BaseModule):
"""A simple MLP module. """A simple MLP module.
...@@ -10,26 +15,28 @@ class MLP(BaseModule): ...@@ -10,26 +15,28 @@ class MLP(BaseModule):
Pass features (B, C, N) through an MLP. Pass features (B, C, N) through an MLP.
Args: Args:
in_channels (int, optional): Number of channels of input features. in_channels (int): Number of channels of input features.
Default: 18. Defaults to 18.
conv_channels (tuple[int], optional): Out channels of the convolution. conv_channels (Tuple[int]): Out channels of the convolution.
Default: (256, 256). Defaults to (256, 256).
conv_cfg (dict, optional): Config of convolution. conv_cfg (:obj:`ConfigDict` or dict): Config dict for convolution
Default: dict(type='Conv1d'). layer. Defaults to dict(type='Conv1d').
norm_cfg (dict, optional): Config of normalization. norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
Default: dict(type='BN1d'). layer. Defaults to dict(type='BN1d').
act_cfg (dict, optional): Config of activation. act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
Default: dict(type='ReLU'). Defaults to dict(type='ReLU').
init_cfg (:obj:`ConfigDict` or dict or List[:obj:`Contigdict` or dict],
optional): Initialization config dict. Defaults to None.
""" """
def __init__(self, def __init__(self,
in_channel=18, in_channel: int = 18,
conv_channels=(256, 256), conv_channels: Tuple[int] = (256, 256),
conv_cfg=dict(type='Conv1d'), conv_cfg: ConfigType = dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'), norm_cfg: ConfigType = dict(type='BN1d'),
act_cfg=dict(type='ReLU'), act_cfg: ConfigType = dict(type='ReLU'),
init_cfg=None): init_cfg: OptMultiConfig = None) -> None:
super().__init__(init_cfg=init_cfg) super(MLP, self).__init__(init_cfg=init_cfg)
self.mlp = nn.Sequential() self.mlp = nn.Sequential()
prev_channels = in_channel prev_channels = in_channel
for i, conv_channel in enumerate(conv_channels): for i, conv_channel in enumerate(conv_channels):
...@@ -47,5 +54,5 @@ class MLP(BaseModule): ...@@ -47,5 +54,5 @@ class MLP(BaseModule):
inplace=True)) inplace=True))
prev_channels = conv_channels[i] prev_channels = conv_channels[i]
def forward(self, img_features): def forward(self, img_features: Tensor) -> Tensor:
return self.mlp(img_features) return self.mlp(img_features)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
from mmengine.registry import MODELS from mmengine.registry import MODELS
from torch import Tensor
from torch import distributed as dist from torch import distributed as dist
from torch import nn as nn from torch import nn as nn
from torch.autograd.function import Function from torch.autograd.function import Function
...@@ -9,7 +10,7 @@ from torch.autograd.function import Function ...@@ -9,7 +10,7 @@ from torch.autograd.function import Function
class AllReduce(Function): class AllReduce(Function):
@staticmethod @staticmethod
def forward(ctx, input): def forward(ctx, input: Tensor) -> Tensor:
input_list = [ input_list = [
torch.zeros_like(input) for k in range(dist.get_world_size()) torch.zeros_like(input) for k in range(dist.get_world_size())
] ]
...@@ -19,7 +20,7 @@ class AllReduce(Function): ...@@ -19,7 +20,7 @@ class AllReduce(Function):
return torch.sum(inputs, dim=0) return torch.sum(inputs, dim=0)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output: Tensor) -> Tensor:
dist.all_reduce(grad_output, async_op=False) dist.all_reduce(grad_output, async_op=False)
return grad_output return grad_output
...@@ -43,20 +44,18 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d): ...@@ -43,20 +44,18 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d):
It is slower than `nn.SyncBatchNorm`. It is slower than `nn.SyncBatchNorm`.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args: list, **kwargs: dict) -> None:
super().__init__(*args, **kwargs) super(NaiveSyncBatchNorm1d, self).__init__(*args, **kwargs)
self.fp16_enabled = False
def forward(self, input): def forward(self, input: Tensor) -> Tensor:
""" """
Args: Args:
input (tensor): Has shape (N, C) or (N, C, L), where N is input (Tensor): Has shape (N, C) or (N, C, L), where N is
the batch size, C is the number of features or the batch size, C is the number of features or
channels, and L is the sequence length channels, and L is the sequence length
Returns: Returns:
tensor: Has shape (N, C) or (N, C, L), has same shape Tensor: Has shape (N, C) or (N, C, L), same shape as input.
as input.
""" """
assert input.dtype == torch.float32, \ assert input.dtype == torch.float32, \
f'input should be in float32 type, got {input.dtype}' f'input should be in float32 type, got {input.dtype}'
...@@ -112,17 +111,16 @@ class NaiveSyncBatchNorm2d(nn.BatchNorm2d): ...@@ -112,17 +111,16 @@ class NaiveSyncBatchNorm2d(nn.BatchNorm2d):
It is slower than `nn.SyncBatchNorm`. It is slower than `nn.SyncBatchNorm`.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args: list, **kwargs: dict) -> None:
super().__init__(*args, **kwargs) super(NaiveSyncBatchNorm2d, self).__init__(*args, **kwargs)
self.fp16_enabled = False
def forward(self, input): def forward(self, input: Tensor) -> Tensor:
""" """
Args: Args:
Input (tensor): Feature has shape (N, C, H, W). Input (Tensor): Feature has shape (N, C, H, W).
Returns: Returns:
tensor: Has shape (N, C, H, W), same shape as input. Tensor: Has shape (N, C, H, W), same shape as input.
""" """
assert input.dtype == torch.float32, \ assert input.dtype == torch.float32, \
f'input should be in float32 type, got {input.dtype}' f'input should be in float32 type, got {input.dtype}'
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
from typing import List, Tuple, Union
import torch import torch
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
from mmcv.ops import assign_score_withk as assign_score_cuda from mmcv.ops import assign_score_withk as assign_score_cuda
from mmengine.model import constant_init from mmengine.model import constant_init
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.utils import ConfigType
from .utils import assign_kernel_withoutk, assign_score, calc_euclidian_dist from .utils import assign_kernel_withoutk, assign_score, calc_euclidian_dist
...@@ -17,33 +20,33 @@ class ScoreNet(nn.Module): ...@@ -17,33 +20,33 @@ class ScoreNet(nn.Module):
Args: Args:
mlp_channels (List[int]): Hidden unit sizes of SharedMLP layers. mlp_channels (List[int]): Hidden unit sizes of SharedMLP layers.
last_bn (bool, optional): Whether to use BN on the last output of mlps. last_bn (bool): Whether to use BN on the last output of mlps.
Defaults to False. Defaults to False.
score_norm (str, optional): Normalization function of output scores. score_norm (str): Normalization function of output scores.
Can be 'softmax', 'sigmoid' or 'identity'. Defaults to 'softmax'. Can be 'softmax', 'sigmoid' or 'identity'. Defaults to 'softmax'.
temp_factor (float, optional): Temperature factor to scale the output temp_factor (float): Temperature factor to scale the output
scores before softmax. Defaults to 1.0. scores before softmax. Defaults to 1.0.
norm_cfg (dict, optional): Type of normalization method. norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
Defaults to dict(type='BN2d'). layer. Defaults to dict(type='BN2d').
bias (bool | str, optional): If specified as `auto`, it will be decided bias (bool or str): If specified as `auto`, it will be decided by
by the norm_cfg. Bias will be set as True if `norm_cfg` is None, `norm_cfg`. `bias` will be set as True if `norm_cfg` is None,
otherwise False. Defaults to 'auto'. otherwise False. Defaults to 'auto'.
Note: Note:
The official code applies xavier_init to all Conv layers in ScoreNet, The official code applies xavier_init to all Conv layers in ScoreNet,
see `PAConv <https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg see `PAConv <https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg
/model/pointnet2/paconv.py#L105>`_. However in our experiments, we /model/pointnet2/paconv.py#L105>`_. However in our experiments, we
did not find much difference in applying such xavier initialization did not find much difference in applying such xavier initialization
or not. So we neglect this initialization in our implementation. or not. So we neglect this initialization in our implementation.
""" """
def __init__(self, def __init__(self,
mlp_channels, mlp_channels: List[int],
last_bn=False, last_bn: bool = False,
score_norm='softmax', score_norm: str = 'softmax',
temp_factor=1.0, temp_factor: float = 1.0,
norm_cfg=dict(type='BN2d'), norm_cfg: ConfigType = dict(type='BN2d'),
bias='auto'): bias: Union[bool, str] = 'auto') -> None:
super(ScoreNet, self).__init__() super(ScoreNet, self).__init__()
assert score_norm in ['softmax', 'sigmoid', 'identity'], \ assert score_norm in ['softmax', 'sigmoid', 'identity'], \
...@@ -79,16 +82,16 @@ class ScoreNet(nn.Module): ...@@ -79,16 +82,16 @@ class ScoreNet(nn.Module):
act_cfg=None, act_cfg=None,
bias=bias)) bias=bias))
def forward(self, xyz_features): def forward(self, xyz_features: Tensor) -> Tensor:
"""Forward. """Forward.
Args: Args:
xyz_features (torch.Tensor): (B, C, N, K), features constructed xyz_features (Tensor): (B, C, N, K) Features constructed from xyz
from xyz coordinates of point pairs. May contain relative coordinates of point pairs. May contain relative positions,
positions, Euclidean distance, etc. Euclidean distance, etc.
Returns: Returns:
torch.Tensor: (B, N, K, M), predicted scores for `M` kernels. Tensor: (B, N, K, M) Predicted scores for `M` kernels.
""" """
scores = self.mlps(xyz_features) # (B, M, N, K) scores = self.mlps(xyz_features) # (B, M, N, K)
...@@ -116,43 +119,49 @@ class PAConv(nn.Module): ...@@ -116,43 +119,49 @@ class PAConv(nn.Module):
in_channels (int): Input channels of point features. in_channels (int): Input channels of point features.
out_channels (int): Output channels of point features. out_channels (int): Output channels of point features.
num_kernels (int): Number of kernel weights in the weight bank. num_kernels (int): Number of kernel weights in the weight bank.
norm_cfg (dict, optional): Type of normalization method. norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
Defaults to dict(type='BN2d', momentum=0.1). layer. Defaults to dict(type='BN2d', momentum=0.1).
act_cfg (dict, optional): Type of activation method. act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
Defaults to dict(type='ReLU', inplace=True). Defaults to dict(type='ReLU', inplace=True).
scorenet_input (str, optional): Type of input to ScoreNet. scorenet_input (str): Type of input to ScoreNet.
Can be 'identity', 'w_neighbor' or 'w_neighbor_dist'. Can be 'identity', 'w_neighbor' or 'w_neighbor_dist'.
Defaults to 'w_neighbor_dist'. Defaults to 'w_neighbor_dist'.
weight_bank_init (str, optional): Init method of weight bank kernels. weight_bank_init (str): Init method of weight bank kernels.
Can be 'kaiming' or 'xavier'. Defaults to 'kaiming'. Can be 'kaiming' or 'xavier'. Defaults to 'kaiming'.
kernel_input (str, optional): Input features to be multiplied with kernel_input (str): Input features to be multiplied with kernel
kernel weights. Can be 'identity' or 'w_neighbor'. weights. Can be 'identity' or 'w_neighbor'.
Defaults to 'w_neighbor'. Defaults to 'w_neighbor'.
scorenet_cfg (dict, optional): Config of the ScoreNet module, which scorenet_cfg (dict): Config of the ScoreNet module, which may contain
may contain the following keys and values: the following keys and values:
- mlp_channels (List[int]): Hidden units of MLPs. - mlp_channels (List[int]): Hidden units of MLPs.
- score_norm (str): Normalization function of output scores. - score_norm (str): Normalization function of output scores.
Can be 'softmax', 'sigmoid' or 'identity'. Can be 'softmax', 'sigmoid' or 'identity'.
- temp_factor (float): Temperature factor to scale the output - temp_factor (float): Temperature factor to scale the output
scores before softmax. scores before softmax.
- last_bn (bool): Whether to use BN on the last output of mlps. - last_bn (bool): Whether to use BN on the last output of mlps.
Defaults to dict(mlp_channels=[16, 16, 16],
score_norm='softmax',
temp_factor=1.0,
last_bn=False).
""" """
def __init__(self, def __init__(
in_channels, self,
out_channels, in_channels: int,
num_kernels, out_channels: int,
norm_cfg=dict(type='BN2d', momentum=0.1), num_kernels: int,
act_cfg=dict(type='ReLU', inplace=True), norm_cfg: ConfigType = dict(type='BN2d', momentum=0.1),
scorenet_input='w_neighbor_dist', act_cfg: ConfigType = dict(type='ReLU', inplace=True),
weight_bank_init='kaiming', scorenet_input: str = 'w_neighbor_dist',
kernel_input='w_neighbor', weight_bank_init: str = 'kaiming',
scorenet_cfg=dict( kernel_input: str = 'w_neighbor',
mlp_channels=[16, 16, 16], scorenet_cfg: dict = dict(
score_norm='softmax', mlp_channels=[16, 16, 16],
temp_factor=1.0, score_norm='softmax',
last_bn=False)): temp_factor=1.0,
last_bn=False)
) -> None:
super(PAConv, self).__init__() super(PAConv, self).__init__()
# determine weight kernel size according to used features # determine weight kernel size according to used features
...@@ -218,21 +227,20 @@ class PAConv(nn.Module): ...@@ -218,21 +227,20 @@ class PAConv(nn.Module):
self.init_weights() self.init_weights()
def init_weights(self): def init_weights(self) -> None:
"""Initialize weights of shared MLP layers and BN layers.""" """Initialize weights of shared MLP layers and BN layers."""
if self.bn is not None: if self.bn is not None:
constant_init(self.bn, val=1, bias=0) constant_init(self.bn, val=1, bias=0)
def _prepare_scorenet_input(self, points_xyz): def _prepare_scorenet_input(self, points_xyz: Tensor) -> Tensor:
"""Prepare input point pairs features for self.ScoreNet. """Prepare input point pairs features for self.ScoreNet.
Args: Args:
points_xyz (torch.Tensor): (B, 3, npoint, K) points_xyz (Tensor): (B, 3, npoint, K) Coordinates of the
Coordinates of the grouped points. grouped points.
Returns: Returns:
torch.Tensor: (B, C, npoint, K) Tensor: (B, C, npoint, K) The generated features per point pair.
The generated features per point pair.
""" """
B, _, npoint, K = points_xyz.size() B, _, npoint, K = points_xyz.size()
center_xyz = points_xyz[..., :1].repeat(1, 1, 1, K) center_xyz = points_xyz[..., :1].repeat(1, 1, 1, K)
...@@ -250,22 +258,22 @@ class PAConv(nn.Module): ...@@ -250,22 +258,22 @@ class PAConv(nn.Module):
dim=1) dim=1)
return xyz_features return xyz_features
def forward(self, inputs): def forward(self, inputs: Tuple[Tensor]) -> Tuple[Tensor]:
"""Forward. """Forward.
Args: Args:
inputs (tuple(torch.Tensor)): inputs (Tuple[Tensor]):
- features (torch.Tensor): (B, in_c, npoint, K) - features (Tensor): (B, in_c, npoint, K)
Features of the queried points. Features of the queried points.
- points_xyz (torch.Tensor): (B, 3, npoint, K) - points_xyz (Tensor): (B, 3, npoint, K)
Coordinates of the grouped points. Coordinates of the grouped points.
Returns: Returns:
Tuple[torch.Tensor]: Tuple[Tensor]:
- new_features: (B, out_c, npoint, K), features after PAConv. - new_features: (B, out_c, npoint, K) Features after PAConv.
- points_xyz: same as input. - points_xyz: Same as input.
""" """
features, points_xyz = inputs features, points_xyz = inputs
B, _, npoint, K = features.size() B, _, npoint, K = features.size()
...@@ -315,20 +323,22 @@ class PAConvCUDA(PAConv): ...@@ -315,20 +323,22 @@ class PAConvCUDA(PAConv):
more detailed descriptions. more detailed descriptions.
""" """
def __init__(self, def __init__(
in_channels, self,
out_channels, in_channels: int,
num_kernels, out_channels: int,
norm_cfg=dict(type='BN2d', momentum=0.1), num_kernels: int,
act_cfg=dict(type='ReLU', inplace=True), norm_cfg: ConfigType = dict(type='BN2d', momentum=0.1),
scorenet_input='w_neighbor_dist', act_cfg: ConfigType = dict(type='ReLU', inplace=True),
weight_bank_init='kaiming', scorenet_input: str = 'w_neighbor_dist',
kernel_input='w_neighbor', weight_bank_init: str = 'kaiming',
scorenet_cfg=dict( kernel_input: str = 'w_neighbor',
mlp_channels=[8, 16, 16], scorenet_cfg: dict = dict(
score_norm='softmax', mlp_channels=[8, 16, 16],
temp_factor=1.0, score_norm='softmax',
last_bn=False)): temp_factor=1.0,
last_bn=False)
) -> None:
super(PAConvCUDA, self).__init__( super(PAConvCUDA, self).__init__(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
...@@ -343,27 +353,27 @@ class PAConvCUDA(PAConv): ...@@ -343,27 +353,27 @@ class PAConvCUDA(PAConv):
assert self.kernel_input == 'w_neighbor', \ assert self.kernel_input == 'w_neighbor', \
'CUDA implemented PAConv only supports w_neighbor kernel_input' 'CUDA implemented PAConv only supports w_neighbor kernel_input'
def forward(self, inputs): def forward(self, inputs: Tuple[Tensor]) -> Tuple[Tensor]:
"""Forward. """Forward.
Args: Args:
inputs (tuple(torch.Tensor)): inputs (Tuple[Tensor]):
- features (torch.Tensor): (B, in_c, N) - features (Tensor): (B, in_c, N)
Features of all points in the current point cloud. Features of all points in the current point cloud.
Different from non-CUDA version PAConv, here the features Different from non-CUDA version PAConv, here the features
are not grouped by each center to form a K dim. are not grouped by each center to form a K dim.
- points_xyz (torch.Tensor): (B, 3, npoint, K) - points_xyz (Tensor): (B, 3, npoint, K)
Coordinates of the grouped points. Coordinates of the grouped points.
- points_idx (torch.Tensor): (B, npoint, K) - points_idx (Tensor): (B, npoint, K)
Index of the grouped points. Index of the grouped points.
Returns: Returns:
Tuple[torch.Tensor]: Tuple[Tensor]:
- new_features: (B, out_c, npoint, K), features after PAConv. - new_features: (B, out_c, npoint, K) Features after PAConv.
- points_xyz: same as input. - points_xyz: Same as input.
- points_idx: same as input. - points_idx: Same as input.
""" """
features, points_xyz, points_idx = inputs features, points_xyz, points_idx = inputs
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch import torch
from torch import Tensor
def calc_euclidian_dist(xyz1, xyz2): def calc_euclidian_dist(xyz1: Tensor, xyz2: Tensor) -> Tensor:
"""Calculate the Euclidean distance between two sets of points. """Calculate the Euclidean distance between two sets of points.
Args: Args:
xyz1 (torch.Tensor): (N, 3), the first set of points. xyz1 (Tensor): (N, 3) The first set of points.
xyz2 (torch.Tensor): (N, 3), the second set of points. xyz2 (Tensor): (N, 3) The second set of points.
Returns: Returns:
torch.Tensor: (N, ), the Euclidean distance between each point pair. Tensor: (N, ) The Euclidean distance between each point pair.
""" """
assert xyz1.shape[0] == xyz2.shape[0], 'number of points are not the same' assert xyz1.shape[0] == xyz2.shape[0], 'number of points are not the same'
assert xyz1.shape[1] == xyz2.shape[1] == 3, \ assert xyz1.shape[1] == xyz2.shape[1] == 3, \
...@@ -18,25 +21,25 @@ def calc_euclidian_dist(xyz1, xyz2): ...@@ -18,25 +21,25 @@ def calc_euclidian_dist(xyz1, xyz2):
return torch.norm(xyz1 - xyz2, dim=-1) return torch.norm(xyz1 - xyz2, dim=-1)
def assign_score(scores, point_features): def assign_score(scores: Tensor, point_features: Tensor) -> Tensor:
"""Perform weighted sum to aggregate output features according to scores. """Perform weighted sum to aggregate output features according to scores.
This function is used in non-CUDA version of PAConv. This function is used in non-CUDA version of PAConv.
Compared to the cuda op assigh_score_withk, this pytorch implementation Compared to the cuda op assigh_score_withk, this pytorch implementation
pre-computes output features for the neighbors of all centers, and then pre-computes output features for the neighbors of all centers, and then
performs aggregation. It consumes more GPU memories. performs aggregation. It consumes more GPU memories.
Args: Args:
scores (torch.Tensor): (B, npoint, K, M), predicted scores to scores (Tensor): (B, npoint, K, M) Predicted scores to
aggregate weight matrices in the weight bank. aggregate weight matrices in the weight bank.
`npoint` is the number of sampled centers. `npoint` is the number of sampled centers.
`K` is the number of queried neighbors. `K` is the number of queried neighbors.
`M` is the number of weight matrices in the weight bank. `M` is the number of weight matrices in the weight bank.
point_features (torch.Tensor): (B, npoint, K, M, out_dim) point_features (Tensor): (B, npoint, K, M, out_dim)
Pre-computed point features to be aggregated. Pre-computed point features to be aggregated.
Returns: Returns:
torch.Tensor: (B, npoint, K, out_dim), the aggregated features. Tensor: (B, npoint, K, out_dim) The aggregated features.
""" """
B, npoint, K, M = scores.size() B, npoint, K, M = scores.size()
scores = scores.view(B, npoint, K, 1, M) scores = scores.view(B, npoint, K, 1, M)
...@@ -44,21 +47,22 @@ def assign_score(scores, point_features): ...@@ -44,21 +47,22 @@ def assign_score(scores, point_features):
return output return output
def assign_kernel_withoutk(features, kernels, M): def assign_kernel_withoutk(features: Tensor, kernels: Tensor,
M: int) -> Tuple[Tensor]:
"""Pre-compute features with weight matrices in weight bank. This function """Pre-compute features with weight matrices in weight bank. This function
is used before cuda op assign_score_withk in CUDA version PAConv. is used before cuda op assign_score_withk in CUDA version PAConv.
Args: Args:
features (torch.Tensor): (B, in_dim, N), input features of all points. features (Tensor): (B, in_dim, N) Input features of all points.
`N` is the number of points in current point cloud. `N` is the number of points in current point cloud.
kernels (torch.Tensor): (2 * in_dim, M * out_dim), weight matrices in kernels (Tensor): (2 * in_dim, M * out_dim) Weight matrices in
the weight bank, transformed from (M, 2 * in_dim, out_dim). the weight bank, transformed from (M, 2 * in_dim, out_dim).
`2 * in_dim` is because the input features are concatenation of `2 * in_dim` is because the input features are concatenation of
(point_features - center_features, point_features). (point_features - center_features, point_features).
M (int): Number of weight matrices in the weight bank. M (int): Number of weight matrices in the weight bank.
Returns: Returns:
Tuple[torch.Tensor]: both of shape (B, N, M, out_dim): Tuple[Tensor]: Both of shape (B, N, M, out_dim).
- point_features: Pre-computed features for points. - point_features: Pre-computed features for points.
- center_features: Pre-computed features for centers. - center_features: Pre-computed features for centers.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
from mmengine.registry import Registry from mmengine.registry import Registry
from torch import nn as nn
SA_MODULES = Registry('point_sa_module') SA_MODULES = Registry('point_sa_module')
def build_sa_module(cfg, *args, **kwargs): def build_sa_module(cfg: Union[dict, None], *args, **kwargs) -> nn.Module:
"""Build PointNet2 set abstraction (SA) module. """Build PointNet2 set abstraction (SA) module.
Args: Args:
cfg (None or dict): The SA module config, which should contain: cfg (dict or None): The SA module config, which should contain:
- type (str): Module type. - type (str): Module type.
- module args: Args needed to instantiate an SA module. - module args: Args needed to instantiate an SA module.
args (argument list): Arguments passed to the `__init__` args (argument list): Arguments passed to the `__init__`
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch import torch
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmdet3d.models.layers.paconv import PAConv, PAConvCUDA from mmdet3d.models.layers.paconv import PAConv, PAConvCUDA
from mmdet3d.utils import ConfigType
from .builder import SA_MODULES from .builder import SA_MODULES
from .point_sa_module import BasePointSAModule from .point_sa_module import BasePointSAModule
...@@ -16,52 +20,81 @@ class PAConvSAModuleMSG(BasePointSAModule): ...@@ -16,52 +20,81 @@ class PAConvSAModuleMSG(BasePointSAModule):
See the `paper <https://arxiv.org/abs/2103.14635>`_ for more details. See the `paper <https://arxiv.org/abs/2103.14635>`_ for more details.
Args: Args:
paconv_num_kernels (list[list[int]]): Number of kernel weights in the num_point (int): Number of points.
radii (List[float]): List of radius in each ball query.
sample_nums (List[int]): Number of samples in each ball query.
mlp_channels (List[List[int]]): Specify of the pointnet before
the global pooling for each scale.
paconv_num_kernels (List[List[int]]): Number of kernel weights in the
weight banks of each layer's PAConv. weight banks of each layer's PAConv.
paconv_kernel_input (str, optional): Input features to be multiplied fps_mod (List[str]): Type of FPS method, valid mod
['F-FPS', 'D-FPS', 'FS']. Defaults to ['D-FPS'].
- F-FPS: Using feature distances for FPS.
- D-FPS: Using Euclidean distances of points for FPS.
- FS: Using F-FPS and D-FPS simultaneously.
fps_sample_range_list (List[int]): Range of points to apply FPS.
Defaults to [-1].
dilated_group (bool): Whether to use dilated ball query.
Defaults to False.
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
layer. Defaults to dict(type='BN2d', momentum=0.1).
use_xyz (bool): Whether to use xyz. Defaults to True.
pool_mod (str): Type of pooling method. Defaults to 'max'.
normalize_xyz (bool): Whether to normalize local XYZ with radius.
Defaults to False.
bias (bool or str): If specified as `auto`, it will be decided by
`norm_cfg`. `bias` will be set as True if `norm_cfg` is None,
otherwise False. Defaults to 'auto'.
paconv_kernel_input (str): Input features to be multiplied
with kernel weights. Can be 'identity' or 'w_neighbor'. with kernel weights. Can be 'identity' or 'w_neighbor'.
Defaults to 'w_neighbor'. Defaults to 'w_neighbor'.
scorenet_input (str, optional): Type of the input to ScoreNet. scorenet_input (str): Type of the input to ScoreNet.
Defaults to 'w_neighbor_dist'. Can be the following values: Defaults to 'w_neighbor_dist'. Can be the following values:
- 'identity': Use xyz coordinates as input. - 'identity': Use xyz coordinates as input.
- 'w_neighbor': Use xyz coordinates and the difference with center - 'w_neighbor': Use xyz coordinates and the difference with center
points as input. points as input.
- 'w_neighbor_dist': Use xyz coordinates, the difference with - 'w_neighbor_dist': Use xyz coordinates, the difference with
center points and the Euclidean distance as input. center points and the Euclidean distance as input.
scorenet_cfg (dict): Config of the ScoreNet module, which
scorenet_cfg (dict, optional): Config of the ScoreNet module, which
may contain the following keys and values: may contain the following keys and values:
- mlp_channels (List[int]): Hidden units of MLPs. - mlp_channels (List[int]): Hidden units of MLPs.
- score_norm (str): Normalization function of output scores. - score_norm (str): Normalization function of output scores.
Can be 'softmax', 'sigmoid' or 'identity'. Can be 'softmax', 'sigmoid' or 'identity'.
- temp_factor (float): Temperature factor to scale the output - temp_factor (float): Temperature factor to scale the output
scores before softmax. scores before softmax.
- last_bn (bool): Whether to use BN on the last output of mlps. - last_bn (bool): Whether to use BN on the last output of mlps.
Defaults to dict(mlp_channels=[16, 16, 16],
score_norm='softmax',
temp_factor=1.0,
last_bn=False).
""" """
def __init__(self, def __init__(
num_point, self,
radii, num_point: int,
sample_nums, radii: List[float],
mlp_channels, sample_nums: List[int],
paconv_num_kernels, mlp_channels: List[List[int]],
fps_mod=['D-FPS'], paconv_num_kernels: List[List[int]],
fps_sample_range_list=[-1], fps_mod: List[str] = ['D-FPS'],
dilated_group=False, fps_sample_range_list: List[int] = [-1],
norm_cfg=dict(type='BN2d', momentum=0.1), dilated_group: bool = False,
use_xyz=True, norm_cfg: ConfigType = dict(type='BN2d', momentum=0.1),
pool_mod='max', use_xyz: bool = True,
normalize_xyz=False, pool_mod: str = 'max',
bias='auto', normalize_xyz: bool = False,
paconv_kernel_input='w_neighbor', bias: Union[bool, str] = 'auto',
scorenet_input='w_neighbor_dist', paconv_kernel_input: str = 'w_neighbor',
scorenet_cfg=dict( scorenet_input: str = 'w_neighbor_dist',
mlp_channels=[16, 16, 16], scorenet_cfg: dict = dict(
score_norm='softmax', mlp_channels=[16, 16, 16],
temp_factor=1.0, score_norm='softmax',
last_bn=False)): temp_factor=1.0,
last_bn=False)
) -> None:
super(PAConvSAModuleMSG, self).__init__( super(PAConvSAModuleMSG, self).__init__(
num_point=num_point, num_point=num_point,
radii=radii, radii=radii,
...@@ -114,25 +147,27 @@ class PAConvSAModule(PAConvSAModuleMSG): ...@@ -114,25 +147,27 @@ class PAConvSAModule(PAConvSAModuleMSG):
<https://arxiv.org/abs/2103.14635>`_ for more details. <https://arxiv.org/abs/2103.14635>`_ for more details.
""" """
def __init__(self, def __init__(
mlp_channels, self,
paconv_num_kernels, mlp_channels: List[int],
num_point=None, paconv_num_kernels: List[int],
radius=None, num_point: Optional[int] = None,
num_sample=None, radius: Optional[float] = None,
norm_cfg=dict(type='BN2d', momentum=0.1), num_sample: Optional[int] = None,
use_xyz=True, norm_cfg: ConfigType = dict(type='BN2d', momentum=0.1),
pool_mod='max', use_xyz: bool = True,
fps_mod=['D-FPS'], pool_mod: str = 'max',
fps_sample_range_list=[-1], fps_mod: List[str] = ['D-FPS'],
normalize_xyz=False, fps_sample_range_list: List[int] = [-1],
paconv_kernel_input='w_neighbor', normalize_xyz: bool = False,
scorenet_input='w_neighbor_dist', paconv_kernel_input: str = 'w_neighbor',
scorenet_cfg=dict( scorenet_input: str = 'w_neighbor_dist',
mlp_channels=[16, 16, 16], scorenet_cfg: dict = dict(
score_norm='softmax', mlp_channels=[16, 16, 16],
temp_factor=1.0, score_norm='softmax',
last_bn=False)): temp_factor=1.0,
last_bn=False)
) -> None:
super(PAConvSAModule, self).__init__( super(PAConvSAModule, self).__init__(
mlp_channels=[mlp_channels], mlp_channels=[mlp_channels],
paconv_num_kernels=[paconv_num_kernels], paconv_num_kernels=[paconv_num_kernels],
...@@ -160,27 +195,29 @@ class PAConvCUDASAModuleMSG(BasePointSAModule): ...@@ -160,27 +195,29 @@ class PAConvCUDASAModuleMSG(BasePointSAModule):
for more details. for more details.
""" """
def __init__(self, def __init__(
num_point, self,
radii, num_point: int,
sample_nums, radii: List[float],
mlp_channels, sample_nums: List[int],
paconv_num_kernels, mlp_channels: List[List[int]],
fps_mod=['D-FPS'], paconv_num_kernels: List[List[int]],
fps_sample_range_list=[-1], fps_mod: List[str] = ['D-FPS'],
dilated_group=False, fps_sample_range_list: List[int] = [-1],
norm_cfg=dict(type='BN2d', momentum=0.1), dilated_group: bool = False,
use_xyz=True, norm_cfg: ConfigType = dict(type='BN2d', momentum=0.1),
pool_mod='max', use_xyz: bool = True,
normalize_xyz=False, pool_mod: str = 'max',
bias='auto', normalize_xyz: bool = False,
paconv_kernel_input='w_neighbor', bias: Union[bool, str] = 'auto',
scorenet_input='w_neighbor_dist', paconv_kernel_input: str = 'w_neighbor',
scorenet_cfg=dict( scorenet_input: str = 'w_neighbor_dist',
mlp_channels=[8, 16, 16], scorenet_cfg: dict = dict(
score_norm='softmax', mlp_channels=[8, 16, 16],
temp_factor=1.0, score_norm='softmax',
last_bn=False)): temp_factor=1.0,
last_bn=False)
) -> None:
super(PAConvCUDASAModuleMSG, self).__init__( super(PAConvCUDASAModuleMSG, self).__init__(
num_point=num_point, num_point=num_point,
radii=radii, radii=radii,
...@@ -230,29 +267,31 @@ class PAConvCUDASAModuleMSG(BasePointSAModule): ...@@ -230,29 +267,31 @@ class PAConvCUDASAModuleMSG(BasePointSAModule):
def forward( def forward(
self, self,
points_xyz, points_xyz: Tensor,
features=None, features: Optional[Tensor] = None,
indices=None, indices: Optional[Tensor] = None,
target_xyz=None, target_xyz: Optional[Tensor] = None,
): ) -> Tuple[Tensor]:
"""forward. """Forward.
Args: Args:
points_xyz (Tensor): (B, N, 3) xyz coordinates of the features. points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
features (Tensor, optional): (B, C, N) features of each point. features (Tensor, optional): (B, C, N) features of each point.
Default: None. Defaults to None.
indices (Tensor, optional): (B, num_point) Index of the features. indices (Tensor, optional): (B, num_point) Index of the features.
Default: None. Defaults to None.
target_xyz (Tensor, optional): (B, M, 3) new coords of the outputs. target_xyz (Tensor, optional): (B, M, 3) new coords of the outputs.
Default: None. Defaults to None.
Returns: Returns:
Tensor: (B, M, 3) where M is the number of points. Tuple[Tensor]:
New features xyz.
Tensor: (B, M, sum_k(mlps[k][-1])) where M is the number - new_xyz: (B, M, 3) where M is the number of points.
of points. New feature descriptors. New features xyz.
Tensor: (B, M) where M is the number of points. - new_features: (B, M, sum_k(mlps[k][-1])) where M is the
Index of the features. number of points. New feature descriptors.
- indices: (B, M) where M is the number of points.
Index of the features.
""" """
new_features_list = [] new_features_list = []
...@@ -306,25 +345,27 @@ class PAConvCUDASAModule(PAConvCUDASAModuleMSG): ...@@ -306,25 +345,27 @@ class PAConvCUDASAModule(PAConvCUDASAModuleMSG):
for more details. for more details.
""" """
def __init__(self, def __init__(
mlp_channels, self,
paconv_num_kernels, mlp_channels: List[int],
num_point=None, paconv_num_kernels: List[int],
radius=None, num_point: Optional[int] = None,
num_sample=None, radius: Optional[float] = None,
norm_cfg=dict(type='BN2d', momentum=0.1), num_sample: Optional[int] = None,
use_xyz=True, norm_cfg: ConfigType = dict(type='BN2d', momentum=0.1),
pool_mod='max', use_xyz: bool = True,
fps_mod=['D-FPS'], pool_mod: str = 'max',
fps_sample_range_list=[-1], fps_mod: List[str] = ['D-FPS'],
normalize_xyz=False, fps_sample_range_list: List[int] = [-1],
paconv_kernel_input='w_neighbor', normalize_xyz: bool = False,
scorenet_input='w_neighbor_dist', paconv_kernel_input: str = 'w_neighbor',
scorenet_cfg=dict( scorenet_input: str = 'w_neighbor_dist',
mlp_channels=[8, 16, 16], scorenet_cfg: dict = dict(
score_norm='softmax', mlp_channels=[8, 16, 16],
temp_factor=1.0, score_norm='softmax',
last_bn=False)): temp_factor=1.0,
last_bn=False)
) -> None:
super(PAConvCUDASAModule, self).__init__( super(PAConvCUDASAModule, self).__init__(
mlp_channels=[mlp_channels], mlp_channels=[mlp_channels],
paconv_num_kernels=[paconv_num_kernels], paconv_num_kernels=[paconv_num_kernels],
......
...@@ -5,8 +5,11 @@ import torch ...@@ -5,8 +5,11 @@ import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.ops import three_interpolate, three_nn from mmcv.ops import three_interpolate, three_nn
from mmengine.model import BaseModule from mmengine.model import BaseModule
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmdet3d.utils import ConfigType, OptMultiConfig
class PointFPModule(BaseModule): class PointFPModule(BaseModule):
"""Point feature propagation module used in PointNets. """Point feature propagation module used in PointNets.
...@@ -15,16 +18,17 @@ class PointFPModule(BaseModule): ...@@ -15,16 +18,17 @@ class PointFPModule(BaseModule):
Args: Args:
mlp_channels (list[int]): List of mlp channels. mlp_channels (list[int]): List of mlp channels.
norm_cfg (dict, optional): Type of normalization method. norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
Default: dict(type='BN2d'). layer. Defaults to dict(type='BN2d').
init_cfg (:obj:`ConfigDict` or dict or List[:obj:`Contigdict` or dict],
optional): Initialization config dict. Defaults to None.
""" """
def __init__(self, def __init__(self,
mlp_channels: List[int], mlp_channels: List[int],
norm_cfg: dict = dict(type='BN2d'), norm_cfg: ConfigType = dict(type='BN2d'),
init_cfg=None): init_cfg: OptMultiConfig = None) -> None:
super().__init__(init_cfg=init_cfg) super(PointFPModule, self).__init__(init_cfg=init_cfg)
self.fp16_enabled = False
self.mlps = nn.Sequential() self.mlps = nn.Sequential()
for i in range(len(mlp_channels) - 1): for i in range(len(mlp_channels) - 1):
self.mlps.add_module( self.mlps.add_module(
...@@ -37,23 +41,22 @@ class PointFPModule(BaseModule): ...@@ -37,23 +41,22 @@ class PointFPModule(BaseModule):
conv_cfg=dict(type='Conv2d'), conv_cfg=dict(type='Conv2d'),
norm_cfg=norm_cfg)) norm_cfg=norm_cfg))
def forward(self, target: torch.Tensor, source: torch.Tensor, def forward(self, target: Tensor, source: Tensor, target_feats: Tensor,
target_feats: torch.Tensor, source_feats: Tensor) -> Tensor:
source_feats: torch.Tensor) -> torch.Tensor: """Forward.
"""forward.
Args: Args:
target (Tensor): (B, n, 3) tensor of the xyz positions of target (Tensor): (B, n, 3) Tensor of the xyz positions of
the target features. the target features.
source (Tensor): (B, m, 3) tensor of the xyz positions of source (Tensor): (B, m, 3) Tensor of the xyz positions of
the source features. the source features.
target_feats (Tensor): (B, C1, n) tensor of the features to be target_feats (Tensor): (B, C1, n) Tensor of the features to be
propagated to. propagated to.
source_feats (Tensor): (B, C2, m) tensor of features source_feats (Tensor): (B, C2, m) Tensor of features
to be propagated. to be propagated.
Return: Return:
Tensor: (B, M, N) M = mlp[-1], tensor of the target features. Tensor: (B, M, N) M = mlp[-1], Tensor of the target features.
""" """
if source is not None: if source is not None:
dist, idx = three_nn(target, source) dist, idx = three_nn(target, source)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.ops import GroupAll from mmcv.ops import GroupAll
from mmcv.ops import PointsSampler as Points_Sampler from mmcv.ops import PointsSampler as Points_Sampler
from mmcv.ops import QueryAndGroup, gather_points from mmcv.ops import QueryAndGroup, gather_points
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.models.layers import PAConv from mmdet3d.models.layers import PAConv
from mmdet3d.utils import ConfigType
from .builder import SA_MODULES from .builder import SA_MODULES
...@@ -16,44 +20,43 @@ class BasePointSAModule(nn.Module): ...@@ -16,44 +20,43 @@ class BasePointSAModule(nn.Module):
Args: Args:
num_point (int): Number of points. num_point (int): Number of points.
radii (list[float]): List of radius in each ball query. radii (List[float]): List of radius in each ball query.
sample_nums (list[int]): Number of samples in each ball query. sample_nums (List[int]): Number of samples in each ball query.
mlp_channels (list[list[int]]): Specify of the pointnet before mlp_channels (List[List[int]]): Specify of the pointnet before
the global pooling for each scale. the global pooling for each scale.
fps_mod (list[str], optional): Type of FPS method, valid mod fps_mod (List[str]): Type of FPS method, valid mod
['F-FPS', 'D-FPS', 'FS'], Default: ['D-FPS']. ['F-FPS', 'D-FPS', 'FS']. Defaults to ['D-FPS'].
F-FPS: using feature distances for FPS.
D-FPS: using Euclidean distances of points for FPS. - F-FPS: using feature distances for FPS.
FS: using F-FPS and D-FPS simultaneously. - D-FPS: using Euclidean distances of points for FPS.
fps_sample_range_list (list[int], optional): - FS: using F-FPS and D-FPS simultaneously.
Range of points to apply FPS. Default: [-1]. fps_sample_range_list (List[int]): Range of points to apply FPS.
dilated_group (bool, optional): Whether to use dilated ball query. Defaults to [-1].
Default: False. dilated_group (bool): Whether to use dilated ball query.
use_xyz (bool, optional): Whether to use xyz. Defaults to False.
Default: True. use_xyz (bool): Whether to use xyz. Defaults to True.
pool_mod (str, optional): Type of pooling method. pool_mod (str): Type of pooling method. Defaults to 'max'.
Default: 'max_pool'. normalize_xyz (bool): Whether to normalize local XYZ with radius.
normalize_xyz (bool, optional): Whether to normalize local XYZ Defaults to False.
with radius. Default: False. grouper_return_grouped_xyz (bool): Whether to return grouped xyz
grouper_return_grouped_xyz (bool, optional): Whether to return in `QueryAndGroup`. Defaults to False.
grouped xyz in `QueryAndGroup`. Defaults to False. grouper_return_grouped_idx (bool): Whether to return grouped idx
grouper_return_grouped_idx (bool, optional): Whether to return in `QueryAndGroup`. Defaults to False.
grouped idx in `QueryAndGroup`. Defaults to False.
""" """
def __init__(self, def __init__(self,
num_point, num_point: int,
radii, radii: List[float],
sample_nums, sample_nums: List[int],
mlp_channels, mlp_channels: List[List[int]],
fps_mod=['D-FPS'], fps_mod: List[str] = ['D-FPS'],
fps_sample_range_list=[-1], fps_sample_range_list: List[int] = [-1],
dilated_group=False, dilated_group: bool = False,
use_xyz=True, use_xyz: bool = True,
pool_mod='max', pool_mod: str = 'max',
normalize_xyz=False, normalize_xyz: bool = False,
grouper_return_grouped_xyz=False, grouper_return_grouped_xyz: bool = False,
grouper_return_grouped_idx=False): grouper_return_grouped_idx: bool = False) -> None:
super(BasePointSAModule, self).__init__() super(BasePointSAModule, self).__init__()
assert len(radii) == len(sample_nums) == len(mlp_channels) assert len(radii) == len(sample_nums) == len(mlp_channels)
...@@ -109,7 +112,8 @@ class BasePointSAModule(nn.Module): ...@@ -109,7 +112,8 @@ class BasePointSAModule(nn.Module):
grouper = GroupAll(use_xyz) grouper = GroupAll(use_xyz)
self.groupers.append(grouper) self.groupers.append(grouper)
def _sample_points(self, points_xyz, features, indices, target_xyz): def _sample_points(self, points_xyz: Tensor, features: Tensor,
indices: Tensor, target_xyz: Tensor) -> Tuple[Tensor]:
"""Perform point sampling based on inputs. """Perform point sampling based on inputs.
If `indices` is specified, directly sample corresponding points. If `indices` is specified, directly sample corresponding points.
...@@ -118,13 +122,15 @@ class BasePointSAModule(nn.Module): ...@@ -118,13 +122,15 @@ class BasePointSAModule(nn.Module):
Args: Args:
points_xyz (Tensor): (B, N, 3) xyz coordinates of the features. points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
features (Tensor): (B, C, N) features of each point. features (Tensor): (B, C, N) Features of each point.
indices (Tensor): (B, num_point) Index of the features. indices (Tensor): (B, num_point) Index of the features.
target_xyz (Tensor): (B, M, 3) new_xyz coordinates of the outputs. target_xyz (Tensor): (B, M, 3) new_xyz coordinates of the outputs.
Returns: Returns:
Tensor: (B, num_point, 3) sampled xyz coordinates of points. Tuple[Tensor]:
Tensor: (B, num_point) sampled points' index.
- new_xyz: (B, num_point, 3) Sampled xyz coordinates of points.
- indices: (B, num_point) Sampled points' index.
""" """
xyz_flipped = points_xyz.transpose(1, 2).contiguous() xyz_flipped = points_xyz.transpose(1, 2).contiguous()
if indices is not None: if indices is not None:
...@@ -143,16 +149,15 @@ class BasePointSAModule(nn.Module): ...@@ -143,16 +149,15 @@ class BasePointSAModule(nn.Module):
return new_xyz, indices return new_xyz, indices
def _pool_features(self, features): def _pool_features(self, features: Tensor) -> Tensor:
"""Perform feature aggregation using pooling operation. """Perform feature aggregation using pooling operation.
Args: Args:
features (torch.Tensor): (B, C, N, K) features (Tensor): (B, C, N, K) Features of locally grouped
Features of locally grouped points before pooling. points before pooling.
Returns: Returns:
torch.Tensor: (B, C, N) Tensor: (B, C, N) Pooled features aggregating local information.
Pooled features aggregating local information.
""" """
if self.pool_mod == 'max': if self.pool_mod == 'max':
# (B, C, N, 1) # (B, C, N, 1)
...@@ -169,29 +174,31 @@ class BasePointSAModule(nn.Module): ...@@ -169,29 +174,31 @@ class BasePointSAModule(nn.Module):
def forward( def forward(
self, self,
points_xyz, points_xyz: Tensor,
features=None, features: Optional[Tensor] = None,
indices=None, indices: Optional[Tensor] = None,
target_xyz=None, target_xyz: Optional[Tensor] = None,
): ) -> Tuple[Tensor]:
"""forward. """Forward.
Args: Args:
points_xyz (Tensor): (B, N, 3) xyz coordinates of the features. points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
features (Tensor, optional): (B, C, N) features of each point. features (Tensor, optional): (B, C, N) Features of each point.
Default: None. Defaults to None.
indices (Tensor, optional): (B, num_point) Index of the features. indices (Tensor, optional): (B, num_point) Index of the features.
Default: None. Defaults to None.
target_xyz (Tensor, optional): (B, M, 3) new coords of the outputs. target_xyz (Tensor, optional): (B, M, 3) New coords of the outputs.
Default: None. Defaults to None.
Returns: Returns:
Tensor: (B, M, 3) where M is the number of points. Tuple[Tensor]:
New features xyz.
Tensor: (B, M, sum_k(mlps[k][-1])) where M is the number - new_xyz: (B, M, 3) Where M is the number of points.
of points. New feature descriptors. New features xyz.
Tensor: (B, M) where M is the number of points. - new_features: (B, M, sum_k(mlps[k][-1])) Where M is the
Index of the features. number of points. New feature descriptors.
- indices: (B, M) Where M is the number of points.
Index of the features.
""" """
new_features_list = [] new_features_list = []
...@@ -229,45 +236,44 @@ class PointSAModuleMSG(BasePointSAModule): ...@@ -229,45 +236,44 @@ class PointSAModuleMSG(BasePointSAModule):
Args: Args:
num_point (int): Number of points. num_point (int): Number of points.
radii (list[float]): List of radius in each ball query. radii (List[float]): List of radius in each ball query.
sample_nums (list[int]): Number of samples in each ball query. sample_nums (List[int]): Number of samples in each ball query.
mlp_channels (list[list[int]]): Specify of the pointnet before mlp_channels (List[List[int]]): Specify of the pointnet before
the global pooling for each scale. the global pooling for each scale.
fps_mod (list[str], optional): Type of FPS method, valid mod fps_mod (List[str]): Type of FPS method, valid mod
['F-FPS', 'D-FPS', 'FS'], Default: ['D-FPS']. ['F-FPS', 'D-FPS', 'FS']. Defaults to ['D-FPS'].
F-FPS: using feature distances for FPS.
D-FPS: using Euclidean distances of points for FPS. - F-FPS: using feature distances for FPS.
FS: using F-FPS and D-FPS simultaneously. - D-FPS: using Euclidean distances of points for FPS.
fps_sample_range_list (list[int], optional): Range of points to - FS: using F-FPS and D-FPS simultaneously.
apply FPS. Default: [-1]. fps_sample_range_list (List[int]): Range of points to apply FPS.
dilated_group (bool, optional): Whether to use dilated ball query. Defaults to [-1].
Default: False. dilated_group (bool): Whether to use dilated ball query.
norm_cfg (dict, optional): Type of normalization method. Defaults to False.
Default: dict(type='BN2d'). norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
use_xyz (bool, optional): Whether to use xyz. layer. Defaults to dict(type='BN2d').
Default: True. use_xyz (bool): Whether to use xyz. Defaults to True.
pool_mod (str, optional): Type of pooling method. pool_mod (str): Type of pooling method. Defaults to 'max'.
Default: 'max_pool'. normalize_xyz (bool): Whether to normalize local XYZ with radius.
normalize_xyz (bool, optional): Whether to normalize local XYZ Defaults to False.
with radius. Default: False. bias (bool or str): If specified as `auto`, it will be decided by
bias (bool | str, optional): If specified as `auto`, it will be `norm_cfg`. `bias` will be set as True if `norm_cfg` is None,
decided by `norm_cfg`. `bias` will be set as True if otherwise False. Defaults to 'auto'.
`norm_cfg` is None, otherwise False. Default: 'auto'.
""" """
def __init__(self, def __init__(self,
num_point, num_point: int,
radii, radii: List[float],
sample_nums, sample_nums: List[int],
mlp_channels, mlp_channels: List[List[int]],
fps_mod=['D-FPS'], fps_mod: List[str] = ['D-FPS'],
fps_sample_range_list=[-1], fps_sample_range_list: List[int] = [-1],
dilated_group=False, dilated_group: bool = False,
norm_cfg=dict(type='BN2d'), norm_cfg: ConfigType = dict(type='BN2d'),
use_xyz=True, use_xyz: bool = True,
pool_mod='max', pool_mod: str = 'max',
normalize_xyz=False, normalize_xyz: bool = False,
bias='auto'): bias: Union[bool, str] = 'auto') -> None:
super(PointSAModuleMSG, self).__init__( super(PointSAModuleMSG, self).__init__(
num_point=num_point, num_point=num_point,
radii=radii, radii=radii,
...@@ -306,39 +312,35 @@ class PointSAModule(PointSAModuleMSG): ...@@ -306,39 +312,35 @@ class PointSAModule(PointSAModuleMSG):
PointNets. PointNets.
Args: Args:
mlp_channels (list[int]): Specify of the pointnet before mlp_channels (List[int]): Specify of the pointnet before
the global pooling for each scale. the global pooling for each scale.
num_point (int, optional): Number of points. num_point (int, optional): Number of points. Defaults to None.
Default: None. radius (float, optional): Radius to group with. Defaults to None.
radius (float, optional): Radius to group with.
Default: None.
num_sample (int, optional): Number of samples in each ball query. num_sample (int, optional): Number of samples in each ball query.
Default: None. Defaults to None.
norm_cfg (dict, optional): Type of normalization method. norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
Default: dict(type='BN2d'). layer. Default to dict(type='BN2d').
use_xyz (bool, optional): Whether to use xyz. use_xyz (bool): Whether to use xyz. Defaults to True.
Default: True. pool_mod (str): Type of pooling method. Defaults to 'max'.
pool_mod (str, optional): Type of pooling method. fps_mod (List[str]): Type of FPS method, valid mod
Default: 'max_pool'. ['F-FPS', 'D-FPS', 'FS']. Defaults to ['D-FPS'].
fps_mod (list[str], optional): Type of FPS method, valid mod fps_sample_range_list (List[int]): Range of points to apply FPS.
['F-FPS', 'D-FPS', 'FS'], Default: ['D-FPS']. Defaults to [-1].
fps_sample_range_list (list[int], optional): Range of points normalize_xyz (bool): Whether to normalize local XYZ with radius.
to apply FPS. Default: [-1]. Defaults to False.
normalize_xyz (bool, optional): Whether to normalize local XYZ
with radius. Default: False.
""" """
def __init__(self, def __init__(self,
mlp_channels, mlp_channels: List[int],
num_point=None, num_point: Optional[int] = None,
radius=None, radius: Optional[float] = None,
num_sample=None, num_sample: Optional[int] = None,
norm_cfg=dict(type='BN2d'), norm_cfg: ConfigType = dict(type='BN2d'),
use_xyz=True, use_xyz: bool = True,
pool_mod='max', pool_mod: str = 'max',
fps_mod=['D-FPS'], fps_mod: List[str] = ['D-FPS'],
fps_sample_range_list=[-1], fps_sample_range_list: List[int] = [-1],
normalize_xyz=False): normalize_xyz: bool = False) -> None:
super(PointSAModule, self).__init__( super(PointSAModule, self).__init__(
mlp_channels=[mlp_channels], mlp_channels=[mlp_channels],
num_point=num_point, num_point=num_point,
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union
from mmcv.cnn import build_conv_layer, build_norm_layer from mmcv.cnn import build_conv_layer, build_norm_layer
from mmdet.models.backbones.resnet import BasicBlock, Bottleneck from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
from torch import nn from torch import nn
from mmdet3d.utils import OptConfigType
from .spconv import IS_SPCONV2_AVAILABLE from .spconv import IS_SPCONV2_AVAILABLE
if IS_SPCONV2_AVAILABLE: if IS_SPCONV2_AVAILABLE:
from spconv.pytorch import SparseModule, SparseSequential from spconv.pytorch import SparseConvTensor, SparseModule, SparseSequential
else: else:
from mmcv.ops import SparseModule, SparseSequential from mmcv.ops import SparseConvTensor, SparseModule, SparseSequential
def replace_feature(out, new_features): def replace_feature(out: SparseConvTensor,
new_features: SparseConvTensor) -> SparseConvTensor:
if 'replace_feature' in out.__dir__(): if 'replace_feature' in out.__dir__():
# spconv 2.x behaviour # spconv 2.x behaviour
return out.replace_feature(new_features) return out.replace_feature(new_features)
...@@ -26,25 +30,26 @@ class SparseBottleneck(Bottleneck, SparseModule): ...@@ -26,25 +30,26 @@ class SparseBottleneck(Bottleneck, SparseModule):
Bottleneck block implemented with submanifold sparse convolution. Bottleneck block implemented with submanifold sparse convolution.
Args: Args:
inplanes (int): inplanes of block. inplanes (int): Inplanes of block.
planes (int): planes of block. planes (int): Planes of block.
stride (int, optional): stride of the first block. Default: 1. stride (int or Tuple[int]): Stride of the first block. Defaults to 1.
downsample (Module, optional): down sample module for block. downsample (Module, optional): Down sample module for block.
conv_cfg (dict, optional): dictionary to construct and config conv Defaults to None.
layer. Default: None. conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
norm_cfg (dict, optional): dictionary to construct and config norm convolution layer. Defaults to None.
layer. Default: dict(type='BN'). norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
normalization layer. Defaults to None.
""" """
expansion = 4 expansion = 4
def __init__(self, def __init__(self,
inplanes, inplanes: int,
planes, planes: int,
stride=1, stride: Union[int, Tuple[int]] = 1,
downsample=None, downsample: nn.Module = None,
conv_cfg=None, conv_cfg: OptConfigType = None,
norm_cfg=None): norm_cfg: OptConfigType = None) -> None:
SparseModule.__init__(self) SparseModule.__init__(self)
Bottleneck.__init__( Bottleneck.__init__(
...@@ -56,7 +61,7 @@ class SparseBottleneck(Bottleneck, SparseModule): ...@@ -56,7 +61,7 @@ class SparseBottleneck(Bottleneck, SparseModule):
conv_cfg=conv_cfg, conv_cfg=conv_cfg,
norm_cfg=norm_cfg) norm_cfg=norm_cfg)
def forward(self, x): def forward(self, x: SparseConvTensor) -> SparseConvTensor:
identity = x.features identity = x.features
out = self.conv1(x) out = self.conv1(x)
...@@ -85,25 +90,26 @@ class SparseBasicBlock(BasicBlock, SparseModule): ...@@ -85,25 +90,26 @@ class SparseBasicBlock(BasicBlock, SparseModule):
Sparse basic block implemented with submanifold sparse convolution. Sparse basic block implemented with submanifold sparse convolution.
Args: Args:
inplanes (int): inplanes of block. inplanes (int): Inplanes of block.
planes (int): planes of block. planes (int): Planes of block.
stride (int, optional): stride of the first block. Default: 1. stride (int or Tuple[int]): Stride of the first block. Defaults to 1.
downsample (Module, optional): down sample module for block. downsample (Module, optional): Down sample module for block.
conv_cfg (dict, optional): dictionary to construct and config conv Defaults to None.
layer. Default: None. conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
norm_cfg (dict, optional): dictionary to construct and config norm convolution layer. Defaults to None.
layer. Default: dict(type='BN'). norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
normalization layer. Defaults to None.
""" """
expansion = 1 expansion = 1
def __init__(self, def __init__(self,
inplanes, inplanes: int,
planes, planes: int,
stride=1, stride: Union[int, Tuple[int]] = 1,
downsample=None, downsample: nn.Module = None,
conv_cfg=None, conv_cfg: OptConfigType = None,
norm_cfg=None): norm_cfg: OptConfigType = None) -> None:
SparseModule.__init__(self) SparseModule.__init__(self)
BasicBlock.__init__( BasicBlock.__init__(
self, self,
...@@ -114,7 +120,7 @@ class SparseBasicBlock(BasicBlock, SparseModule): ...@@ -114,7 +120,7 @@ class SparseBasicBlock(BasicBlock, SparseModule):
conv_cfg=conv_cfg, conv_cfg=conv_cfg,
norm_cfg=norm_cfg) norm_cfg=norm_cfg)
def forward(self, x): def forward(self, x: SparseConvTensor) -> SparseConvTensor:
identity = x.features identity = x.features
assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}' assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
...@@ -134,29 +140,33 @@ class SparseBasicBlock(BasicBlock, SparseModule): ...@@ -134,29 +140,33 @@ class SparseBasicBlock(BasicBlock, SparseModule):
return out return out
def make_sparse_convmodule(in_channels, def make_sparse_convmodule(
out_channels, in_channels: int,
kernel_size, out_channels: int,
indice_key, kernel_size: Union[int, Tuple[int]],
stride=1, indice_key: str,
padding=0, stride: Union[int, Tuple[int]] = 1,
conv_type='SubMConv3d', padding: Union[int, Tuple[int]] = 0,
norm_cfg=None, conv_type: str = 'SubMConv3d',
order=('conv', 'norm', 'act')): norm_cfg: OptConfigType = None,
order: Tuple[str] = ('conv', 'norm', 'act')
) -> SparseSequential:
"""Make sparse convolution module. """Make sparse convolution module.
Args: Args:
in_channels (int): the number of input channels in_channels (int): The number of input channels.
out_channels (int): the number of out channels out_channels (int): The number of out channels.
kernel_size (int|tuple(int)): kernel size of convolution kernel_size (int | Tuple[int]): Kernel size of convolution.
indice_key (str): the indice key used for sparse tensor indice_key (str): The indice key used for sparse tensor.
stride (int|tuple(int)): the stride of convolution stride (int or tuple[int]): The stride of convolution.
padding (int or list[int]): the padding number of input padding (int or tuple[int]): The padding number of input.
conv_type (str): sparse conv type in spconv conv_type (str): Sparse conv type in spconv. Defaults to 'SubMConv3d'.
norm_cfg (dict[str]): config of normalization layer norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
order (tuple[str]): The order of conv/norm/activation layers. It is a normalization layer. Defaults to None.
order (Tuple[str]): The order of conv/norm/activation layers. It is a
sequence of "conv", "norm" and "act". Common examples are sequence of "conv", "norm" and "act". Common examples are
("conv", "norm", "act") and ("act", "conv", "norm"). ("conv", "norm", "act") and ("act", "conv", "norm").
Defaults to ('conv', 'norm', 'act').
Returns: Returns:
spconv.SparseSequential: sparse convolution module. spconv.SparseSequential: sparse convolution module.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import itertools import itertools
from typing import List, OrderedDict
from mmengine.registry import MODELS from mmengine.registry import MODELS
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
def register_spconv2(): def register_spconv2() -> bool:
"""This func registers spconv2.0 spconv ops to overwrite the default mmcv """This func registers spconv2.0 spconv ops to overwrite the default mmcv
spconv ops.""" spconv ops."""
try: try:
...@@ -39,8 +40,10 @@ def register_spconv2(): ...@@ -39,8 +40,10 @@ def register_spconv2():
return True return True
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, def _load_from_state_dict(self, state_dict: OrderedDict, prefix: str,
missing_keys, unexpected_keys, error_msgs): local_metadata: dict, strict: bool,
missing_keys: List[str], unexpected_keys: List[str],
error_msgs: List[str]) -> None:
"""Rewrite this func to compat the convolutional kernel weights between """Rewrite this func to compat the convolutional kernel weights between
spconv 1.x in MMCV and 2.x in spconv2.x. spconv 1.x in MMCV and 2.x in spconv2.x.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
from mmcv.cnn.bricks.transformer import MultiheadAttention from mmcv.cnn.bricks.transformer import MultiheadAttention
from mmengine.registry import MODELS from mmengine.registry import MODELS
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmdet3d.utils import ConfigType, OptMultiConfig
@MODELS.register_module() @MODELS.register_module()
class GroupFree3DMHA(MultiheadAttention): class GroupFree3DMHA(MultiheadAttention):
...@@ -15,40 +20,42 @@ class GroupFree3DMHA(MultiheadAttention): ...@@ -15,40 +20,42 @@ class GroupFree3DMHA(MultiheadAttention):
embed_dims (int): The embedding dimension. embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads. Same as num_heads (int): Parallel attention heads. Same as
`nn.MultiheadAttention`. `nn.MultiheadAttention`.
attn_drop (float, optional): A Dropout layer on attn_output_weights. attn_drop (float): A Dropout layer on attn_output_weights.
Defaults to 0.0. Defaults to 0.0.
proj_drop (float, optional): A Dropout layer. Defaults to 0.0. proj_drop (float): A Dropout layer. Defaults to 0.0.
dropout_layer (obj:`ConfigDict`, optional): The dropout_layer used dropout_layer (ConfigType): The dropout_layer used when adding
when adding the shortcut. the shortcut. Defaults to dict(type='DropOut', drop_prob=0.).
init_cfg (obj:`mmengine.ConfigDict`, optional): The Config for init_cfg (:obj:`ConfigDict` or dict or List[:obj:`Contigdict` or dict],
initialization. Default: None. optional): Initialization config dict. Defaults to None.
batch_first (bool, optional): Key, Query and Value are shape of batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim) (batch, n, embed_dim) or (n, batch, embed_dim).
or (n, batch, embed_dim). Defaults to False. Defaults to False.
""" """
def __init__(self, def __init__(self,
embed_dims, embed_dims: int,
num_heads, num_heads: int,
attn_drop=0., attn_drop: float = 0.,
proj_drop=0., proj_drop: float = 0.,
dropout_layer=dict(type='DropOut', drop_prob=0.), dropout_layer: ConfigType = dict(
init_cfg=None, type='DropOut', drop_prob=0.),
batch_first=False, init_cfg: OptMultiConfig = None,
**kwargs): batch_first: bool = False,
super().__init__(embed_dims, num_heads, attn_drop, proj_drop, **kwargs) -> None:
dropout_layer, init_cfg, batch_first, **kwargs) super(GroupFree3DMHA,
self).__init__(embed_dims, num_heads, attn_drop, proj_drop,
dropout_layer, init_cfg, batch_first, **kwargs)
def forward(self, def forward(self,
query, query: Tensor,
key, key: Tensor,
value, value: Tensor,
identity, identity: Tensor,
query_pos=None, query_pos: Optional[Tensor] = None,
key_pos=None, key_pos: Optional[Tensor] = None,
attn_mask=None, attn_mask: Optional[Tensor] = None,
key_padding_mask=None, key_padding_mask: Optional[Tensor] = None,
**kwargs): **kwargs) -> Tensor:
"""Forward function for `GroupFree3DMHA`. """Forward function for `GroupFree3DMHA`.
**kwargs allow passing a more general data flow when combining **kwargs allow passing a more general data flow when combining
...@@ -81,7 +88,7 @@ class GroupFree3DMHA(MultiheadAttention): ...@@ -81,7 +88,7 @@ class GroupFree3DMHA(MultiheadAttention):
Defaults to None. Defaults to None.
Returns: Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims]. Tensor: Forwarded results with shape [num_queries, bs, embed_dims].
""" """
if hasattr(self, 'operation_name'): if hasattr(self, 'operation_name'):
...@@ -113,26 +120,26 @@ class ConvBNPositionalEncoding(nn.Module): ...@@ -113,26 +120,26 @@ class ConvBNPositionalEncoding(nn.Module):
"""Absolute position embedding with Conv learning. """Absolute position embedding with Conv learning.
Args: Args:
input_channel (int): input features dim. input_channel (int): Input features dim.
num_pos_feats (int, optional): output position features dim. num_pos_feats (int): Output position features dim.
Defaults to 288 to be consistent with seed features dim. Defaults to 288 to be consistent with seed features dim.
""" """
def __init__(self, input_channel, num_pos_feats=288): def __init__(self, input_channel: int, num_pos_feats: int = 288) -> None:
super().__init__() super(ConvBNPositionalEncoding, self).__init__()
self.position_embedding_head = nn.Sequential( self.position_embedding_head = nn.Sequential(
nn.Conv1d(input_channel, num_pos_feats, kernel_size=1), nn.Conv1d(input_channel, num_pos_feats, kernel_size=1),
nn.BatchNorm1d(num_pos_feats), nn.ReLU(inplace=True), nn.BatchNorm1d(num_pos_feats), nn.ReLU(inplace=True),
nn.Conv1d(num_pos_feats, num_pos_feats, kernel_size=1)) nn.Conv1d(num_pos_feats, num_pos_feats, kernel_size=1))
def forward(self, xyz): def forward(self, xyz: Tensor) -> Tensor:
"""Forward pass. """Forward pass.
Args: Args:
xyz (Tensor) (B, N, 3) the coordinates to embed. xyz (Tensor): (B, N, 3) The coordinates to embed.
Returns: Returns:
Tensor: (B, num_pos_feats, N) the embedded position features. Tensor: (B, num_pos_feats, N) The embedded position features.
""" """
xyz = xyz.permute(0, 2, 1) xyz = xyz.permute(0, 2, 1)
position_embedding = self.position_embedding_head(xyz) position_embedding = self.position_embedding_head(xyz)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmengine import is_tuple_of from mmengine import is_tuple_of
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmdet3d.models.builder import build_loss from mmdet3d.registry import MODELS
from mmdet3d.utils import ConfigType, OptConfigType
class VoteModule(nn.Module): class VoteModule(nn.Module):
...@@ -14,41 +18,41 @@ class VoteModule(nn.Module): ...@@ -14,41 +18,41 @@ class VoteModule(nn.Module):
Args: Args:
in_channels (int): Number of channels of seed point features. in_channels (int): Number of channels of seed point features.
vote_per_seed (int, optional): Number of votes generated from vote_per_seed (int): Number of votes generated from each seed point.
each seed point. Default: 1. Defaults to 1.
gt_per_seed (int, optional): Number of ground truth votes generated gt_per_seed (int): Number of ground truth votes generated from each
from each seed point. Default: 3. seed point. Defaults to 3.
num_points (int, optional): Number of points to be used for voting. num_points (int): Number of points to be used for voting.
Default: 1. Defaults to 1.
conv_channels (tuple[int], optional): Out channels of vote conv_channels (tuple[int]): Out channels of vote generating
generating convolution. Default: (16, 16). convolution. Defaults to (16, 16).
conv_cfg (dict, optional): Config of convolution. conv_cfg (:obj:`ConfigDict` or dict): Config dict for convolution
Default: dict(type='Conv1d'). layer. Defaults to dict(type='Conv1d').
norm_cfg (dict, optional): Config of normalization. norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
Default: dict(type='BN1d'). layer. Defaults to dict(type='BN1d').
norm_feats (bool, optional): Whether to normalize features. norm_feats (bool): Whether to normalize features. Default to True.
Default: True. with_res_feat (bool): Whether to predict residual features.
with_res_feat (bool, optional): Whether to predict residual features. Defaults to True.
Default: True. vote_xyz_range (List[float], optional): The range of points
vote_xyz_range (list[float], optional): translation. Defaults to None.
The range of points translation. Default: None. vote_loss (:obj:`ConfigDict` or dict, optional): Config of vote loss.
vote_loss (dict, optional): Config of vote loss. Default: None. Defaults to None.
""" """
def __init__(self, def __init__(self,
in_channels, in_channels: int,
vote_per_seed=1, vote_per_seed: int = 1,
gt_per_seed=3, gt_per_seed: int = 3,
num_points=-1, num_points: int = -1,
conv_channels=(16, 16), conv_channels: Tuple[int] = (16, 16),
conv_cfg=dict(type='Conv1d'), conv_cfg: ConfigType = dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'), norm_cfg: ConfigType = dict(type='BN1d'),
act_cfg=dict(type='ReLU'), act_cfg: ConfigType = dict(type='ReLU'),
norm_feats=True, norm_feats: bool = True,
with_res_feat=True, with_res_feat: bool = True,
vote_xyz_range=None, vote_xyz_range: List[float] = None,
vote_loss=None): vote_loss: OptConfigType = None) -> None:
super().__init__() super(VoteModule, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.vote_per_seed = vote_per_seed self.vote_per_seed = vote_per_seed
self.gt_per_seed = gt_per_seed self.gt_per_seed = gt_per_seed
...@@ -60,7 +64,7 @@ class VoteModule(nn.Module): ...@@ -60,7 +64,7 @@ class VoteModule(nn.Module):
self.vote_xyz_range = vote_xyz_range self.vote_xyz_range = vote_xyz_range
if vote_loss is not None: if vote_loss is not None:
self.vote_loss = build_loss(vote_loss) self.vote_loss = MODELS.build(vote_loss)
prev_channels = in_channels prev_channels = in_channels
vote_conv_list = list() vote_conv_list = list()
...@@ -86,23 +90,24 @@ class VoteModule(nn.Module): ...@@ -86,23 +90,24 @@ class VoteModule(nn.Module):
out_channel = 3 * self.vote_per_seed out_channel = 3 * self.vote_per_seed
self.conv_out = nn.Conv1d(prev_channels, out_channel, 1) self.conv_out = nn.Conv1d(prev_channels, out_channel, 1)
def forward(self, seed_points, seed_feats): def forward(self, seed_points: Tensor,
"""forward. seed_feats: Tensor) -> Tuple[Tensor]:
"""Forward.
Args: Args:
seed_points (torch.Tensor): Coordinate of the seed seed_points (Tensor): Coordinate of the seed points in shape
points in shape (B, N, 3). (B, N, 3).
seed_feats (torch.Tensor): Features of the seed points in shape seed_feats (Tensor): Features of the seed points in shape
(B, C, N). (B, C, N).
Returns: Returns:
tuple[torch.Tensor]: Tuple[torch.Tensor]:
- vote_points: Voted xyz based on the seed points - vote_points: Voted xyz based on the seed points
with shape (B, M, 3), ``M=num_seed*vote_per_seed``. with shape (B, M, 3), ``M=num_seed*vote_per_seed``.
- vote_features: Voted features based on the seed points with - vote_features: Voted features based on the seed points with
shape (B, C, M) where ``M=num_seed*vote_per_seed``, shape (B, C, M) where ``M=num_seed*vote_per_seed``,
``C=vote_feature_dim``. ``C=vote_feature_dim``.
""" """
if self.num_points != -1: if self.num_points != -1:
assert self.num_points < seed_points.shape[1], \ assert self.num_points < seed_points.shape[1], \
...@@ -150,19 +155,20 @@ class VoteModule(nn.Module): ...@@ -150,19 +155,20 @@ class VoteModule(nn.Module):
vote_feats = seed_feats vote_feats = seed_feats
return vote_points, vote_feats, offset return vote_points, vote_feats, offset
def get_loss(self, seed_points, vote_points, seed_indices, def get_loss(self, seed_points: Tensor, vote_points: Tensor,
vote_targets_mask, vote_targets): seed_indices: Tensor, vote_targets_mask: Tensor,
vote_targets: Tensor) -> Tensor:
"""Calculate loss of voting module. """Calculate loss of voting module.
Args: Args:
seed_points (torch.Tensor): Coordinate of the seed points. seed_points (Tensor): Coordinate of the seed points.
vote_points (torch.Tensor): Coordinate of the vote points. vote_points (Tensor): Coordinate of the vote points.
seed_indices (torch.Tensor): Indices of seed points in raw points. seed_indices (Tensor): Indices of seed points in raw points.
vote_targets_mask (torch.Tensor): Mask of valid vote targets. vote_targets_mask (Tensor): Mask of valid vote targets.
vote_targets (torch.Tensor): Targets of votes. vote_targets (Tensor): Targets of votes.
Returns: Returns:
torch.Tensor: Weighted vote loss. Tensor: Weighted vote loss.
""" """
batch_size, num_seed = seed_points.shape[:2] batch_size, num_seed = seed_points.shape[:2]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment