Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
d7ade147
Commit
d7ade147
authored
Jun 30, 2020
by
zhangwenwei
Browse files
Merge branch 'add-flop-counter' into 'master'
update docstrings in core See merge request open-mmlab/mmdet.3d!103
parents
398f541e
9732a488
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
256 additions
and
96 deletions
+256
-96
mmdet3d/models/necks/second_fpn.py
mmdet3d/models/necks/second_fpn.py
+7
-58
mmdet3d/models/roi_heads/base_3droi_head.py
mmdet3d/models/roi_heads/base_3droi_head.py
+6
-0
mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py
mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py
+39
-3
mmdet3d/models/roi_heads/part_aggregation_roi_head.py
mmdet3d/models/roi_heads/part_aggregation_roi_head.py
+45
-1
mmdet3d/models/roi_heads/roi_extractors/single_roiaware_extractor.py
...els/roi_heads/roi_extractors/single_roiaware_extractor.py
+1
-0
mmdet3d/models/voxel_encoders/pillar_encoder.py
mmdet3d/models/voxel_encoders/pillar_encoder.py
+74
-18
mmdet3d/models/voxel_encoders/utils.py
mmdet3d/models/voxel_encoders/utils.py
+59
-14
mmdet3d/models/voxel_encoders/voxel_encoder.py
mmdet3d/models/voxel_encoders/voxel_encoder.py
+24
-2
mmdet3d/utils/collect_env.py
mmdet3d/utils/collect_env.py
+1
-0
No files found.
mmdet3d/models/necks/second_fpn.py
View file @
d7ade147
...
...
@@ -4,7 +4,6 @@ from mmcv.cnn import (build_norm_layer, build_upsample_layer, constant_init,
is_norm
,
kaiming_init
)
from
mmdet.models
import
NECKS
from
..
import
builder
@
NECKS
.
register_module
()
...
...
@@ -47,6 +46,7 @@ class SECONDFPN(nn.Module):
self
.
deblocks
=
nn
.
ModuleList
(
deblocks
)
def
init_weights
(
self
):
"""Initialize weights of FPN"""
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
kaiming_init
(
m
)
...
...
@@ -54,53 +54,14 @@ class SECONDFPN(nn.Module):
constant_init
(
m
,
1
)
def
forward
(
self
,
x
):
assert
len
(
x
)
==
len
(
self
.
in_channels
)
ups
=
[
deblock
(
x
[
i
])
for
i
,
deblock
in
enumerate
(
self
.
deblocks
)]
if
len
(
ups
)
>
1
:
out
=
torch
.
cat
(
ups
,
dim
=
1
)
else
:
out
=
ups
[
0
]
return
[
out
]
@
NECKS
.
register_module
()
class
SECONDFusionFPN
(
SECONDFPN
):
"""FPN used in multi-modality SECOND/PointPillars
Args:
in_channels (list[int]): Input channels of multi-scale feature maps
out_channels (list[int]): Output channels of feature maps
upsample_strides (list[int]): Strides used to upsample the feature maps
norm_cfg (dict): Config dict of normalization layers
upsample_cfg (dict): Config dict of upsample layers
downsample_rates (list[int]): The downsample rate of feature map in
comparison to the original voxelization input
fusion_layer (dict): Config dict of fusion layers
"""
"""Forward function
def
__init__
(
self
,
in_channels
=
[
128
,
128
,
256
],
out_channels
=
[
256
,
256
,
256
],
upsample_strides
=
[
1
,
2
,
4
],
norm_cfg
=
dict
(
type
=
'BN'
,
eps
=
1e-3
,
momentum
=
0.01
),
upsample_cfg
=
dict
(
type
=
'deconv'
,
bias
=
False
),
downsample_rates
=
[
40
,
8
,
8
],
fusion_layer
=
None
):
super
(
SECONDFusionFPN
,
self
).
__init__
(
in_channels
,
out_channels
,
upsample_strides
,
norm_cfg
,
upsample_cfg
)
self
.
fusion_layer
=
None
if
fusion_layer
is
not
None
:
self
.
fusion_layer
=
builder
.
build_fusion_layer
(
fusion_layer
)
self
.
downsample_rates
=
downsample_rates
Args:
x (torch.Tensor): 4D Tensor in (N, C, H, W) shape.
def
forward
(
self
,
x
,
coors
=
None
,
points
=
None
,
img_feats
=
None
,
img_metas
=
None
):
Returns:
list[torch.Tensor]: Multi-level feature maps.
"""
assert
len
(
x
)
==
len
(
self
.
in_channels
)
ups
=
[
deblock
(
x
[
i
])
for
i
,
deblock
in
enumerate
(
self
.
deblocks
)]
...
...
@@ -108,16 +69,4 @@ class SECONDFusionFPN(SECONDFPN):
out
=
torch
.
cat
(
ups
,
dim
=
1
)
else
:
out
=
ups
[
0
]
if
(
self
.
fusion_layer
is
not
None
and
img_feats
is
not
None
):
downsample_pts_coors
=
torch
.
zeros_like
(
coors
)
downsample_pts_coors
[:,
0
]
=
coors
[:,
0
]
downsample_pts_coors
[:,
1
]
=
(
coors
[:,
1
]
/
self
.
downsample_rates
[
0
])
downsample_pts_coors
[:,
2
]
=
(
coors
[:,
2
]
/
self
.
downsample_rates
[
1
])
downsample_pts_coors
[:,
3
]
=
(
coors
[:,
3
]
/
self
.
downsample_rates
[
2
])
# fusion for each point
out
=
self
.
fusion_layer
(
img_feats
,
points
,
out
,
downsample_pts_coors
,
img_metas
)
return
[
out
]
mmdet3d/models/roi_heads/base_3droi_head.py
View file @
d7ade147
...
...
@@ -26,26 +26,32 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta):
@
property
def
with_bbox
(
self
):
"""bool: whether the RoIHead has box head"""
return
hasattr
(
self
,
'bbox_head'
)
and
self
.
bbox_head
is
not
None
@
property
def
with_mask
(
self
):
"""bool: whether the RoIHead has mask head"""
return
hasattr
(
self
,
'mask_head'
)
and
self
.
mask_head
is
not
None
@
abstractmethod
def
init_weights
(
self
,
pretrained
):
"""Initialize the module with pre-trained weights."""
pass
@
abstractmethod
def
init_bbox_head
(
self
):
"""Initialize the box head."""
pass
@
abstractmethod
def
init_mask_head
(
self
):
"""Initialize maek head."""
pass
@
abstractmethod
def
init_assigner_sampler
(
self
):
"""Initialize assigner and sampler"""
pass
@
abstractmethod
...
...
mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py
View file @
d7ade147
...
...
@@ -4,13 +4,12 @@ import torch.nn as nn
from
mmcv.cnn
import
ConvModule
,
normal_init
,
xavier_init
import
mmdet3d.ops.spconv
as
spconv
from
mmdet3d.core
import
build_bbox_coder
,
xywhr2xyxyr
from
mmdet3d.core.bbox.structures
import
(
LiDARInstance3DBoxes
,
rotation_3d_in_axis
)
rotation_3d_in_axis
,
xywhr2xyxyr
)
from
mmdet3d.models.builder
import
build_loss
from
mmdet3d.ops
import
make_sparse_convmodule
from
mmdet3d.ops.iou3d.iou3d_utils
import
nms_gpu
,
nms_normal_gpu
from
mmdet.core
import
multi_apply
from
mmdet.core
import
build_bbox_coder
,
multi_apply
from
mmdet.models
import
HEADS
...
...
@@ -224,6 +223,7 @@ class PartA2BboxHead(nn.Module):
self
.
init_weights
()
def
init_weights
(
self
):
"""Initialize weights of the bbox head"""
for
m
in
self
.
modules
():
if
isinstance
(
m
,
(
nn
.
Conv2d
,
nn
.
Conv1d
)):
xavier_init
(
m
,
distribution
=
'uniform'
)
...
...
@@ -390,6 +390,22 @@ class PartA2BboxHead(nn.Module):
bbox_weights
)
def
_get_target_single
(
self
,
pos_bboxes
,
pos_gt_bboxes
,
ious
,
cfg
):
"""Generate training targets for a single sample.
Args:
pos_bboxes (torch.Tensor): Positive boxes with shape
(N, 7).
pos_gt_bboxes (torch.Tensor): Ground truth boxes with shape
(M, 7).
ious (torch.Tensor): IoU between `pos_bboxes` and `pos_gt_bboxes`
in shape (N, M).
cfg (dict): Training configs.
Returns:
tuple: Target for positive boxes.
(label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights,
bbox_weights)
"""
cls_pos_mask
=
ious
>
cfg
.
cls_pos_thr
cls_neg_mask
=
ious
<
cfg
.
cls_neg_thr
interval_mask
=
(
cls_pos_mask
==
0
)
&
(
cls_neg_mask
==
0
)
...
...
@@ -540,6 +556,26 @@ class PartA2BboxHead(nn.Module):
nms_thr
,
input_meta
,
use_rotate_nms
=
True
):
"""Multi-class NMS for box head
Note:
This function has large overlap with the `box3d_multiclass_nms`
implemented in `mmdet3d.core.post_processing`. We are considering
merging these two functions in the future.
Args:
box_probs (torch.Tensor): Predicted boxes probabitilies in
shape (N,).
box_preds (torch.Tensor): Predicted boxes in shape (N, 7+C).
score_thr (float): Threshold of scores.
nms_thr (float): Threshold for NMS.
input_meta (dict): Meta informations of the current sample.
use_rotate_nms (bool, optional): Whether to use rotated nms.
Defaults to True.
Returns:
torch.Tensor: Selected indices.
"""
if
use_rotate_nms
:
nms_func
=
nms_gpu
else
:
...
...
mmdet3d/models/roi_heads/part_aggregation_roi_head.py
View file @
d7ade147
...
...
@@ -44,15 +44,21 @@ class PartAggregationROIHead(Base3DRoIHead):
self
.
init_assigner_sampler
()
def
init_weights
(
self
,
pretrained
):
"""Initialize weights, skip since ``PartAggregationROIHead``
does not need to initialize weights"""
pass
def
init_mask_head
(
self
):
"""Initialize mask head, skip since ``PartAggregationROIHead``
does not have one."""
pass
def
init_bbox_head
(
self
,
bbox_head
):
"""Initialize box head"""
self
.
bbox_head
=
build_head
(
bbox_head
)
def
init_assigner_sampler
(
self
):
"""Initialize assigner and sampler"""
self
.
bbox_assigner
=
None
self
.
bbox_sampler
=
None
if
self
.
train_cfg
:
...
...
@@ -66,6 +72,7 @@ class PartAggregationROIHead(Base3DRoIHead):
@
property
def
with_semantic
(
self
):
"""bool: whether the head has semantic branch"""
return
hasattr
(
self
,
'semantic_head'
)
and
self
.
semantic_head
is
not
None
...
...
@@ -152,6 +159,18 @@ class PartAggregationROIHead(Base3DRoIHead):
def
_bbox_forward_train
(
self
,
seg_feats
,
part_feats
,
voxels_dict
,
sampling_results
):
"""Forward training function of roi_extractor and bbox_head.
Args:
seg_feats (torch.Tensor): Point-wise semantic features.
part_feats (torch.Tensor): Point-wise part prediction features.
voxels_dict (dict): Contains information of voxels.
sampling_results (:obj:`SamplingResult`): Sampled results used
for training.
Returns:
dict: Forward results including losses and predictions.
"""
rois
=
bbox3d2roi
([
res
.
bboxes
for
res
in
sampling_results
])
bbox_results
=
self
.
_bbox_forward
(
seg_feats
,
part_feats
,
voxels_dict
,
rois
)
...
...
@@ -166,7 +185,8 @@ class PartAggregationROIHead(Base3DRoIHead):
return
bbox_results
def
_bbox_forward
(
self
,
seg_feats
,
part_feats
,
voxels_dict
,
rois
):
"""Forward function of roi_extractor and bbox_head.
"""Forward function of roi_extractor and bbox_head used in both
training and testing.
Args:
seg_feats (torch.Tensor): Point-wise semantic features.
...
...
@@ -196,6 +216,18 @@ class PartAggregationROIHead(Base3DRoIHead):
return
bbox_results
def
_assign_and_sample
(
self
,
proposal_list
,
gt_bboxes_3d
,
gt_labels_3d
):
"""Assign and sample proposals for training
Args:
proposal_list (list[dict]): Proposals produced by RPN.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
boxes.
gt_labels_3d (list[torch.Tensor]): Ground truth labels
Returns:
list[:obj:`SamplingResult`]: Sampled results of each training
sample.
"""
sampling_results
=
[]
# bbox assign
for
batch_idx
in
range
(
len
(
proposal_list
)):
...
...
@@ -258,6 +290,18 @@ class PartAggregationROIHead(Base3DRoIHead):
def
_semantic_forward_train
(
self
,
x
,
voxels_dict
,
gt_bboxes_3d
,
gt_labels_3d
):
"""Train semantic head
Args:
x (torch.Tensor): Point-wise semantic features for segmentation
voxels_dict (dict): Contains information of voxels.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
boxes.
gt_labels_3d (list[torch.Tensor]): Ground truth labels
Returns:
dict: Segmentation results including losses
"""
semantic_results
=
self
.
semantic_head
(
x
)
semantic_targets
=
self
.
semantic_head
.
get_targets
(
voxels_dict
,
gt_bboxes_3d
,
gt_labels_3d
)
...
...
mmdet3d/models/roi_heads/roi_extractors/single_roiaware_extractor.py
View file @
d7ade147
...
...
@@ -20,6 +20,7 @@ class Single3DRoIAwareExtractor(nn.Module):
self
.
roi_layer
=
self
.
build_roi_layers
(
roi_layer
)
def
build_roi_layers
(
self
,
layer_cfg
):
"""Build roi layers using `layer_cfg`"""
cfg
=
layer_cfg
.
copy
()
layer_type
=
cfg
.
pop
(
'type'
)
assert
hasattr
(
ops
,
layer_type
)
...
...
mmdet3d/models/voxel_encoders/pillar_encoder.py
View file @
d7ade147
...
...
@@ -15,16 +15,22 @@ class PillarFeatureNet(nn.Module):
through PFNLayers.
Args:
in_channels (int). Number of input features,
either x, y, z or x, y, z, r.
feat_channels (list[int]). Number of features in each of the
N PFNLayers.
with_distance (bool). Whether to include Euclidean distance
to points.
voxel_size (list[float]). Size of voxels, only utilize x and y
size.
point_cloud_range (list[float]). Point cloud range, only
utilizes x and y min.
in_channels (int, optional): Number of input features,
either x, y, z or x, y, z, r. Defaults to 4.
feat_channels (tuple, optional): Number of features in each of the
N PFNLayers. Defaults to (64, ).
with_distance (bool, optional): Whether to include Euclidean distance
to points. Defaults to False.
with_cluster_center (bool, optional): [description]. Defaults to True.
with_voxel_center (bool, optional): [description]. Defaults to True.
voxel_size (tuple[float], optional): Size of voxels, only utilize x
and y size. Defaults to (0.2, 0.2, 4).
point_cloud_range (tuple[float], optional): Point cloud range, only
utilizes x and y min. Defaults to (0, -40, -3, 70.4, 40, 1).
norm_cfg ([type], optional): [description].
Defaults to dict(type='BN1d', eps=1e-3, momentum=0.01).
mode (str, optional): The mode to gather point features. Options are
'max' or 'avg'. Defaults to 'max'.
"""
def
__init__
(
self
,
...
...
@@ -77,6 +83,17 @@ class PillarFeatureNet(nn.Module):
self
.
point_cloud_range
=
point_cloud_range
def
forward
(
self
,
features
,
num_points
,
coors
):
"""Forward function
Args:
features (torch.Tensor): Point features or raw points in shape
(N, M, C).
num_points (torch.Tensor): Number of points in each pillar.
coors (torch.Tensor): Coordinates of each voxel
Returns:
torch.Tensor: Features of pillars.
"""
features_ls
=
[
features
]
# Find distance of x, y, and z from cluster center
if
self
.
_with_cluster_center
:
...
...
@@ -119,6 +136,31 @@ class PillarFeatureNet(nn.Module):
@
VOXEL_ENCODERS
.
register_module
()
class
DynamicPillarFeatureNet
(
PillarFeatureNet
):
"""Pillar Feature Net using dynamic voxelization
The network prepares the pillar features and performs forward pass
through PFNLayers. The main difference is that it is used for
dynamic voxels, which contains different number of points inside a voxel
without limits.
Args:
in_channels (int, optional): Number of input features,
either x, y, z or x, y, z, r. Defaults to 4.
feat_channels (tuple, optional): Number of features in each of the
N PFNLayers. Defaults to (64, ).
with_distance (bool, optional): Whether to include Euclidean distance
to points. Defaults to False.
with_cluster_center (bool, optional): [description]. Defaults to True.
with_voxel_center (bool, optional): [description]. Defaults to True.
voxel_size (tuple[float], optional): Size of voxels, only utilize x
and y size. Defaults to (0.2, 0.2, 4).
point_cloud_range (tuple[float], optional): Point cloud range, only
utilizes x and y min. Defaults to (0, -40, -3, 70.4, 40, 1).
norm_cfg ([type], optional): [description].
Defaults to dict(type='BN1d', eps=1e-3, momentum=0.01).
mode (str, optional): The mode to gather point features. Options are
'max' or 'avg'. Defaults to 'max'.
"""
def
__init__
(
self
,
in_channels
=
4
,
...
...
@@ -130,11 +172,6 @@ class DynamicPillarFeatureNet(PillarFeatureNet):
point_cloud_range
=
(
0
,
-
40
,
-
3
,
70.4
,
40
,
1
),
norm_cfg
=
dict
(
type
=
'BN1d'
,
eps
=
1e-3
,
momentum
=
0.01
),
mode
=
'max'
):
"""
Dynamic Pillar Feature Net for Dynamic Voxelization.
The difference is in the forward part
"""
super
(
DynamicPillarFeatureNet
,
self
).
__init__
(
in_channels
,
feat_channels
,
...
...
@@ -168,6 +205,19 @@ class DynamicPillarFeatureNet(PillarFeatureNet):
voxel_size
,
point_cloud_range
,
average_points
=
True
)
def
map_voxel_center_to_point
(
self
,
pts_coors
,
voxel_mean
,
voxel_coors
):
"""Map the centers of voxels to its corresponding points
Args:
pts_coors (torch.Tensor): The coordinates of each points, shape
(M, 3), where M is the number of points.
voxel_mean (torch.Tensor): The mean or aggreagated features of a
voxel, shape (N, C), where N is the number of voxels.
voxel_coors (torch.Tensor): The coordinates of each voxel.
Returns:
torch.Tensor: Corresponding voxel centers of each points, shape
(M, C), where M is the numver of points.
"""
# Step 1: scatter voxel into canvas
# Calculate necessary things for canvas creation
canvas_y
=
int
(
...
...
@@ -194,9 +244,15 @@ class DynamicPillarFeatureNet(PillarFeatureNet):
return
center_per_point
def
forward
(
self
,
features
,
coors
):
"""
features (torch.Tensor): NxC
coors (torch.Tensor): Nx(1+NDim)
"""Forward function
Args:
features (torch.Tensor): Point features or raw points in shape
(N, M, C).
coors (torch.Tensor): Coordinates of each voxel
Returns:
torch.Tensor: Features of pillars.
"""
features_ls
=
[
features
]
# Find distance of x, y, and z from cluster center
...
...
mmdet3d/models/voxel_encoders/utils.py
View file @
d7ade147
...
...
@@ -28,6 +28,21 @@ def get_paddings_indicator(actual_num, max_num, axis=0):
class
VFELayer
(
nn
.
Module
):
""" Voxel Feature Encoder layer.
The voxel encoder is composed of a series of these layers.
This module do not support average pooling and only support to use
max pooling to gather features inside a VFE.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
norm_cfg (dict): Config dict of normalization layers
max_out (bool): Whether aggregate the features of points inside
each voxel and only return voxel features.
cat_max (bool): Whether concatenate the aggregated features
and pointwise features.
"""
def
__init__
(
self
,
in_channels
,
...
...
@@ -44,6 +59,23 @@ class VFELayer(nn.Module):
self
.
linear
=
nn
.
Linear
(
in_channels
,
out_channels
,
bias
=
False
)
def
forward
(
self
,
inputs
):
"""Forward function
Args:
inputs (torch.Tensor): Voxels features of shape (N, M, C).
N is the number of voxels, M is the number of points in
voxels, C is the number of channels of point features.
Returns:
torch.Tensor: Voxel features. There are three mode under which the
features have different meaning.
- `max_out=False`: Return point-wise features in
shape (N, M, C).
- `max_out=True` and `cat_max=False`: Return aggregated
voxel features in shape (N, C)
- `max_out=True` and `cat_max=True`: Return concatenated
point-wise features in shape (N, M, C).
"""
# [K, T, 7] tensordot [7, units] = [K, T, units]
voxel_count
=
inputs
.
shape
[
1
]
x
=
self
.
linear
(
inputs
)
...
...
@@ -68,6 +100,20 @@ class VFELayer(nn.Module):
class
PFNLayer
(
nn
.
Module
):
""" Pillar Feature Net Layer.
The Pillar Feature Net is composed of a series of these layers, but the
PointPillars paper results only used a single PFNLayer.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
norm_cfg (dict): Config dict of normalization layers
last_layer (bool): If last_layer, there is no concatenation of
features.
mode (str): Pooling model to gather features inside voxels.
Default to 'max'.
"""
def
__init__
(
self
,
in_channels
,
...
...
@@ -75,20 +121,6 @@ class PFNLayer(nn.Module):
norm_cfg
=
dict
(
type
=
'BN1d'
,
eps
=
1e-3
,
momentum
=
0.01
),
last_layer
=
False
,
mode
=
'max'
):
""" Pillar Feature Net Layer.
The Pillar Feature Net is composed of a series of these layers, but the
PointPillars paper results only used a single PFNLayer.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
norm_cfg (dict): Config dict of normalization layers
last_layer (bool): If last_layer, there is no concatenation of
features.
mode (str): Pooling model to gather features inside voxels.
Default to 'max'.
"""
super
().
__init__
()
self
.
name
=
'PFNLayer'
...
...
@@ -104,7 +136,20 @@ class PFNLayer(nn.Module):
self
.
mode
=
mode
def
forward
(
self
,
inputs
,
num_voxels
=
None
,
aligned_distance
=
None
):
"""Forward function
Args:
inputs (torch.Tensor): Pillar/Voxel inputs with shape (N, M, C).
N is the number of voxels, M is the number of points in
voxels, C is the number of channels of point features.
num_voxels (torch.Tensor, optional): Number of points in each
voxel. Defaults to None.
aligned_distance (torch.Tensor, optional): The distance of
each points to the voxel center. Defaults to None.
Returns:
torch.Tensor: Features of Pillars.
"""
x
=
self
.
linear
(
inputs
)
x
=
self
.
norm
(
x
.
permute
(
0
,
2
,
1
).
contiguous
()).
permute
(
0
,
2
,
1
).
contiguous
()
...
...
mmdet3d/models/voxel_encoders/voxel_encoder.py
View file @
d7ade147
...
...
@@ -19,8 +19,19 @@ class HardSimpleVFE(nn.Module):
super
(
HardSimpleVFE
,
self
).
__init__
()
def
forward
(
self
,
features
,
num_points
,
coors
):
# features: [concated_num_points, num_voxel_size, 3(4)]
# num_points: [concated_num_points]
"""Forward function
Args:
features (torch.Tensor): point features in shape
(N, M, 3(4)). N is the number of voxels and M is the maximum
number of points inside a single voxel.
num_points (torch.Tensor): Number of points in each voxel,
shape (N, ).
coors (torch.Tensor): Coordinates of voxels.
Returns:
torch.Tensor: Mean of points inside each voxel in shape (N, 3(4))
"""
points_mean
=
features
[:,
:,
:
4
].
sum
(
dim
=
1
,
keepdim
=
False
)
/
num_points
.
type_as
(
features
).
view
(
-
1
,
1
)
return
points_mean
.
contiguous
()
...
...
@@ -46,6 +57,17 @@ class DynamicSimpleVFE(nn.Module):
@
torch
.
no_grad
()
def
forward
(
self
,
features
,
coors
):
"""Forward function
Args:
features (torch.Tensor): point features in shape
(N, 3(4)). N is the number of points.
coors (torch.Tensor): Coordinates of voxels.
Returns:
torch.Tensor: Mean of points inside each voxel in shape (M, 3(4)).
M is the number of voxels.
"""
# This function is used from the start of the voxelnet
# num_points: [concated_num_points]
features
,
features_coors
=
self
.
scatter
(
features
,
coors
)
...
...
mmdet3d/utils/collect_env.py
View file @
d7ade147
...
...
@@ -13,6 +13,7 @@ import mmdet3d
def
collect_env
():
"""Collect and print the information of running enviroments."""
env_info
=
{}
env_info
[
'sys.platform'
]
=
sys
.
platform
env_info
[
'Python'
]
=
sys
.
version
.
replace
(
'
\n
'
,
''
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment