Unverified Commit 23768cba authored by encore-zhou's avatar encore-zhou Committed by GitHub
Browse files

[Enhance] Use Points structure in augmentation and models (#204)

* add h3d backbone

* add h3d backbone

* add h3dnet

* modify scannet config

* fix bugs for proposal refine

* fix bugs for test backbone

* add primitive head test

* modify h3dhead

* modify h3d head

* update loss weight config

* fix bugs for h3d head loss

* modify h3d head get targets function

* update h3dnet base config

* modify weighted loss

* Revert "Merge branch 'h3d_u2' into 'master'"

This reverts merge request !5

* modify pipeline

* modify kitti pipeline

* fix bugs for points rotation

* modify multi sweeps

* modify multi sweep points

* fix bugs for points slicing

* modify BackgroundPointsFilter

* modify pipeline

* modify unittest

* modify unittest

* modify docstring

* modify config files

* update configs

* modify docstring
parent a97fc87b
......@@ -37,6 +37,7 @@ file_client_args = dict(
train_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=5,
use_dim=5,
file_client_args=file_client_args),
......
......@@ -37,6 +37,7 @@ file_client_args = dict(
train_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=5,
use_dim=5,
file_client_args=file_client_args),
......
......@@ -135,7 +135,7 @@ img_norm_cfg = dict(
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
input_modality = dict(use_lidar=True, use_camera=True)
train_pipeline = [
dict(type='LoadPointsFromFile', load_dim=4, use_dim=4),
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(
......@@ -160,7 +160,7 @@ train_pipeline = [
keys=['points', 'img', 'gt_bboxes_3d', 'gt_labels_3d']),
]
test_pipeline = [
dict(type='LoadPointsFromFile', load_dim=4, use_dim=4),
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug3D',
......
......@@ -210,7 +210,7 @@ db_sampler = dict(
classes=class_names,
sample_groups=dict(Car=12, Pedestrian=6, Cyclist=6))
train_pipeline = [
dict(type='LoadPointsFromFile', load_dim=4, use_dim=4),
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler),
dict(
......@@ -232,7 +232,7 @@ train_pipeline = [
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(type='LoadPointsFromFile', load_dim=4, use_dim=4),
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
......
......@@ -82,7 +82,7 @@ db_sampler = dict(
classes=class_names,
sample_groups=dict(Car=15))
train_pipeline = [
dict(type='LoadPointsFromFile', load_dim=4, use_dim=4),
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler),
dict(
......@@ -104,7 +104,7 @@ train_pipeline = [
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(type='LoadPointsFromFile', load_dim=4, use_dim=4),
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
......
......@@ -21,7 +21,7 @@ db_sampler = dict(
# PointPillars uses different augmentation hyper parameters
train_pipeline = [
dict(type='LoadPointsFromFile', load_dim=4, use_dim=4),
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler),
dict(
......@@ -42,7 +42,7 @@ train_pipeline = [
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(type='LoadPointsFromFile', load_dim=4, use_dim=4),
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
......
......@@ -40,7 +40,7 @@ db_sampler = dict(
classes=class_names)
train_pipeline = [
dict(type='LoadPointsFromFile', load_dim=4, use_dim=4),
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler),
dict(
......@@ -61,7 +61,7 @@ train_pipeline = [
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(type='LoadPointsFromFile', load_dim=4, use_dim=4),
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
......
......@@ -24,7 +24,7 @@ db_sampler = dict(
type='LoadPointsFromFile', load_dim=5, use_dim=[0, 1, 2, 3, 4]))
train_pipeline = [
dict(type='LoadPointsFromFile', load_dim=6, use_dim=5),
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=6, use_dim=5),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler),
dict(
......@@ -44,7 +44,7 @@ train_pipeline = [
]
test_pipeline = [
dict(type='LoadPointsFromFile', load_dim=6, use_dim=5),
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=6, use_dim=5),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
......
......@@ -13,7 +13,7 @@ class_names = [
]
train_pipeline = [
dict(type='LoadPointsFromFile', load_dim=5, use_dim=5),
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=5, use_dim=5),
dict(type='LoadPointsFromMultiSweeps', sweeps_num=10),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(
......@@ -33,7 +33,7 @@ train_pipeline = [
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(type='LoadPointsFromFile', load_dim=5, use_dim=5),
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=5, use_dim=5),
dict(type='LoadPointsFromMultiSweeps', sweeps_num=10),
dict(
type='MultiScaleFlipAug3D',
......
......@@ -13,7 +13,7 @@ class_names = [
]
train_pipeline = [
dict(type='LoadPointsFromFile', load_dim=5, use_dim=5),
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=5, use_dim=5),
dict(type='LoadPointsFromMultiSweeps', sweeps_num=10),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(
......@@ -33,7 +33,7 @@ train_pipeline = [
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(type='LoadPointsFromFile', load_dim=5, use_dim=5),
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=5, use_dim=5),
dict(type='LoadPointsFromMultiSweeps', sweeps_num=10),
dict(
type='MultiScaleFlipAug3D',
......
import numpy as np
import torch
from mmdet3d.core.points import BasePoints
from .base_box3d import BaseInstance3DBoxes
from .utils import limit_period, rotation_3d_in_axis
......@@ -96,7 +97,8 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
@property
def corners(self):
"""torch.Tensor: Coordinates of corners of all the boxes in shape (N, 8, 3).
"""torch.Tensor: Coordinates of corners of all the boxes in
shape (N, 8, 3).
Convert the boxes to in clockwise order, in the form of
(x0y0z0, x0y0z1, x0y1z1, x0y1z0, x1y0z0, x1y0z1, x1y1z1, x1y1z0)
......@@ -168,8 +170,8 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
Args:
angle (float, torch.Tensor): Rotation angle.
points (torch.Tensor, numpy.ndarray, optional): Points to rotate.
Defaults to None.
points (torch.Tensor, numpy.ndarray, :obj:`BasePoints`, optional):
Points to rotate. Defaults to None.
Returns:
tuple or None: When ``points`` is None, the function returns \
......@@ -192,6 +194,9 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
elif isinstance(points, np.ndarray):
rot_mat_T = rot_mat_T.numpy()
points[:, :3] = np.dot(points[:, :3], rot_mat_T)
elif isinstance(points, BasePoints):
# clockwise
points.rotate(-angle)
else:
raise ValueError
return points, rot_mat_T
......@@ -203,8 +208,8 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
Args:
bev_direction (str): Flip direction (horizontal or vertical).
points (torch.Tensor, numpy.ndarray, None): Points to flip.
Defaults to None.
points (torch.Tensor, numpy.ndarray, :obj:`BasePoints`, None):
Points to flip. Defaults to None.
Returns:
torch.Tensor, numpy.ndarray or None: Flipped points.
......@@ -220,11 +225,14 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
self.tensor[:, 6] = -self.tensor[:, 6]
if points is not None:
assert isinstance(points, (torch.Tensor, np.ndarray))
assert isinstance(points, (torch.Tensor, np.ndarray, BasePoints))
if isinstance(points, (torch.Tensor, np.ndarray)):
if bev_direction == 'horizontal':
points[:, 0] = -points[:, 0]
elif bev_direction == 'vertical':
points[:, 2] = -points[:, 2]
elif isinstance(points, BasePoints):
points.flip(bev_direction)
return points
def in_range_bev(self, box_range):
......
import numpy as np
import torch
from mmdet3d.core.points import BasePoints
from mmdet3d.ops import points_in_boxes_batch
from .base_box3d import BaseInstance3DBoxes
from .utils import limit_period, rotation_3d_in_axis
......@@ -114,8 +115,8 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes):
Args:
angle (float, torch.Tensor): Rotation angle.
points (torch.Tensor, numpy.ndarray, optional): Points to rotate.
Defaults to None.
points (torch.Tensor, numpy.ndarray, :obj:`BasePoints`, optional):
Points to rotate. Defaults to None.
Returns:
tuple or None: When ``points`` is None, the function returns \
......@@ -148,6 +149,9 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes):
elif isinstance(points, np.ndarray):
rot_mat_T = rot_mat_T.numpy()
points[:, :3] = np.dot(points[:, :3], rot_mat_T)
elif isinstance(points, BasePoints):
# anti-clockwise
points.rotate(angle)
else:
raise ValueError
return points, rot_mat_T
......@@ -159,8 +163,8 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes):
Args:
bev_direction (str): Flip direction (horizontal or vertical).
points (torch.Tensor, numpy.ndarray, None): Points to flip.
Defaults to None.
points (torch.Tensor, numpy.ndarray, :obj:`BasePoints`, None):
Points to flip. Defaults to None.
Returns:
torch.Tensor, numpy.ndarray or None: Flipped points.
......@@ -176,11 +180,14 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes):
self.tensor[:, 6] = -self.tensor[:, 6]
if points is not None:
assert isinstance(points, (torch.Tensor, np.ndarray))
assert isinstance(points, (torch.Tensor, np.ndarray, BasePoints))
if isinstance(points, (torch.Tensor, np.ndarray)):
if bev_direction == 'horizontal':
points[:, 0] = -points[:, 0]
elif bev_direction == 'vertical':
points[:, 1] = -points[:, 1]
elif isinstance(points, BasePoints):
points.flip(bev_direction)
return points
def in_range_bev(self, box_range):
......
import numpy as np
import torch
from mmdet3d.core.points import BasePoints
from mmdet3d.ops.roiaware_pool3d import points_in_boxes_gpu
from .base_box3d import BaseInstance3DBoxes
from .utils import limit_period, rotation_3d_in_axis
......@@ -114,8 +115,8 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
Args:
angle (float | torch.Tensor): Rotation angle.
points (torch.Tensor, numpy.ndarray, optional): Points to rotate.
Defaults to None.
points (torch.Tensor, numpy.ndarray, :obj:`BasePoints`, optional):
Points to rotate. Defaults to None.
Returns:
tuple or None: When ``points`` is None, the function returns \
......@@ -142,6 +143,9 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
elif isinstance(points, np.ndarray):
rot_mat_T = rot_mat_T.numpy()
points[:, :3] = np.dot(points[:, :3], rot_mat_T)
elif isinstance(points, BasePoints):
# clockwise
points.rotate(-angle)
else:
raise ValueError
return points, rot_mat_T
......@@ -153,8 +157,8 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
Args:
bev_direction (str): Flip direction (horizontal or vertical).
points (torch.Tensor, numpy.ndarray, None): Points to flip.
Defaults to None.
points (torch.Tensor, numpy.ndarray, :obj:`BasePoints`, None):
Points to flip. Defaults to None.
Returns:
torch.Tensor, numpy.ndarray or None: Flipped points.
......@@ -170,11 +174,14 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
self.tensor[:, 6] = -self.tensor[:, 6]
if points is not None:
assert isinstance(points, (torch.Tensor, np.ndarray))
assert isinstance(points, (torch.Tensor, np.ndarray, BasePoints))
if isinstance(points, (torch.Tensor, np.ndarray)):
if bev_direction == 'horizontal':
points[:, 1] = -points[:, 1]
elif bev_direction == 'vertical':
points[:, 0] = -points[:, 0]
elif isinstance(points, BasePoints):
points.flip(bev_direction)
return points
def in_range_bev(self, box_range):
......
......@@ -4,3 +4,26 @@ from .depth_points import DepthPoints
from .lidar_points import LiDARPoints
__all__ = ['BasePoints', 'CameraPoints', 'DepthPoints', 'LiDARPoints']
def get_points_type(points_type):
"""Get the class of points according to coordinate type.
Args:
points_type (str): The type of points coordinate.
The valid value are "CAMERA", "LIDAR", or "DEPTH".
Returns:
class: Points type.
"""
if points_type == 'CAMERA':
points_cls = CameraPoints
elif points_type == 'LIDAR':
points_cls = LiDARPoints
elif points_type == 'DEPTH':
points_cls = DepthPoints
else:
raise ValueError('Only "points_type" of "CAMERA", "LIDAR", or "DEPTH"'
f' are supported, got {points_type}')
return points_cls
import numpy as np
import torch
from abc import abstractmethod
......@@ -18,6 +19,7 @@ class BasePoints(object):
Each row is (x, y, z, ...).
attribute_dims (bool): Dictinory to indicate the meaning of extra
dimension. Default to None.
rotation_axis (int): Default rotation axis for points rotation.
"""
def __init__(self, tensor, points_dim=3, attribute_dims=None):
......@@ -37,6 +39,7 @@ class BasePoints(object):
self.tensor = tensor
self.points_dim = points_dim
self.attribute_dims = attribute_dims
self.rotation_axis = 0
@property
def coord(self):
......@@ -61,24 +64,32 @@ class BasePoints(object):
else:
return None
@property
def shape(self):
"""torch.Shape: Shape of points."""
return self.tensor.shape
def shuffle(self):
"""Shuffle the points."""
self.tensor = self.tensor[torch.randperm(
self.__len__(), device=self.tensor.device)]
def rotate(self, rotation, axis=2):
def rotate(self, rotation, axis=None):
"""Rotate points with the given rotation matrix or angle.
Args:
rotation (float, np.ndarray, torch.Tensor): Rotation matrix
or angle.
axis (int): Axis to rotate at. Defaults to 2.
axis (int): Axis to rotate at. Defaults to None.
"""
if not isinstance(rotation, torch.Tensor):
rotation = self.tensor.new_tensor(rotation)
assert rotation.shape == torch.Size([3, 3]) or \
rotation.numel() == 1
if axis is None:
axis = self.rotation_axis
if rotation.numel() == 1:
rot_sin = torch.sin(rotation)
rot_cos = torch.cos(rotation)
......@@ -204,12 +215,14 @@ class BasePoints(object):
3. `new_points = points[vector]`:
where vector is a torch.BoolTensor with `length = len(points)`.
Nonzero elements in the vector will be selected.
4. `new_points = points[3:11, vector]`:
return a slice of points and attribute dims.
Note that the returned Points might share storage with this Points,
subject to Pytorch's indexing semantics.
Returns:
:obj:`BaseInstancesPints`: A new object of \
:class:`BaseInstancesPints` after indexing.
:obj:`BasePoints`: A new object of \
:class:`BasePoints` after indexing.
"""
original_type = type(self)
if isinstance(item, int):
......@@ -217,11 +230,43 @@ class BasePoints(object):
self.tensor[item].view(1, -1),
points_dim=self.points_dim,
attribute_dims=self.attribute_dims)
elif isinstance(item, tuple) and len(item) == 2:
if isinstance(item[1], slice):
start = 0 if item[1].start is None else item[1].start
stop = self.tensor.shape[1] + \
1 if item[1].stop is None else item[1].stop
step = 1 if item[1].step is None else item[1].step
item[1] = list(range(start, stop, step))
p = self.tensor[item[0], item[1]]
keep_dims = list(
set(item[1]).intersection(set(range(3, self.tensor.shape[1]))))
if self.attribute_dims is not None:
attribute_dims = self.attribute_dims.copy()
for key in self.attribute_dims.keys():
cur_attribute_dim = attribute_dims[key]
if isinstance(cur_attribute_dim, int):
cur_attribute_dims = [cur_attribute_dim]
intersect_attr = list(
set(cur_attribute_dims).intersection(set(keep_dims)))
if len(intersect_attr) == 1:
attribute_dims[key] = intersect_attr[0]
elif len(intersect_attr) > 1:
attribute_dims[key] = intersect_attr
else:
attribute_dims.pop(key)
else:
attribute_dims = None
elif isinstance(item, (slice, np.ndarray, torch.Tensor)):
p = self.tensor[item]
attribute_dims = self.attribute_dims
else:
raise NotImplementedError(f'Invalid slice {item}!')
assert p.dim() == 2, \
f'Indexing on Points with {item} failed to return a matrix!'
return original_type(
p, points_dim=self.points_dim, attribute_dims=self.attribute_dims)
p, points_dim=p.shape[1], attribute_dims=attribute_dims)
def __len__(self):
"""int: Number of points in the current object."""
......@@ -236,10 +281,10 @@ class BasePoints(object):
"""Concatenate a list of Points into a single Points.
Args:
points_list (list[:obj:`BaseInstancesPoints`]): List of points.
points_list (list[:obj:`BasePoints`]): List of points.
Returns:
:obj:`BaseInstancesPoints`: The concatenated Points.
:obj:`BasePoints`: The concatenated Points.
"""
assert isinstance(points_list, (list, tuple))
if len(points_list) == 0:
......
......@@ -17,11 +17,13 @@ class CameraPoints(BasePoints):
Each row is (x, y, z, ...).
attribute_dims (bool): Dictinory to indicate the meaning of extra
dimension. Default to None.
rotation_axis (int): Default rotation axis for points rotation.
"""
def __init__(self, tensor, points_dim=3, attribute_dims=None):
super(CameraPoints, self).__init__(
tensor, points_dim=points_dim, attribute_dims=attribute_dims)
self.rotation_axis = 1
def flip(self, bev_direction='horizontal'):
"""Flip the boxes in BEV along given BEV direction."""
......
......@@ -17,11 +17,13 @@ class DepthPoints(BasePoints):
Each row is (x, y, z, ...).
attribute_dims (bool): Dictinory to indicate the meaning of extra
dimension. Default to None.
rotation_axis (int): Default rotation axis for points rotation.
"""
def __init__(self, tensor, points_dim=3, attribute_dims=None):
super(DepthPoints, self).__init__(
tensor, points_dim=points_dim, attribute_dims=attribute_dims)
self.rotation_axis = 2
def flip(self, bev_direction='horizontal'):
"""Flip the boxes in BEV along given BEV direction."""
......
......@@ -17,11 +17,13 @@ class LiDARPoints(BasePoints):
Each row is (x, y, z, ...).
attribute_dims (bool): Dictinory to indicate the meaning of extra
dimension. Default to None.
rotation_axis (int): Default rotation axis for points rotation.
"""
def __init__(self, tensor, points_dim=3, attribute_dims=None):
super(LiDARPoints, self).__init__(
tensor, points_dim=points_dim, attribute_dims=attribute_dims)
self.rotation_axis = 2
def flip(self, bev_direction='horizontal'):
"""Flip the boxes in BEV along given BEV direction."""
......
......@@ -100,6 +100,7 @@ class DataBaseSampler(object):
classes=None,
points_loader=dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=4,
use_dim=[0, 1, 2, 3])):
super().__init__()
......@@ -253,7 +254,7 @@ class DataBaseSampler(object):
info['path']) if self.data_root else info['path']
results = dict(pts_filename=file_path)
s_points = self.points_loader(results)['points']
s_points[:, :3] += info['box3d_lidar'][:3]
s_points.translate(info['box3d_lidar'][:3])
count += 1
......@@ -267,7 +268,7 @@ class DataBaseSampler(object):
'gt_bboxes_3d':
sampled_gt_bboxes,
'points':
np.concatenate(s_points_list, axis=0),
s_points_list[0].cat(s_points_list),
'group_ids':
np.arange(gt_bboxes.shape[0],
gt_bboxes.shape[0] + len(sampled))
......
......@@ -2,6 +2,7 @@ import numpy as np
from mmcv.parallel import DataContainer as DC
from mmdet3d.core.bbox import BaseInstance3DBoxes
from mmdet3d.core.points import BasePoints
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import to_tensor
......@@ -71,6 +72,7 @@ class DefaultFormatBundle(object):
if 'gt_semantic_seg' in results:
results['gt_semantic_seg'] = DC(
to_tensor(results['gt_semantic_seg'][None, ...]), stack=True)
return results
def __repr__(self):
......@@ -202,9 +204,11 @@ class DefaultFormatBundle3D(DefaultFormatBundle):
default bundle.
"""
# Format 3D data
for key in [
'voxels', 'coors', 'voxel_centers', 'num_points', 'points'
]:
if 'points' in results:
assert isinstance(results['points'], BasePoints)
results['points'] = DC(results['points'].tensor)
for key in ['voxels', 'coors', 'voxel_centers', 'num_points']:
if key not in results:
continue
results[key] = DC(to_tensor(results[key]), stack=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment