Unverified Commit d7067e44 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

Bump version to v1.1.0rc2

Bump to v1.1.0rc2
parents 28fe73d2 fb0e57e5
...@@ -9,6 +9,7 @@ from mmcv.cnn.bricks.transformer import (build_positional_encoding, ...@@ -9,6 +9,7 @@ from mmcv.cnn.bricks.transformer import (build_positional_encoding,
build_transformer_layer) build_transformer_layer)
from mmcv.ops import PointsSampler as Points_Sampler from mmcv.ops import PointsSampler as Points_Sampler
from mmcv.ops import gather_points from mmcv.ops import gather_points
from mmdet.models.utils import multi_apply
from mmengine.model import BaseModule, xavier_init from mmengine.model import BaseModule, xavier_init
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
from torch import Tensor from torch import Tensor
...@@ -19,7 +20,6 @@ from mmdet3d.models.layers import aligned_3d_nms ...@@ -19,7 +20,6 @@ from mmdet3d.models.layers import aligned_3d_nms
from mmdet3d.registry import MODELS, TASK_UTILS from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures import BaseInstance3DBoxes, Det3DDataSample from mmdet3d.structures import BaseInstance3DBoxes, Det3DDataSample
from mmdet3d.structures.det3d_data_sample import SampleList from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet.models.utils import multi_apply
from .base_conv_bbox_head import BaseConvBboxHead from .base_conv_bbox_head import BaseConvBboxHead
EPS = 1e-6 EPS = 1e-6
......
...@@ -2,6 +2,11 @@ ...@@ -2,6 +2,11 @@
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
from mmdet.models.utils import (gaussian_radius, gen_gaussian_target,
multi_apply)
from mmdet.models.utils.gaussian_target import (get_local_maximum,
get_topk_from_heatmap,
transpose_and_gather_feat)
from mmengine.config import ConfigDict from mmengine.config import ConfigDict
from mmengine.model import xavier_init from mmengine.model import xavier_init
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
...@@ -15,11 +20,6 @@ from mmdet3d.models.utils import (filter_outside_objs, get_edge_indices, ...@@ -15,11 +20,6 @@ from mmdet3d.models.utils import (filter_outside_objs, get_edge_indices,
handle_proj_objs) handle_proj_objs)
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet3d.structures import Det3DDataSample from mmdet3d.structures import Det3DDataSample
from mmdet.models.utils import (gaussian_radius, gen_gaussian_target,
multi_apply)
from mmdet.models.utils.gaussian_target import (get_local_maximum,
get_topk_from_heatmap,
transpose_and_gather_feat)
from .anchor_free_mono3d_head import AnchorFreeMono3DHead from .anchor_free_mono3d_head import AnchorFreeMono3DHead
......
...@@ -4,6 +4,8 @@ from typing import List, Optional, Tuple ...@@ -4,6 +4,8 @@ from typing import List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from mmcv.cnn import Scale from mmcv.cnn import Scale
from mmdet.models.utils import multi_apply
from mmdet.structures.bbox import distance2bbox
from mmengine.model import bias_init_with_prob, normal_init from mmengine.model import bias_init_with_prob, normal_init
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
from torch import Tensor from torch import Tensor
...@@ -15,8 +17,6 @@ from mmdet3d.registry import MODELS ...@@ -15,8 +17,6 @@ from mmdet3d.registry import MODELS
from mmdet3d.structures import points_cam2img, points_img2cam, xywhr2xyxyr from mmdet3d.structures import points_cam2img, points_img2cam, xywhr2xyxyr
from mmdet3d.utils.typing import (ConfigType, InstanceList, OptConfigType, from mmdet3d.utils.typing import (ConfigType, InstanceList, OptConfigType,
OptInstanceList) OptInstanceList)
from mmdet.models.utils import multi_apply
from mmdet.structures.bbox import distance2bbox
from .fcos_mono3d_head import FCOSMono3DHead from .fcos_mono3d_head import FCOSMono3DHead
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
from mmdet.models.utils import multi_apply
from mmengine.model import BaseModule from mmengine.model import BaseModule
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
from torch import Tensor from torch import Tensor
...@@ -15,7 +16,6 @@ from mmdet3d.structures.bbox_3d import (BaseInstance3DBoxes, ...@@ -15,7 +16,6 @@ from mmdet3d.structures.bbox_3d import (BaseInstance3DBoxes,
LiDARInstance3DBoxes) LiDARInstance3DBoxes)
from mmdet3d.structures.det3d_data_sample import SampleList from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils.typing import InstanceList from mmdet3d.utils.typing import InstanceList
from mmdet.models.utils import multi_apply
@MODELS.register_module() @MODELS.register_module()
......
...@@ -5,6 +5,7 @@ from typing import Dict, List, Optional, Tuple ...@@ -5,6 +5,7 @@ from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmdet.models.utils import multi_apply
from mmengine.model import BaseModule from mmengine.model import BaseModule
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
from torch import Tensor from torch import Tensor
...@@ -14,7 +15,6 @@ from mmdet3d.models.layers import box3d_multiclass_nms ...@@ -14,7 +15,6 @@ from mmdet3d.models.layers import box3d_multiclass_nms
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet3d.structures import limit_period, xywhr2xyxyr from mmdet3d.structures import limit_period, xywhr2xyxyr
from mmdet3d.utils import InstanceList, OptInstanceList from mmdet3d.utils import InstanceList, OptInstanceList
from mmdet.models.utils import multi_apply
from ..builder import build_head from ..builder import build_head
from .anchor3d_head import Anchor3DHead from .anchor3d_head import Anchor3DHead
......
...@@ -2,6 +2,11 @@ ...@@ -2,6 +2,11 @@
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from mmdet.models.utils import (gaussian_radius, gen_gaussian_target,
multi_apply)
from mmdet.models.utils.gaussian_target import (get_local_maximum,
get_topk_from_heatmap,
transpose_and_gather_feat)
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
from torch import Tensor from torch import Tensor
from torch.nn import functional as F from torch.nn import functional as F
...@@ -9,11 +14,6 @@ from torch.nn import functional as F ...@@ -9,11 +14,6 @@ from torch.nn import functional as F
from mmdet3d.registry import MODELS, TASK_UTILS from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.utils import (ConfigType, InstanceList, OptConfigType, from mmdet3d.utils import (ConfigType, InstanceList, OptConfigType,
OptInstanceList, OptMultiConfig) OptInstanceList, OptMultiConfig)
from mmdet.models.utils import (gaussian_radius, gen_gaussian_target,
multi_apply)
from mmdet.models.utils.gaussian_target import (get_local_maximum,
get_topk_from_heatmap,
transpose_and_gather_feat)
from .anchor_free_mono3d_head import AnchorFreeMono3DHead from .anchor_free_mono3d_head import AnchorFreeMono3DHead
......
...@@ -3,6 +3,7 @@ from typing import List, Optional, Tuple, Union ...@@ -3,6 +3,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from mmcv.ops.nms import batched_nms from mmcv.ops.nms import batched_nms
from mmdet.models.utils import multi_apply
from mmengine import ConfigDict from mmengine import ConfigDict
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
from torch import Tensor from torch import Tensor
...@@ -13,7 +14,6 @@ from mmdet3d.structures import BaseInstance3DBoxes ...@@ -13,7 +14,6 @@ from mmdet3d.structures import BaseInstance3DBoxes
from mmdet3d.structures.bbox_3d import (DepthInstance3DBoxes, from mmdet3d.structures.bbox_3d import (DepthInstance3DBoxes,
LiDARInstance3DBoxes, LiDARInstance3DBoxes,
rotation_3d_in_axis) rotation_3d_in_axis)
from mmdet.models.utils import multi_apply
from .vote_head import VoteHead from .vote_head import VoteHead
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import numpy as np import numpy as np
import torch import torch
from mmdet.models.utils import images_to_levels, multi_apply
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
from mmdet3d.structures import limit_period from mmdet3d.structures import limit_period
from mmdet.models.utils import images_to_levels, multi_apply
class AnchorTrainMixin(object): class AnchorTrainMixin(object):
......
...@@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Tuple, Union ...@@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from mmcv.ops import furthest_point_sample from mmcv.ops import furthest_point_sample
from mmdet.models.utils import multi_apply
from mmengine import ConfigDict from mmengine import ConfigDict
from mmengine.model import BaseModule from mmengine.model import BaseModule
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
...@@ -14,7 +15,6 @@ from mmdet3d.models.layers import VoteModule, aligned_3d_nms, build_sa_module ...@@ -14,7 +15,6 @@ from mmdet3d.models.layers import VoteModule, aligned_3d_nms, build_sa_module
from mmdet3d.models.losses import chamfer_distance from mmdet3d.models.losses import chamfer_distance
from mmdet3d.registry import MODELS, TASK_UTILS from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures import Det3DDataSample from mmdet3d.structures import Det3DDataSample
from mmdet.models.utils import multi_apply
from .base_conv_bbox_head import BaseConvBboxHead from .base_conv_bbox_head import BaseConvBboxHead
......
...@@ -8,11 +8,13 @@ from .groupfree3dnet import GroupFree3DNet ...@@ -8,11 +8,13 @@ from .groupfree3dnet import GroupFree3DNet
from .h3dnet import H3DNet from .h3dnet import H3DNet
from .imvotenet import ImVoteNet from .imvotenet import ImVoteNet
from .imvoxelnet import ImVoxelNet from .imvoxelnet import ImVoxelNet
from .mink_single_stage import MinkSingleStage3DDetector
from .multiview_dfm import MultiViewDfM from .multiview_dfm import MultiViewDfM
from .mvx_faster_rcnn import DynamicMVXFasterRCNN, MVXFasterRCNN from .mvx_faster_rcnn import DynamicMVXFasterRCNN, MVXFasterRCNN
from .mvx_two_stage import MVXTwoStageDetector from .mvx_two_stage import MVXTwoStageDetector
from .parta2 import PartA2 from .parta2 import PartA2
from .point_rcnn import PointRCNN from .point_rcnn import PointRCNN
from .pv_rcnn import PointVoxelRCNN
from .sassd import SASSD from .sassd import SASSD
from .single_stage_mono3d import SingleStageMono3DDetector from .single_stage_mono3d import SingleStageMono3DDetector
from .smoke_mono3d import SMOKEMono3D from .smoke_mono3d import SMOKEMono3D
...@@ -21,25 +23,10 @@ from .votenet import VoteNet ...@@ -21,25 +23,10 @@ from .votenet import VoteNet
from .voxelnet import VoxelNet from .voxelnet import VoxelNet
__all__ = [ __all__ = [
'Base3DDetector', 'Base3DDetector', 'VoxelNet', 'DynamicVoxelNet', 'MVXTwoStageDetector',
'DfM', 'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet', 'H3DNet',
'VoxelNet', 'CenterPoint', 'SSD3DNet', 'ImVoteNet', 'SingleStageMono3DDetector',
'DynamicVoxelNet', 'FCOSMono3D', 'ImVoxelNet', 'GroupFree3DNet', 'PointRCNN', 'SMOKEMono3D',
'MVXTwoStageDetector', 'SASSD', 'MinkSingleStage3DDetector', 'MultiViewDfM', 'DfM',
'DynamicMVXFasterRCNN', 'PointVoxelRCNN'
'MVXFasterRCNN',
'MultiViewDfM',
'PartA2',
'VoteNet',
'H3DNet',
'CenterPoint',
'SSD3DNet',
'ImVoteNet',
'SingleStageMono3DDetector',
'FCOSMono3D',
'ImVoxelNet',
'GroupFree3DNet',
'PointRCNN',
'SMOKEMono3D',
'SASSD',
] ]
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Union from typing import List, Union
from mmdet.models import BaseDetector
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet3d.structures.det3d_data_sample import (ForwardResults, from mmdet3d.structures.det3d_data_sample import (ForwardResults,
OptSampleList, SampleList) OptSampleList, SampleList)
from mmdet3d.utils.typing import OptConfigType, OptInstanceList, OptMultiConfig from mmdet3d.utils.typing import OptConfigType, OptInstanceList, OptMultiConfig
from mmdet.models import BaseDetector
@MODELS.register_module() @MODELS.register_module()
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
from mmdet.models.detectors import BaseDetector
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet3d.structures.ops import bbox3d2result from mmdet3d.structures.ops import bbox3d2result
from mmdet3d.utils import ConfigType from mmdet3d.utils import ConfigType
from mmdet.models.detectors import BaseDetector
from ..builder import build_backbone, build_head, build_neck from ..builder import build_backbone, build_head, build_neck
......
# Copyright (c) OpenMMLab. All rights reserved.
# Adapted from https://github.com/SamsungLabs/fcaf3d/blob/master/mmdet3d/models/detectors/single_stage_sparse.py # noqa
from typing import Dict, List, OrderedDict, Tuple, Union
import torch
from torch import Tensor
try:
import MinkowskiEngine as ME
except ImportError:
# Please follow getting_started.md to install MinkowskiEngine.
ME = None
pass
from mmdet3d.registry import MODELS
from mmdet3d.utils import ConfigType, OptConfigType, OptMultiConfig
from .single_stage import SingleStage3DDetector
@MODELS.register_module()
class MinkSingleStage3DDetector(SingleStage3DDetector):
r"""MinkSingleStage3DDetector.
This class serves as a base class for single-stage 3D detectors based on
MinkowskiEngine `GSDN <https://arxiv.org/abs/2006.12356>`_.
Args:
backbone (dict): Config dict of detector's backbone.
neck (dict, optional): Config dict of neck. Defaults to None.
bbox_head (dict, optional): Config dict of box head. Defaults to None.
train_cfg (dict, optional): Config dict of training hyper-parameters.
Defaults to None.
test_cfg (dict, optional): Config dict of test hyper-parameters.
Defaults to None.
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`BaseDataPreprocessor`. it usually includes,
``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
init_cfg (dict or ConfigDict, optional): the config to control the
initialization. Defaults to None.
"""
_version = 2
def __init__(self,
backbone: ConfigType,
neck: OptConfigType = None,
bbox_head: OptConfigType = None,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None) -> None:
super().__init__(
backbone=backbone,
neck=neck,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg)
if ME is None:
raise ImportError(
'Please follow `getting_started.md` to install MinkowskiEngine.`' # noqa: E501
)
self.voxel_size = bbox_head['voxel_size']
def extract_feat(
self, batch_inputs_dict: Dict[str, Tensor]
) -> Union[Tuple[torch.Tensor], Dict[str, Tensor]]:
"""Directly extract features from the backbone+neck.
Args:
batch_inputs_dict (dict): The model input dict which includes
'points' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
Returns:
tuple[Tensor] | dict: For outside 3D object detection, we
typically obtain a tuple of features from the backbone + neck,
and for inside 3D object detection, usually a dict containing
features will be obtained.
"""
points = batch_inputs_dict['points']
coordinates, features = ME.utils.batch_sparse_collate(
[(p[:, :3] / self.voxel_size, p[:, 3:]) for p in points],
device=points[0].device)
x = ME.SparseTensor(coordinates=coordinates, features=features)
x = self.backbone(x)
if self.with_neck:
x = self.neck(x)
return x
def _load_from_state_dict(self, state_dict: OrderedDict, prefix: str,
local_metadata: Dict, strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str]) -> None:
"""Load checkpoint.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this
module.
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
parameters and buffers in this module
missing_keys (list of str): if ``strict=True``, add missing keys to
this list
unexpected_keys (list of str): if ``strict=True``, add unexpected
keys to this list
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
"""
# The names of some parameters in FCAF3D has been changed
# since 2022.10.
version = local_metadata.get('version', None)
if (version is None or
version < 2) and self.__class__ is MinkSingleStage3DDetector:
convert_dict = {'head.': 'bbox_head.'}
state_dict_keys = list(state_dict.keys())
for k in state_dict_keys:
for ori_key, convert_key in convert_dict.items():
if ori_key in k:
convert_key = k.replace(ori_key, convert_key)
state_dict[convert_key] = state_dict[k]
del state_dict[k]
super(MinkSingleStage3DDetector,
self)._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys,
unexpected_keys, error_msgs)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import numpy as np import numpy as np
import torch import torch
from mmdet.models.detectors import BaseDetector
from mmdet3d.models.layers.fusion_layers.point_fusion import (point_sample, from mmdet3d.models.layers.fusion_layers.point_fusion import (point_sample,
voxel_sample) voxel_sample)
...@@ -8,7 +9,6 @@ from mmdet3d.registry import MODELS, TASK_UTILS ...@@ -8,7 +9,6 @@ from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures.bbox_3d.utils import get_lidar2img from mmdet3d.structures.bbox_3d.utils import get_lidar2img
from mmdet3d.structures.det3d_data_sample import SampleList from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import ConfigType, OptConfigType from mmdet3d.utils import ConfigType, OptConfigType
from mmdet.models.detectors import BaseDetector
from .dfm import DfM from .dfm import DfM
from .imvoxelnet import ImVoxelNet from .imvoxelnet import ImVoxelNet
......
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Optional
from mmdet3d.registry import MODELS
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import InstanceList
from .two_stage import TwoStage3DDetector
@MODELS.register_module()
class PointVoxelRCNN(TwoStage3DDetector):
r"""PointVoxelRCNN detector.
Please refer to the `PointVoxelRCNN <https://arxiv.org/abs/1912.13192>`_.
Args:
voxel_encoder (dict): Point voxelization encoder layer.
middle_encoder (dict): Middle encoder layer
of points cloud modality.
backbone (dict): Backbone of extracting points features.
neck (dict, optional): Neck of extracting points features.
Defaults to None.
rpn_head (dict, optional): Config of RPN head. Defaults to None.
points_encoder (dict, optional): Points encoder to extract point-wise
features. Defaults to None.
roi_head (dict, optional): Config of ROI head. Defaults to None.
train_cfg (dict, optional): Train config of model.
Defaults to None.
test_cfg (dict, optional): Train config of model.
Defaults to None.
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`Det3DDataPreprocessor`. Defaults to None.
"""
def __init__(self,
voxel_encoder: dict,
middle_encoder: dict,
backbone: dict,
neck: Optional[dict] = None,
rpn_head: Optional[dict] = None,
points_encoder: Optional[dict] = None,
roi_head: Optional[dict] = None,
train_cfg: Optional[dict] = None,
test_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None,
data_preprocessor: Optional[dict] = None) -> None:
super().__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg,
data_preprocessor=data_preprocessor)
self.voxel_encoder = MODELS.build(voxel_encoder)
self.middle_encoder = MODELS.build(middle_encoder)
self.points_encoder = MODELS.build(points_encoder)
def predict(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
**kwargs) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'voxels' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- voxels (dict[torch.Tensor]): Voxels of the batch sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Returns:
list[:obj:`Det3DDataSample`]: Detection results of the
input samples. Each Det3DDataSample usually contain
'pred_instances_3d'. And the ``pred_instances_3d`` usually
contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instance, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (Tensor): Contains a tensor with shape
(num_instances, C) where C >=7.
"""
feats_dict = self.extract_feat(batch_inputs_dict)
if self.with_rpn:
rpn_results_list = self.rpn_head.predict(feats_dict,
batch_data_samples)
else:
rpn_results_list = [
data_sample.proposals for data_sample in batch_data_samples
]
# extrack points feats by points_encoder
points_feats_dict = self.extract_points_feat(batch_inputs_dict,
feats_dict,
rpn_results_list)
results_list_3d = self.roi_head.predict(points_feats_dict,
rpn_results_list,
batch_data_samples)
# connvert to Det3DDataSample
results_list = self.add_pred_to_datasample(batch_data_samples,
results_list_3d)
return results_list
def extract_feat(self, batch_inputs_dict: dict) -> dict:
"""Extract features from the input voxels.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'voxels' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- voxels (dict[torch.Tensor]): Voxels of the batch sample.
Returns:
dict: We typically obtain a dict of features from the backbone +
neck, it includes:
- spatial_feats (torch.Tensor): Spatial feats from middle
encoder.
- multi_scale_3d_feats (list[torch.Tensor]): Multi scale
middle feats from middle encoder.
- neck_feats (torch.Tensor): Neck feats from neck.
"""
feats_dict = dict()
voxel_dict = batch_inputs_dict['voxels']
voxel_features = self.voxel_encoder(voxel_dict['voxels'],
voxel_dict['num_points'],
voxel_dict['coors'])
batch_size = voxel_dict['coors'][-1, 0].item() + 1
feats_dict['spatial_feats'], feats_dict[
'multi_scale_3d_feats'] = self.middle_encoder(
voxel_features, voxel_dict['coors'], batch_size)
x = self.backbone(feats_dict['spatial_feats'])
if self.with_neck:
neck_feats = self.neck(x)
feats_dict['neck_feats'] = neck_feats
return feats_dict
def extract_points_feat(self, batch_inputs_dict: dict, feats_dict: dict,
rpn_results_list: InstanceList) -> dict:
"""Extract point-wise features from the raw points and voxel features.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'voxels' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- voxels (dict[torch.Tensor]): Voxels of the batch sample.
feats_dict (dict): Contains features from the first stage.
rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head.
Returns:
dict: Contain Point-wise features, include:
- keypoints (torch.Tensor): Sampled key points.
- keypoint_features (torch.Tensor): Gather key points features
from multi input.
- fusion_keypoint_features (torch.Tensor): Fusion
keypoint_features by point_feature_fusion_layer.
"""
return self.points_encoder(batch_inputs_dict, feats_dict,
rpn_results_list)
def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
**kwargs):
"""Calculate losses from a batch of inputs and data samples.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'voxels' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- voxels (dict[torch.Tensor]): Voxels of the batch sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Returns:
dict: A dictionary of loss components.
"""
feats_dict = self.extract_feat(batch_inputs_dict)
losses = dict()
# RPN forward and loss
if self.with_rpn:
proposal_cfg = self.train_cfg.get('rpn_proposal',
self.test_cfg.rpn)
rpn_data_samples = copy.deepcopy(batch_data_samples)
rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict(
feats_dict,
rpn_data_samples,
proposal_cfg=proposal_cfg,
**kwargs)
# avoid get same name with roi_head loss
keys = rpn_losses.keys()
for key in keys:
if 'loss' in key and 'rpn' not in key:
rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key)
losses.update(rpn_losses)
else:
# TODO: Not support currently, should have a check at Fast R-CNN
assert batch_data_samples[0].get('proposals', None) is not None
# use pre-defined proposals in InstanceData for the second stage
# to extract ROI features.
rpn_results_list = [
data_sample.proposals for data_sample in batch_data_samples
]
points_feats_dict = self.extract_points_feat(batch_inputs_dict,
feats_dict,
rpn_results_list)
roi_losses = self.roi_head.loss(points_feats_dict, rpn_results_list,
batch_data_samples)
losses.update(roi_losses)
return losses
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch from typing import Tuple, Union
from mmcv.ops import Voxelization
from torch.nn import functional as F
from mmdet3d.models.test_time_augs import merge_aug_bboxes_3d from torch import Tensor
from mmdet3d.structures.ops import bbox3d2result
from mmdet.models.builder import DETECTORS from mmdet3d.registry import MODELS
from mmdet.registry import MODELS from mmdet3d.utils import ConfigType, OptConfigType, OptMultiConfig
from ...structures.det3d_data_sample import SampleList
from .single_stage import SingleStage3DDetector from .single_stage import SingleStage3DDetector
@DETECTORS.register_module() @MODELS.register_module()
class SASSD(SingleStage3DDetector): class SASSD(SingleStage3DDetector):
r"""`SASSD <https://github.com/skyhehe123/SA-SSD>` _ for 3D detection.""" r"""`SASSD <https://github.com/skyhehe123/SA-SSD>` _ for 3D detection."""
def __init__(self, def __init__(self,
voxel_layer, voxel_encoder: ConfigType,
voxel_encoder, middle_encoder: ConfigType,
middle_encoder, backbone: ConfigType,
backbone, neck: OptConfigType = None,
neck=None, bbox_head: OptConfigType = None,
bbox_head=None, train_cfg: OptConfigType = None,
train_cfg=None, test_cfg: OptConfigType = None,
test_cfg=None, data_preprocessor: OptConfigType = None,
init_cfg=None, init_cfg: OptMultiConfig = None):
pretrained=None):
super(SASSD, self).__init__( super(SASSD, self).__init__(
backbone=backbone, backbone=backbone,
neck=neck, neck=neck,
bbox_head=bbox_head, bbox_head=bbox_head,
train_cfg=train_cfg, train_cfg=train_cfg,
test_cfg=test_cfg, test_cfg=test_cfg,
init_cfg=init_cfg, data_preprocessor=data_preprocessor,
pretrained=pretrained) init_cfg=init_cfg)
self.voxel_layer = Voxelization(**voxel_layer)
self.voxel_encoder = MODELS.build(voxel_encoder) self.voxel_encoder = MODELS.build(voxel_encoder)
self.middle_encoder = MODELS.build(middle_encoder) self.middle_encoder = MODELS.build(middle_encoder)
def extract_feat(self, points, img_metas=None, test_mode=False): def extract_feat(
"""Extract features from points.""" self,
voxels, num_points, coors = self.voxelize(points) batch_inputs_dict: dict,
voxel_features = self.voxel_encoder(voxels, num_points, coors) test_mode: bool = True
batch_size = coors[-1, 0].item() + 1 ) -> Union[Tuple[Tuple[Tensor], Tuple], Tuple[Tensor]]:
x, point_misc = self.middle_encoder(voxel_features, coors, batch_size, """Extract features from points.
Args:
batch_inputs_dict (dict): The batch inputs.
test_mode (bool, optional): Whether test mode. Defaults to True.
Returns:
Union[Tuple[Tuple[Tensor], Tuple], Tuple[Tensor]]: In test mode, it
returns the features of points from multiple levels. In training
mode, it returns the features of points from multiple levels and a
tuple containing the mean features of points and the targets of
clssification and regression.
"""
voxel_dict = batch_inputs_dict['voxels']
voxel_features = self.voxel_encoder(voxel_dict['voxels'],
voxel_dict['num_points'],
voxel_dict['coors'])
batch_size = voxel_dict['coors'][-1, 0].item() + 1
# `point_misc` is a tuple containing the mean features of points and
# the targets of clssification and regression. It's only used for
# calculating auxiliary loss in training mode.
x, point_misc = self.middle_encoder(voxel_features,
voxel_dict['coors'], batch_size,
test_mode) test_mode)
x = self.backbone(x) x = self.backbone(x)
if self.with_neck: if self.with_neck:
x = self.neck(x) x = self.neck(x)
return x, point_misc
@torch.no_grad() return (x, point_misc) if not test_mode else x
def voxelize(self, points):
"""Apply hard voxelization to points."""
voxels, coors, num_points = [], [], []
for res in points:
res_voxels, res_coors, res_num_points = self.voxel_layer(res)
voxels.append(res_voxels)
coors.append(res_coors)
num_points.append(res_num_points)
voxels = torch.cat(voxels, dim=0)
num_points = torch.cat(num_points, dim=0)
coors_batch = []
for i, coor in enumerate(coors):
coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
coors_batch.append(coor_pad)
coors_batch = torch.cat(coors_batch, dim=0)
return voxels, num_points, coors_batch
def forward_train(self, def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
points, **kwargs) -> dict:
img_metas, """Calculate losses from a batch of inputs dict and data samples.
gt_bboxes_3d,
gt_labels_3d,
gt_bboxes_ignore=None):
"""Training forward function.
Args: Args:
points (list[torch.Tensor]): Point cloud of each sample. batch_inputs_dict (dict): The model input dict which include
img_metas (list[dict]): Meta information of each sample 'points' keys.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth - points (list[torch.Tensor]): Point cloud of each sample.
boxes for each sample.
gt_labels_3d (list[torch.Tensor]): Ground truth labels for batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
boxes of each sampole Samples. It usually includes information such as
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
boxes to be ignored. Defaults to None.
Returns: Returns:
dict: Losses of each branch. dict: A dictionary of loss components.
""" """
x, point_misc = self.extract_feat(batch_inputs_dict, test_mode=False)
x, point_misc = self.extract_feat(points, img_metas, test_mode=False) batch_gt_bboxes_3d = [
aux_loss = self.middle_encoder.aux_loss(*point_misc, gt_bboxes_3d) data_sample.gt_instances_3d.bboxes_3d
for data_sample in batch_data_samples
outs = self.bbox_head(x) ]
loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_metas) aux_loss = self.middle_encoder.aux_loss(*point_misc,
losses = self.bbox_head.loss( batch_gt_bboxes_3d)
*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) losses = self.bbox_head.loss(x, batch_data_samples)
losses.update(aux_loss) losses.update(aux_loss)
return losses return losses
def simple_test(self, points, img_metas, imgs=None, rescale=False):
"""Test function without augmentaiton."""
x, _ = self.extract_feat(points, img_metas, test_mode=True)
outs = self.bbox_head(x)
bbox_list = self.bbox_head.get_bboxes(
*outs, img_metas, rescale=rescale)
bbox_results = [
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
return bbox_results
def aug_test(self, points, img_metas, imgs=None, rescale=False):
"""Test function with augmentaiton."""
feats = self.extract_feats(points, img_metas, test_mode=True)
# only support aug_test for one sample
aug_bboxes = []
for x, img_meta in zip(feats, img_metas):
outs = self.bbox_head(x)
bbox_list = self.bbox_head.get_bboxes(
*outs, img_meta, rescale=rescale)
bbox_list = [
dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels)
for bboxes, scores, labels in bbox_list
]
aug_bboxes.append(bbox_list[0])
# after merging, bboxes will be rescaled to the original image size
merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
self.bbox_head.test_cfg)
return [merged_bboxes]
...@@ -143,7 +143,11 @@ class SingleStage3DDetector(Base3DDetector): ...@@ -143,7 +143,11 @@ class SingleStage3DDetector(Base3DDetector):
"""Directly extract features from the backbone+neck. """Directly extract features from the backbone+neck.
Args: Args:
points (torch.Tensor): Input points. batch_inputs_dict (dict): The model input dict which include
'points', 'img' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor, optional): Image of each sample.
Returns: Returns:
tuple[Tensor] | dict: For outside 3D object detection, we tuple[Tensor] | dict: For outside 3D object detection, we
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple from typing import Tuple
from mmdet.models.detectors.single_stage import SingleStageDetector
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
from torch import Tensor from torch import Tensor
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet3d.structures.det3d_data_sample import SampleList from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import OptInstanceList from mmdet3d.utils import OptInstanceList
from mmdet.models.detectors.single_stage import SingleStageDetector
@MODELS.register_module() @MODELS.register_module()
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple
import numba import numba
import numpy as np import numpy as np
import torch import torch
from mmcv.ops import nms, nms_rotated from mmcv.ops import nms, nms_rotated
from torch import Tensor
def box3d_multiclass_nms(mlvl_bboxes,
mlvl_bboxes_for_nms, def box3d_multiclass_nms(
mlvl_scores, mlvl_bboxes: Tensor,
score_thr, mlvl_bboxes_for_nms: Tensor,
max_num, mlvl_scores: Tensor,
cfg, score_thr: float,
mlvl_dir_scores=None, max_num: int,
mlvl_attr_scores=None, cfg: dict,
mlvl_bboxes2d=None): mlvl_dir_scores: Optional[Tensor] = None,
mlvl_attr_scores: Optional[Tensor] = None,
mlvl_bboxes2d: Optional[Tensor] = None) -> Tuple[Tensor]:
"""Multi-class NMS for 3D boxes. The IoU used for NMS is defined as the 2D """Multi-class NMS for 3D boxes. The IoU used for NMS is defined as the 2D
IoU between BEV boxes. IoU between BEV boxes.
Args: Args:
mlvl_bboxes (torch.Tensor): Multi-level boxes with shape (N, M). mlvl_bboxes (Tensor): Multi-level boxes with shape (N, M).
M is the dimensions of boxes. M is the dimensions of boxes.
mlvl_bboxes_for_nms (torch.Tensor): Multi-level boxes with shape mlvl_bboxes_for_nms (Tensor): Multi-level boxes with shape (N, 5)
(N, 5) ([x1, y1, x2, y2, ry]). N is the number of boxes. ([x1, y1, x2, y2, ry]). N is the number of boxes.
The coordinate system of the BEV boxes is counterclockwise. The coordinate system of the BEV boxes is counterclockwise.
mlvl_scores (torch.Tensor): Multi-level boxes with shape mlvl_scores (Tensor): Multi-level boxes with shape (N, C + 1).
(N, C + 1). N is the number of boxes. C is the number of classes. N is the number of boxes. C is the number of classes.
score_thr (float): Score threshold to filter boxes with low score_thr (float): Score threshold to filter boxes with low confidence.
confidence.
max_num (int): Maximum number of boxes will be kept. max_num (int): Maximum number of boxes will be kept.
cfg (dict): Configuration dict of NMS. cfg (dict): Configuration dict of NMS.
mlvl_dir_scores (torch.Tensor, optional): Multi-level scores mlvl_dir_scores (Tensor, optional): Multi-level scores of direction
of direction classifier. Defaults to None. classifier. Defaults to None.
mlvl_attr_scores (torch.Tensor, optional): Multi-level scores mlvl_attr_scores (Tensor, optional): Multi-level scores of attribute
of attribute classifier. Defaults to None. classifier. Defaults to None.
mlvl_bboxes2d (torch.Tensor, optional): Multi-level 2D bounding mlvl_bboxes2d (Tensor, optional): Multi-level 2D bounding boxes.
boxes. Defaults to None. Defaults to None.
Returns: Returns:
tuple[torch.Tensor]: Return results after nms, including 3D Tuple[Tensor]: Return results after nms, including 3D bounding boxes,
bounding boxes, scores, labels, direction scores, attribute scores, labels, direction scores, attribute scores (optional) and
scores (optional) and 2D bounding boxes (optional). 2D bounding boxes (optional).
""" """
# do multi class nms # do multi class nms
# the fg class id range: [0, num_classes-1] # the fg class id range: [0, num_classes-1]
...@@ -128,17 +131,18 @@ def box3d_multiclass_nms(mlvl_bboxes, ...@@ -128,17 +131,18 @@ def box3d_multiclass_nms(mlvl_bboxes,
return results return results
def aligned_3d_nms(boxes, scores, classes, thresh): def aligned_3d_nms(boxes: Tensor, scores: Tensor, classes: Tensor,
thresh: float) -> Tensor:
"""3D NMS for aligned boxes. """3D NMS for aligned boxes.
Args: Args:
boxes (torch.Tensor): Aligned box with shape [n, 6]. boxes (Tensor): Aligned box with shape [N, 6].
scores (torch.Tensor): Scores of each box. scores (Tensor): Scores of each box.
classes (torch.Tensor): Class of each box. classes (Tensor): Class of each box.
thresh (float): IoU threshold for nms. thresh (float): IoU threshold for nms.
Returns: Returns:
torch.Tensor: Indices of selected boxes. Tensor: Indices of selected boxes.
""" """
x1 = boxes[:, 0] x1 = boxes[:, 0]
y1 = boxes[:, 1] y1 = boxes[:, 1]
...@@ -179,21 +183,20 @@ def aligned_3d_nms(boxes, scores, classes, thresh): ...@@ -179,21 +183,20 @@ def aligned_3d_nms(boxes, scores, classes, thresh):
@numba.jit(nopython=True) @numba.jit(nopython=True)
def circle_nms(dets, thresh, post_max_size=83): def circle_nms(dets: Tensor, thresh: float, post_max_size: int = 83) -> Tensor:
"""Circular NMS. """Circular NMS.
An object is only counted as positive if no other center An object is only counted as positive if no other center with a higher
with a higher confidence exists within a radius r using a confidence exists within a radius r using a bird-eye view distance metric.
bird-eye view distance metric.
Args: Args:
dets (torch.Tensor): Detection results with the shape of [N, 3]. dets (Tensor): Detection results with the shape of [N, 3].
thresh (float): Value of threshold. thresh (float): Value of threshold.
post_max_size (int, optional): Max number of prediction to be kept. post_max_size (int): Max number of prediction to be kept.
Defaults to 83. Defaults to 83.
Returns: Returns:
torch.Tensor: Indexes of the detections to be kept. Tensor: Indexes of the detections to be kept.
""" """
x1 = dets[:, 0] x1 = dets[:, 0]
y1 = dets[:, 1] y1 = dets[:, 1]
...@@ -228,24 +231,28 @@ def circle_nms(dets, thresh, post_max_size=83): ...@@ -228,24 +231,28 @@ def circle_nms(dets, thresh, post_max_size=83):
# This function duplicates functionality of mmcv.ops.iou_3d.nms_bev # This function duplicates functionality of mmcv.ops.iou_3d.nms_bev
# from mmcv<=1.5, but using cuda ops from mmcv.ops.nms.nms_rotated. # from mmcv<=1.5, but using cuda ops from mmcv.ops.nms.nms_rotated.
# Nms api will be unified in mmdetection3d one day. # Nms api will be unified in mmdetection3d one day.
def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None): def nms_bev(boxes: Tensor,
scores: Tensor,
thresh: float,
pre_max_size: Optional[int] = None,
post_max_size: Optional[int] = None) -> Tensor:
"""NMS function GPU implementation (for BEV boxes). The overlap of two """NMS function GPU implementation (for BEV boxes). The overlap of two
boxes for IoU calculation is defined as the exact overlapping area of the boxes for IoU calculation is defined as the exact overlapping area of the
two boxes. In this function, one can also set ``pre_max_size`` and two boxes. In this function, one can also set ``pre_max_size`` and
``post_max_size``. ``post_max_size``.
Args: Args:
boxes (torch.Tensor): Input boxes with the shape of [N, 5] boxes (Tensor): Input boxes with the shape of [N, 5]
([x1, y1, x2, y2, ry]). ([x1, y1, x2, y2, ry]).
scores (torch.Tensor): Scores of boxes with the shape of [N]. scores (Tensor): Scores of boxes with the shape of [N].
thresh (float): Overlap threshold of NMS. thresh (float): Overlap threshold of NMS.
pre_max_size (int, optional): Max size of boxes before NMS. pre_max_size (int, optional): Max size of boxes before NMS.
Default: None. Defaults to None.
post_max_size (int, optional): Max size of boxes after NMS. post_max_size (int, optional): Max size of boxes after NMS.
Default: None. Defaults to None.
Returns: Returns:
torch.Tensor: Indexes after NMS. Tensor: Indexes after NMS.
""" """
assert boxes.size(1) == 5, 'Input boxes shape should be [N, 5]' assert boxes.size(1) == 5, 'Input boxes shape should be [N, 5]'
order = scores.sort(0, descending=True)[1] order = scores.sort(0, descending=True)[1]
...@@ -271,18 +278,18 @@ def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None): ...@@ -271,18 +278,18 @@ def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None):
# This function duplicates functionality of mmcv.ops.iou_3d.nms_normal_bev # This function duplicates functionality of mmcv.ops.iou_3d.nms_normal_bev
# from mmcv<=1.5, but using cuda ops from mmcv.ops.nms.nms. # from mmcv<=1.5, but using cuda ops from mmcv.ops.nms.nms.
# Nms api will be unified in mmdetection3d one day. # Nms api will be unified in mmdetection3d one day.
def nms_normal_bev(boxes, scores, thresh): def nms_normal_bev(boxes: Tensor, scores: Tensor, thresh: float) -> Tensor:
"""Normal NMS function GPU implementation (for BEV boxes). The overlap of """Normal NMS function GPU implementation (for BEV boxes). The overlap of
two boxes for IoU calculation is defined as the exact overlapping area of two boxes for IoU calculation is defined as the exact overlapping area of
the two boxes WITH their yaw angle set to 0. the two boxes WITH their yaw angle set to 0.
Args: Args:
boxes (torch.Tensor): Input boxes with shape (N, 5). boxes (Tensor): Input boxes with shape (N, 5).
scores (torch.Tensor): Scores of predicted boxes with shape (N). scores (Tensor): Scores of predicted boxes with shape (N).
thresh (float): Overlap threshold of NMS. thresh (float): Overlap threshold of NMS.
Returns: Returns:
torch.Tensor: Remaining indices with scores in descending order. Tensor: Remaining indices with scores in descending order.
""" """
assert boxes.shape[1] == 5, 'Input boxes shape should be [N, 5]' assert boxes.shape[1] == 5, 'Input boxes shape should be [N, 5]'
return nms(boxes[:, :-1], scores, thresh)[1] return nms(boxes[:, :-1], scores, thresh)[1]
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmengine.model import BaseModule from mmengine.model import BaseModule
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmdet3d.utils import ConfigType, OptMultiConfig
class DGCNNFAModule(BaseModule): class DGCNNFAModule(BaseModule):
"""Point feature aggregation module used in DGCNN. """Point feature aggregation module used in DGCNN.
...@@ -11,21 +16,21 @@ class DGCNNFAModule(BaseModule): ...@@ -11,21 +16,21 @@ class DGCNNFAModule(BaseModule):
Aggregate all the features of points. Aggregate all the features of points.
Args: Args:
mlp_channels (list[int]): List of mlp channels. mlp_channels (List[int]): List of mlp channels.
norm_cfg (dict, optional): Type of normalization method. norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
Defaults to dict(type='BN1d'). layer. Defaults to dict(type='BN1d').
act_cfg (dict, optional): Type of activation method. act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
Defaults to dict(type='ReLU'). Defaults to dict(type='ReLU').
init_cfg (dict, optional): Initialization config. Defaults to None. init_cfg (:obj:`ConfigDict` or dict or List[:obj:`Contigdict` or dict],
optional): Initialization config dict. Defaults to None.
""" """
def __init__(self, def __init__(self,
mlp_channels, mlp_channels: List[int],
norm_cfg=dict(type='BN1d'), norm_cfg: ConfigType = dict(type='BN1d'),
act_cfg=dict(type='ReLU'), act_cfg: ConfigType = dict(type='ReLU'),
init_cfg=None): init_cfg: OptMultiConfig = None) -> None:
super().__init__(init_cfg=init_cfg) super(DGCNNFAModule, self).__init__(init_cfg=init_cfg)
self.fp16_enabled = False
self.mlps = nn.Sequential() self.mlps = nn.Sequential()
for i in range(len(mlp_channels) - 1): for i in range(len(mlp_channels) - 1):
self.mlps.add_module( self.mlps.add_module(
...@@ -39,14 +44,14 @@ class DGCNNFAModule(BaseModule): ...@@ -39,14 +44,14 @@ class DGCNNFAModule(BaseModule):
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
act_cfg=act_cfg)) act_cfg=act_cfg))
def forward(self, points): def forward(self, points: List[Tensor]) -> Tensor:
"""forward. """forward.
Args: Args:
points (List[Tensor]): tensor of the features to be aggregated. points (List[Tensor]): Tensor of the features to be aggregated.
Returns: Returns:
Tensor: (B, N, M) M = mlp[-1], tensor of the output points. Tensor: (B, N, M) M = mlp[-1]. Tensor of the output points.
""" """
if len(points) > 1: if len(points) > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment