Unverified Commit bcbab523 authored by Sun Jiahao's avatar Sun Jiahao Committed by GitHub
Browse files

[Feature] Cylindrical voxelization & voxel feature encoder (#2228)

* add cylindrical voxelization & voxel feature encoder

* add cylindrical voxelization & voxel feature encoder

* add voxel-wise label & voxelization UT

* fix vfe

* fix vfe UT

* rename voxel encoder & add more test case

* fix type hint

* temporarily refactoring mmcv's voxelize and dynamic in mmdet3d for data_preprocesser

* fix vfe init bug & fix UT

* add grid_size & move voxelization code

* fix import bug

* keep radian to follow origin

* add doc string

* fix type hint

* rename gird shape & add comments

* fix UT

* rename voxelizationofGridshape & fix docstring
parent 92b24d97
......@@ -5,15 +5,16 @@ from typing import Dict, List, Optional, Sequence, Union
import numpy as np
import torch
from mmcv.ops import Voxelization
from mmdet.models import DetDataPreprocessor
from mmengine.model import stack_batch
from mmengine.utils import is_list_of
from torch.nn import functional as F
from mmdet3d.registry import MODELS
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import OptConfigType
from .utils import multiview_img_stack_batch
from .voxelize import VoxelizationByGridShape, dynamic_scatter_3d
@MODELS.register_module()
......@@ -103,7 +104,7 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
self.voxel = voxel
self.voxel_type = voxel_type
if voxel:
self.voxel_layer = Voxelization(**voxel_layer)
self.voxel_layer = VoxelizationByGridShape(**voxel_layer)
def forward(self,
data: Union[dict, List[dict]],
......@@ -157,7 +158,7 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
batch_inputs['points'] = inputs['points']
if self.voxel:
voxel_dict = self.voxelize(inputs['points'])
voxel_dict = self.voxelize(inputs['points'], data_samples)
batch_inputs['voxels'] = voxel_dict
if 'imgs' in inputs:
......@@ -329,11 +330,14 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
return batch_pad_shape
@torch.no_grad()
def voxelize(self, points: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
def voxelize(self, points: List[torch.Tensor],
data_samples: SampleList) -> Dict[str, torch.Tensor]:
"""Apply voxelization to point cloud.
Args:
points (List[Tensor]): Point cloud in one data batch.
data_samples: (list[:obj:`Det3DDataSample`]): The annotation data
of every samples. Add voxel-wise annotation for segmentation.
Returns:
Dict[str, Tensor]: Voxelization information.
......@@ -378,6 +382,31 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
coors.append(res_coors)
voxels = torch.cat(points, dim=0)
coors = torch.cat(coors, dim=0)
elif self.voxel_type == 'cylindrical':
voxels, coors = [], []
for i, (res, data_sample) in enumerate(zip(points, data_samples)):
rho = torch.sqrt(res[:, 0]**2 + res[:, 1]**2)
phi = torch.atan2(res[:, 1], res[:, 0])
polar_res = torch.stack((rho, phi, res[:, 2]), dim=-1)
# Currently we only support PyTorch >= 1.9.0, and will
# implement it in voxel_layer soon for better compatibility
min_bound = polar_res.new_tensor(
self.voxel_layer.point_cloud_range[:3])
max_bound = polar_res.new_tensor(
self.voxel_layer.point_cloud_range[3:])
polar_res = torch.clamp(polar_res, min_bound, max_bound)
res_coors = torch.floor(
(polar_res - min_bound) /
polar_res.new_tensor(self.voxel_layer.voxel_size)).int()
if self.training:
self.get_voxel_seg(res_coors, data_sample)
res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
res_voxels = torch.cat((polar_res, res[:, :2], res[:, 3:]),
dim=-1)
voxels.append(res_voxels)
coors.append(res_coors)
voxels = torch.cat(voxels, dim=0)
coors = torch.cat(coors, dim=0)
else:
raise ValueError(f'Invalid voxelization type {self.voxel_type}')
......@@ -385,3 +414,19 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
voxel_dict['coors'] = coors
return voxel_dict
def get_voxel_seg(self, res_coors: torch.Tensor, data_sample: SampleList):
"""Get voxel-wise segmentation label and point2voxel map.
Args:
res_coors (Tensor): The voxel coordinates of points, Nx3.
data_sample: (:obj:`Det3DDataSample`): The annotation data of
every samples. Add voxel-wise annotation forsegmentation.
"""
pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask
voxel_semantic_mask, _, point2voxel_map = dynamic_scatter_3d(
F.one_hot(pts_semantic_mask.long()).float(), res_coors, 'mean',
True)
voxel_semantic_mask = torch.argmax(voxel_semantic_mask, dim=-1)
data_sample.gt_pts_seg.voxel_semantic_mask = voxel_semantic_mask
data_sample.gt_pts_seg.point2voxel_map = point2voxel_map
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, List, Optional, Tuple, Union
import torch
from mmcv.utils import ext_loader
from torch import nn
from torch.autograd import Function
from torch.nn import functional as F
from torch.nn.modules.utils import _pair
ext_module = ext_loader.load_ext('_ext', [
'dynamic_voxelize_forward', 'hard_voxelize_forward',
'dynamic_point_to_voxel_forward', 'dynamic_point_to_voxel_backward'
])
class _Voxelization(Function):
@staticmethod
def forward(
ctx: Any,
points: torch.Tensor,
voxel_size: Union[tuple, float],
coors_range: Union[tuple, float],
max_points: int = 35,
max_voxels: int = 20000,
deterministic: bool = True) -> Union[Tuple[torch.Tensor], Tuple]:
"""Convert kitti points(N, >=3) to voxels.
Args:
points (torch.Tensor): [N, ndim]. Points[:, :3] contain xyz points
and points[:, 3:] contain other information like reflectivity.
voxel_size (tuple or float): The size of voxel with the shape of
[3].
coors_range (tuple or float): The coordinate range of voxel with
the shape of [6].
max_points (int, optional): maximum points contained in a voxel. if
max_points=-1, it means using dynamic_voxelize. Default: 35.
max_voxels (int, optional): maximum voxels this function create.
for second, 20000 is a good choice. Users should shuffle points
before call this function because max_voxels may drop points.
Default: 20000.
deterministic: bool. whether to invoke the non-deterministic
version of hard-voxelization implementations. non-deterministic
version is considerablly fast but is not deterministic. only
affects hard voxelization. default True. for more information
of this argument and the implementation insights, please refer
to the following links:
https://github.com/open-mmlab/mmdetection3d/issues/894
https://github.com/open-mmlab/mmdetection3d/pull/904
it is an experimental feature and we will appreciate it if
you could share with us the failing cases.
Returns:
tuple[torch.Tensor]: tuple[torch.Tensor]: A tuple contains three
elements. The first one is the output voxels with the shape of
[M, max_points, n_dim], which only contain points and returned
when max_points != -1. The second is the voxel coordinates with
shape of [M, 3]. The last is number of point per voxel with the
shape of [M], which only returned when max_points != -1.
"""
if max_points == -1 or max_voxels == -1:
coors = points.new_zeros(size=(points.size(0), 3), dtype=torch.int)
ext_module.dynamic_voxelize_forward(
points,
torch.tensor(voxel_size, dtype=torch.float),
torch.tensor(coors_range, dtype=torch.float),
coors,
NDim=3)
return coors
else:
voxels = points.new_zeros(
size=(max_voxels, max_points, points.size(1)))
coors = points.new_zeros(size=(max_voxels, 3), dtype=torch.int)
num_points_per_voxel = points.new_zeros(
size=(max_voxels, ), dtype=torch.int)
voxel_num = torch.zeros(size=(), dtype=torch.long)
ext_module.hard_voxelize_forward(
points,
torch.tensor(voxel_size, dtype=torch.float),
torch.tensor(coors_range, dtype=torch.float),
voxels,
coors,
num_points_per_voxel,
voxel_num,
max_points=max_points,
max_voxels=max_voxels,
NDim=3,
deterministic=deterministic)
# select the valid voxels
voxels_out = voxels[:voxel_num]
coors_out = coors[:voxel_num]
num_points_per_voxel_out = num_points_per_voxel[:voxel_num]
return voxels_out, coors_out, num_points_per_voxel_out
voxelization = _Voxelization.apply
class VoxelizationByGridShape(nn.Module):
"""Voxelization that allows inferring voxel size automatically based on
grid shape.
Please refer to `Point-Voxel CNN for Efficient 3D Deep Learning
<https://arxiv.org/abs/1907.03739>`_ for more details.
Args:
point_cloud_range (list):
[x_min, y_min, z_min, x_max, y_max, z_max]
max_num_points (int): max number of points per voxel
voxel_size (list): list [x, y, z] or [rho, phi, z]
size of single voxel.
grid_shape (list): [L, W, H], grid shape of voxelization.
max_voxels (tuple or int): max number of voxels in
(training, testing) time
deterministic: bool. whether to invoke the non-deterministic
version of hard-voxelization implementations. non-deterministic
version is considerablly fast but is not deterministic. only
affects hard voxelization. default True. for more information
of this argument and the implementation insights, please refer
to the following links:
https://github.com/open-mmlab/mmdetection3d/issues/894
https://github.com/open-mmlab/mmdetection3d/pull/904
it is an experimental feature and we will appreciate it if
you could share with us the failing cases.
"""
def __init__(self,
point_cloud_range: List,
max_num_points: int,
voxel_size: List = [],
grid_shape: List[int] = [],
max_voxels: Union[tuple, int] = 20000,
deterministic: bool = True):
super().__init__()
if voxel_size and grid_shape:
raise ValueError('voxel_size is mutually exclusive grid_shape')
self.point_cloud_range = point_cloud_range
self.max_num_points = max_num_points
if isinstance(max_voxels, tuple):
self.max_voxels = max_voxels
else:
self.max_voxels = _pair(max_voxels)
self.deterministic = deterministic
point_cloud_range = torch.tensor(
point_cloud_range, dtype=torch.float32)
if voxel_size:
self.voxel_size = voxel_size
voxel_size = torch.tensor(voxel_size, dtype=torch.float32)
grid_shape = (point_cloud_range[3:] -
point_cloud_range[:3]) / voxel_size
grid_shape = torch.round(grid_shape).long().tolist()
self.grid_shape = grid_shape
elif grid_shape:
grid_shape = torch.tensor(grid_shape, dtype=torch.float32)
voxel_size = (point_cloud_range[3:] - point_cloud_range[:3]) / (
grid_shape - 1)
voxel_size = voxel_size.tolist()
self.voxel_size = voxel_size
else:
raise ValueError('must assign a value to voxel_size or grid_shape')
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.training:
max_voxels = self.max_voxels[0]
else:
max_voxels = self.max_voxels[1]
return voxelization(input, self.voxel_size, self.point_cloud_range,
self.max_num_points, max_voxels,
self.deterministic)
def __repr__(self):
s = self.__class__.__name__ + '('
s += 'voxel_size=' + str(self.voxel_size)
s += ', grid_shape=' + str(self.grid_shape)
s += ', point_cloud_range=' + str(self.point_cloud_range)
s += ', max_num_points=' + str(self.max_num_points)
s += ', max_voxels=' + str(self.max_voxels)
s += ', deterministic=' + str(self.deterministic)
s += ')'
return s
class _DynamicScatter(Function):
"""Different from the mmcv implementation, here it is allowed to return
point2voxel_map."""
@staticmethod
def forward(ctx: Any,
feats: torch.Tensor,
coors: torch.Tensor,
reduce_type: str = 'max',
return_map: str = False) -> Tuple[torch.Tensor, torch.Tensor]:
"""convert kitti points(N, >=3) to voxels.
Args:
feats (torch.Tensor): [N, C]. Points features to be reduced
into voxels.
coors (torch.Tensor): [N, ndim]. Corresponding voxel coordinates
(specifically multi-dim voxel index) of each points.
reduce_type (str, optional): Reduce op. support 'max', 'sum' and
'mean'. Default: 'max'.
return_map (str, optional): Whether to return point2voxel_map.
Returns:
tuple[torch.Tensor]: A tuple contains two elements. The first one
is the voxel features with shape [M, C] which are respectively
reduced from input features that share the same voxel coordinates.
The second is voxel coordinates with shape [M, ndim].
"""
results = ext_module.dynamic_point_to_voxel_forward(
feats, coors, reduce_type)
(voxel_feats, voxel_coors, point2voxel_map,
voxel_points_count) = results
ctx.reduce_type = reduce_type
ctx.save_for_backward(feats, voxel_feats, point2voxel_map,
voxel_points_count)
ctx.mark_non_differentiable(voxel_coors)
if return_map:
return voxel_feats, voxel_coors, point2voxel_map
else:
return voxel_feats, voxel_coors
@staticmethod
def backward(ctx: Any,
grad_voxel_feats: torch.Tensor,
grad_voxel_coors: Optional[torch.Tensor] = None) -> tuple:
(feats, voxel_feats, point2voxel_map,
voxel_points_count) = ctx.saved_tensors
grad_feats = torch.zeros_like(feats)
# TODO: whether to use index put or use cuda_backward
# To use index put, need point to voxel index
ext_module.dynamic_point_to_voxel_backward(
grad_feats, grad_voxel_feats.contiguous(), feats, voxel_feats,
point2voxel_map, voxel_points_count, ctx.reduce_type)
return grad_feats, None, None
dynamic_scatter_3d = _DynamicScatter.apply
class DynamicScatter3D(nn.Module):
"""Scatters points into voxels, used in the voxel encoder with dynamic
voxelization.
Note:
The CPU and GPU implementation get the same output, but have numerical
difference after summation and division (e.g., 5e-7).
Args:
voxel_size (list): list [x, y, z] size of three dimension.
point_cloud_range (list): The coordinate range of points, [x_min,
y_min, z_min, x_max, y_max, z_max].
average_points (bool): whether to use avg pooling to scatter points
into voxel.
"""
def __init__(self, voxel_size: List, point_cloud_range: List,
average_points: bool):
super().__init__()
self.voxel_size = voxel_size
self.point_cloud_range = point_cloud_range
self.average_points = average_points
def forward_single(
self, points: torch.Tensor,
coors: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Scatters points into voxels.
Args:
points (torch.Tensor): Points to be reduced into voxels.
coors (torch.Tensor): Corresponding voxel coordinates (specifically
multi-dim voxel index) of each points.
Returns:
tuple[torch.Tensor]: A tuple contains two elements. The first one
is the voxel features with shape [M, C] which are respectively
reduced from input features that share the same voxel coordinates.
The second is voxel coordinates with shape [M, ndim].
"""
reduce = 'mean' if self.average_points else 'max'
return dynamic_scatter_3d(points.contiguous(), coors.contiguous(),
reduce)
def forward(self, points: torch.Tensor,
coors: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Scatters points/features into voxels.
Args:
points (torch.Tensor): Points to be reduced into voxels.
coors (torch.Tensor): Corresponding voxel coordinates (specifically
multi-dim voxel index) of each points.
Returns:
tuple[torch.Tensor]: A tuple contains two elements. The first one
is the voxel features with shape [M, C] which are respectively
reduced from input features that share the same voxel coordinates.
The second is voxel coordinates with shape [M, ndim].
"""
if coors.size(-1) == 3:
return self.forward_single(points, coors)
else:
batch_size = coors[-1, 0] + 1
voxels, voxel_coors = [], []
for i in range(batch_size):
inds = torch.where(coors[:, 0] == i)
voxel, voxel_coor = self.forward_single(
points[inds], coors[inds][:, 1:])
coor_pad = F.pad(voxel_coor, (1, 0), mode='constant', value=i)
voxel_coors.append(coor_pad)
voxels.append(voxel)
features = torch.cat(voxels, dim=0)
feature_coors = torch.cat(voxel_coors, dim=0)
return features, feature_coors
def __repr__(self):
s = self.__class__.__name__ + '('
s += 'voxel_size=' + str(self.voxel_size)
s += ', point_cloud_range=' + str(self.point_cloud_range)
s += ', average_points=' + str(self.average_points)
s += ')'
return s
# Copyright (c) OpenMMLab. All rights reserved.
from .pillar_encoder import DynamicPillarFeatureNet, PillarFeatureNet
from .voxel_encoder import DynamicSimpleVFE, DynamicVFE, HardSimpleVFE, HardVFE
from .voxel_encoder import (DynamicSimpleVFE, DynamicVFE, HardSimpleVFE,
HardVFE, SegVFE)
__all__ = [
'PillarFeatureNet', 'DynamicPillarFeatureNet', 'HardVFE', 'DynamicVFE',
'HardSimpleVFE', 'DynamicSimpleVFE'
'HardSimpleVFE', 'DynamicSimpleVFE', 'SegVFE'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Tuple
import torch
from mmcv.cnn import build_norm_layer
from mmcv.ops import DynamicScatter
......@@ -486,3 +488,158 @@ class HardVFE(nn.Module):
out = torch.max(voxel_canvas, dim=1)[0]
return out
@MODELS.register_module()
class SegVFE(nn.Module):
"""Voxel feature encoder used in segmentation task.
It encodes features of voxels and their points. It could also fuse
image feature into voxel features in a point-wise manner.
The number of points inside the voxel varies.
Args:
in_channels (int): Input channels of VFE. Defaults to 6.
feat_channels (list(int)): Channels of features in VFE.
with_voxel_center (bool): Whether to use the distance
to center of voxel for each points inside a voxel.
Defaults to False.
voxel_size (tuple[float]): Size of a single voxel (rho, phi, z).
Defaults to None.
grid_shape (tuple[float]): The grid shape of voxelization.
Defaults to (480, 360, 32).
point_cloud_range (tuple[float]): The range of points
or voxels. Defaults to (0, -40, -3, 70.4, 40, 1).
norm_cfg (dict): Config dict of normalization layers.
mode (str): The mode when pooling features of points
inside a voxel. Available options include 'max' and 'avg'.
Defaults to 'max'.
with_pre_norm (bool): Whether to use the norm layer before
input vfe layer.
feat_compression (int, optional): The voxel feature compression
channels, Defaults to None
return_point_feats (bool): Whether to return the features
of each points. Defaults to False.
"""
def __init__(self,
in_channels: int = 6,
feat_channels: Sequence[int] = [],
with_voxel_center: bool = False,
voxel_size: Optional[Sequence[float]] = None,
grid_shape: Sequence[float] = (480, 360, 32),
point_cloud_range: Sequence[float] = (0, -180, -4, 50, 180,
2),
norm_cfg: dict = dict(type='BN1d', eps=1e-5, momentum=0.1),
mode: bool = 'max',
with_pre_norm: bool = True,
feat_compression: Optional[int] = None,
return_point_feats: bool = False) -> None:
super(SegVFE, self).__init__()
assert mode in ['avg', 'max']
assert len(feat_channels) > 0
assert not (voxel_size and grid_shape), \
'voxel_size and grid_shape cannot be setting at the same time'
if with_voxel_center:
in_channels += 3
self.in_channels = in_channels
self._with_voxel_center = with_voxel_center
self.return_point_feats = return_point_feats
self.point_cloud_range = point_cloud_range
point_cloud_range = torch.tensor(
point_cloud_range, dtype=torch.float32)
if voxel_size:
self.voxel_size = voxel_size
voxel_size = torch.tensor(voxel_size, dtype=torch.float32)
grid_shape = (point_cloud_range[3:] -
point_cloud_range[:3]) / voxel_size
grid_shape = torch.round(grid_shape).long().tolist()
self.grid_shape = grid_shape
elif grid_shape:
grid_shape = torch.tensor(grid_shape, dtype=torch.float32)
voxel_size = (point_cloud_range[3:] - point_cloud_range[:3]) / (
grid_shape - 1)
voxel_size = voxel_size.tolist()
self.voxel_size = voxel_size
else:
raise ValueError('must assign a value to voxel_size or grid_shape')
# Need pillar (voxel) size and x/y offset in order to calculate offset
self.vx = self.voxel_size[0]
self.vy = self.voxel_size[0]
self.vz = self.voxel_size[0]
self.x_offset = self.vx / 2 + point_cloud_range[0]
self.y_offset = self.vy / 2 + point_cloud_range[1]
self.z_offset = self.vz / 2 + point_cloud_range[2]
feat_channels = [self.in_channels] + list(feat_channels)
if with_pre_norm:
self.pre_norm = build_norm_layer(norm_cfg, self.in_channels)[1]
vfe_layers = []
for i in range(len(feat_channels) - 1):
in_filters = feat_channels[i]
out_filters = feat_channels[i + 1]
norm_layer = build_norm_layer(norm_cfg, out_filters)[1]
if i == len(feat_channels) - 2:
vfe_layers.append(nn.Linear(in_filters, out_filters))
else:
vfe_layers.append(
nn.Sequential(
nn.Linear(in_filters, out_filters), norm_layer,
nn.ReLU(inplace=True)))
self.vfe_layers = nn.ModuleList(vfe_layers)
self.num_vfe = len(vfe_layers)
self.vfe_scatter = DynamicScatter(self.voxel_size,
self.point_cloud_range,
(mode != 'max'))
self.compression_layers = None
if feat_compression is not None:
self.compression_layers = nn.Linear(feat_channels[-1],
feat_compression)
def forward(self, features: Tensor, coors: Tensor, *args,
**kwargs) -> Tuple[Tensor]:
"""Forward functions.
Args:
features (Tensor): Features of voxels, shape is NxC.
coors (Tensor): Coordinates of voxels, shape is Nx(1+NDim).
Returns:
tuple: If `return_point_feats` is False, returns voxel features and
its coordinates. If `return_point_feats` is True, returns
feature of each points inside voxels additionally.
"""
features_ls = [features]
# Find distance of x, y, and z from voxel center
if self._with_voxel_center:
f_center = features.new_zeros(size=(features.size(0), 3))
f_center[:, 0] = features[:, 0] - (
coors[:, 3].type_as(features) * self.vx + self.x_offset)
f_center[:, 1] = features[:, 1] - (
coors[:, 2].type_as(features) * self.vy + self.y_offset)
f_center[:, 2] = features[:, 2] - (
coors[:, 1].type_as(features) * self.vz + self.z_offset)
features_ls.append(f_center)
# Combine together feature decorations
features = torch.cat(features_ls[::-1], dim=-1)
if self.pre_norm is not None:
features = self.pre_norm(features)
point_feats = []
for i, vfe in enumerate(self.vfe_layers):
features = vfe(features)
point_feats.append(features)
if i == self.num_vfe - 1:
voxel_feats, voxel_coors = self.vfe_scatter(features, coors)
if self.compression_layers is not None:
voxel_feats = self.compression_layers(voxel_feats)
if self.return_point_feats:
return voxel_feats, voxel_coors, point_feats
return voxel_feats, voxel_coors
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import pytest
import torch
from mmdet3d.models.data_preprocessors import Det3DDataPreprocessor
from mmdet3d.structures import Det3DDataSample
from mmdet3d.structures import Det3DDataSample, PointData
class TestDet3DDataPreprocessor(TestCase):
......@@ -95,3 +96,33 @@ class TestDet3DDataPreprocessor(TestCase):
for data_sample, expected_shape in zip(batch_data_samples, [(10, 15),
(10, 25)]):
self.assertEqual(data_sample.pad_shape, expected_shape)
# test cylindrical voxelization
if not torch.cuda.is_available():
pytest.skip('test requires GPU and CUDA')
point_cloud_range = [0, -180, -4, 50, 180, 2]
grid_shape = [480, 360, 32]
voxel_layer = dict(
grid_shape=grid_shape,
point_cloud_range=point_cloud_range,
max_num_points=-1,
max_voxels=-1)
processor = Det3DDataPreprocessor(
voxel=True, voxel_type='cylindrical',
voxel_layer=voxel_layer).cuda()
num_points = 5000
xy = torch.rand(num_points, 2) * 140 - 70
z = torch.rand(num_points, 1) * 9 - 6
ref = torch.rand(num_points, 1)
points = [torch.cat([xy, z, ref], dim=-1)] * 2
data_sample = Det3DDataSample()
gt_pts_seg = PointData()
gt_pts_seg.pts_semantic_mask = torch.randint(0, 10, (num_points, ))
data_sample.gt_pts_seg = gt_pts_seg
data_samples = [data_sample] * 2
inputs = dict(inputs=dict(points=points), data_samples=data_samples)
out_data = processor(inputs)
batch_inputs, batch_data_samples = out_data['inputs'], out_data[
'data_samples']
self.assertEqual(batch_inputs['voxels']['voxels'].shape, (10000, 6))
self.assertEqual(batch_inputs['voxels']['coors'].shape, (10000, 4))
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
import torch.nn.functional as F
from mmdet3d.registry import MODELS
......@@ -15,3 +16,31 @@ def test_hard_simple_VFE():
outputs = hard_simple_VFE(features, num_voxels, None)
assert outputs.shape == torch.Size([240000, 5])
def test_seg_VFE():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
seg_VFE_cfg = dict(
type='SegVFE',
feat_channels=[64, 128, 256, 256],
grid_shape=[480, 360, 32],
with_voxel_center=True,
feat_compression=16,
return_point_feats=True)
seg_VFE = MODELS.build(seg_VFE_cfg)
seg_VFE = seg_VFE.cuda()
features = torch.rand([240000, 6]).cuda()
coors = []
for i in range(4):
coor = torch.randint(0, 10, (60000, 3))
coor = F.pad(coor, (1, 0), mode='constant', value=i)
coors.append(coor)
coors = torch.cat(coors, dim=0).cuda()
out_features, out_coors, out_point_features = seg_VFE(features, coors)
assert out_features.shape[0] == out_coors.shape[0]
assert len(out_point_features) == 4
assert out_point_features[0].shape == torch.Size([240000, 64])
assert out_point_features[1].shape == torch.Size([240000, 128])
assert out_point_features[2].shape == torch.Size([240000, 256])
assert out_point_features[3].shape == torch.Size([240000, 256])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment