Unverified Commit be8f6538 authored by Wenbo Yu's avatar Wenbo Yu Committed by GitHub
Browse files

[Feature] Support SA-SSD (#1337)



* sassd origin

* sassd for merge

* Fix flake8 coding style warnings and errors

* Fix flake8 indent style

* Fix no newline at end of line issue

* fix mmdet3d.ops error

* Fix flake8 over-indented

* Change points_ops format

* import format updated

* Fix import isort issue

* Fix yapf format

* classes modification after first review

* tmp

* solved mmdet3d update issue

* fixed comment style and lint errors

* fix isort error

* fix yapf error

* Modify comments

* solve merge confilict

* unit test added

* Fix yamp error

* fix lint format

* fix loss format error
Co-authored-by: default avatarYi-Chen Zhang <yi-chen.zhang@isza.com>
parent c232b7a4
_base_ = [
'../_base_/datasets/kitti-3d-3class.py',
'../_base_/schedules/cyclic_40e.py', '../_base_/default_runtime.py'
]
voxel_size = [0.05, 0.05, 0.1]
model = dict(
type='SASSD',
voxel_layer=dict(
max_num_points=5,
point_cloud_range=[0, -40, -3, 70.4, 40, 1],
voxel_size=voxel_size,
max_voxels=(16000, 40000)),
voxel_encoder=dict(type='HardSimpleVFE'),
middle_encoder=dict(
type='SparseEncoderSASSD',
in_channels=4,
sparse_shape=[41, 1600, 1408],
order=('conv', 'norm', 'act')),
backbone=dict(
type='SECOND',
in_channels=256,
layer_nums=[5, 5],
layer_strides=[1, 2],
out_channels=[128, 256]),
neck=dict(
type='SECONDFPN',
in_channels=[128, 256],
upsample_strides=[1, 2],
out_channels=[256, 256]),
bbox_head=dict(
type='Anchor3DHead',
num_classes=3,
in_channels=512,
feat_channels=512,
use_direction_classifier=True,
anchor_generator=dict(
type='Anchor3DRangeGenerator',
ranges=[
[0, -40.0, -0.6, 70.4, 40.0, -0.6],
[0, -40.0, -0.6, 70.4, 40.0, -0.6],
[0, -40.0, -1.78, 70.4, 40.0, -1.78],
],
sizes=[[0.6, 0.8, 1.73], [0.6, 1.76, 1.73], [1.6, 3.9, 1.56]],
rotations=[0, 1.57],
reshape_out=False),
diff_rad_by_sin=True,
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0),
loss_dir=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.2)),
# model training and testing settings
train_cfg=dict(
assigner=[
dict( # for Pedestrian
type='MaxIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.35,
neg_iou_thr=0.2,
min_pos_iou=0.2,
ignore_iof_thr=-1),
dict( # for Cyclist
type='MaxIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.35,
neg_iou_thr=0.2,
min_pos_iou=0.2,
ignore_iof_thr=-1),
dict( # for Car
type='MaxIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.6,
neg_iou_thr=0.45,
min_pos_iou=0.45,
ignore_iof_thr=-1),
],
allowed_border=0,
pos_weight=-1,
debug=False),
test_cfg=dict(
use_rotate_nms=True,
nms_across_levels=False,
nms_thr=0.01,
score_thr=0.1,
min_bbox_size=0,
nms_pre=100,
max_num=50))
......@@ -11,6 +11,7 @@ from .mvx_faster_rcnn import DynamicMVXFasterRCNN, MVXFasterRCNN
from .mvx_two_stage import MVXTwoStageDetector
from .parta2 import PartA2
from .point_rcnn import PointRCNN
from .sassd import SASSD
from .single_stage_mono3d import SingleStageMono3DDetector
from .smoke_mono3d import SMOKEMono3D
from .ssd3dnet import SSD3DNet
......@@ -21,5 +22,6 @@ __all__ = [
'Base3DDetector', 'VoxelNet', 'DynamicVoxelNet', 'MVXTwoStageDetector',
'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet', 'H3DNet',
'CenterPoint', 'SSD3DNet', 'ImVoteNet', 'SingleStageMono3DDetector',
'FCOSMono3D', 'ImVoxelNet', 'GroupFree3DNet', 'PointRCNN', 'SMOKEMono3D'
'FCOSMono3D', 'ImVoxelNet', 'GroupFree3DNet', 'PointRCNN', 'SMOKEMono3D',
'SASSD'
]
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.ops import Voxelization
from mmcv.runner import force_fp32
from torch.nn import functional as F
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from mmdet.models.builder import DETECTORS
from .. import builder
from .single_stage import SingleStage3DDetector
@DETECTORS.register_module()
class SASSD(SingleStage3DDetector):
r"""`SASSD <https://github.com/skyhehe123/SA-SSD>` _ for 3D detection."""
def __init__(self,
voxel_layer,
voxel_encoder,
middle_encoder,
backbone,
neck=None,
bbox_head=None,
train_cfg=None,
test_cfg=None,
init_cfg=None,
pretrained=None):
super(SASSD, self).__init__(
backbone=backbone,
neck=neck,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg,
pretrained=pretrained)
self.voxel_layer = Voxelization(**voxel_layer)
self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder)
self.middle_encoder = builder.build_middle_encoder(middle_encoder)
def extract_feat(self, points, img_metas=None, test_mode=False):
"""Extract features from points."""
voxels, num_points, coors = self.voxelize(points)
voxel_features = self.voxel_encoder(voxels, num_points, coors)
batch_size = coors[-1, 0].item() + 1
x, point_misc = self.middle_encoder(voxel_features, coors, batch_size,
test_mode)
x = self.backbone(x)
if self.with_neck:
x = self.neck(x)
return x, point_misc
@torch.no_grad()
@force_fp32()
def voxelize(self, points):
"""Apply hard voxelization to points."""
voxels, coors, num_points = [], [], []
for res in points:
res_voxels, res_coors, res_num_points = self.voxel_layer(res)
voxels.append(res_voxels)
coors.append(res_coors)
num_points.append(res_num_points)
voxels = torch.cat(voxels, dim=0)
num_points = torch.cat(num_points, dim=0)
coors_batch = []
for i, coor in enumerate(coors):
coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
coors_batch.append(coor_pad)
coors_batch = torch.cat(coors_batch, dim=0)
return voxels, num_points, coors_batch
def forward_train(self,
points,
img_metas,
gt_bboxes_3d,
gt_labels_3d,
gt_bboxes_ignore=None):
"""Training forward function.
Args:
points (list[torch.Tensor]): Point cloud of each sample.
img_metas (list[dict]): Meta information of each sample
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
boxes for each sample.
gt_labels_3d (list[torch.Tensor]): Ground truth labels for
boxes of each sampole
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
boxes to be ignored. Defaults to None.
Returns:
dict: Losses of each branch.
"""
x, point_misc = self.extract_feat(points, img_metas, test_mode=False)
aux_loss = self.middle_encoder.aux_loss(*point_misc, gt_bboxes_3d)
outs = self.bbox_head(x)
loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_metas)
losses = self.bbox_head.loss(
*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
losses.update(aux_loss)
return losses
def simple_test(self, points, img_metas, imgs=None, rescale=False):
"""Test function without augmentaiton."""
x, _ = self.extract_feat(points, img_metas, test_mode=True)
outs = self.bbox_head(x)
bbox_list = self.bbox_head.get_bboxes(
*outs, img_metas, rescale=rescale)
bbox_results = [
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
return bbox_results
def aug_test(self, points, img_metas, imgs=None, rescale=False):
"""Test function with augmentaiton."""
feats = self.extract_feats(points, img_metas, test_mode=True)
# only support aug_test for one sample
aug_bboxes = []
for x, img_meta in zip(feats, img_metas):
outs = self.bbox_head(x)
bbox_list = self.bbox_head.get_bboxes(
*outs, img_meta, rescale=rescale)
bbox_list = [
dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels)
for bboxes, scores, labels in bbox_list
]
aug_bboxes.append(bbox_list[0])
# after merging, bboxes will be rescaled to the original image size
merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
self.bbox_head.test_cfg)
return [merged_bboxes]
# Copyright (c) OpenMMLab. All rights reserved.
from .pillar_scatter import PointPillarsScatter
from .sparse_encoder import SparseEncoder
from .sparse_encoder import SparseEncoder, SparseEncoderSASSD
from .sparse_unet import SparseUNet
__all__ = ['PointPillarsScatter', 'SparseEncoder', 'SparseUNet']
__all__ = [
'PointPillarsScatter', 'SparseEncoder', 'SparseEncoderSASSD', 'SparseUNet'
]
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.ops import points_in_boxes_all, three_interpolate, three_nn
from mmcv.runner import auto_fp16
from torch import nn as nn
from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule
from mmdet3d.ops.spconv import IS_SPCONV2_AVAILABLE
from mmdet.models.losses import sigmoid_focal_loss, smooth_l1_loss
from ..builder import MIDDLE_ENCODERS
if IS_SPCONV2_AVAILABLE:
......@@ -30,9 +32,10 @@ class SparseEncoder(nn.Module):
Defaults to 128.
encoder_channels (tuple[tuple[int]], optional):
Convolutional channels of each encode block.
Defaults to ((16, ), (32, 32, 32), (64, 64, 64), (64, 64, 64)).
encoder_paddings (tuple[tuple[int]], optional):
Paddings of each encode block.
Defaults to ((16, ), (32, 32, 32), (64, 64, 64), (64, 64, 64)).
Defaults to ((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1, 1)).
block_type (str, optional): Type of the block to use.
Defaults to 'conv_module'.
"""
......@@ -106,8 +109,8 @@ class SparseEncoder(nn.Module):
"""Forward of SparseEncoder.
Args:
voxel_features (torch.float32): Voxel features in shape (N, C).
coors (torch.int32): Coordinates in shape (N, 4),
voxel_features (torch.Tensor): Voxel features in shape (N, C).
coors (torch.Tensor): Coordinates in shape (N, 4),
the columns in the order of (batch_idx, z_idx, y_idx, x_idx).
batch_size (int): Batch size.
......@@ -209,3 +212,280 @@ class SparseEncoder(nn.Module):
stage_layers = SparseSequential(*blocks_list)
self.encoder_layers.add_module(stage_name, stage_layers)
return out_channels
@MIDDLE_ENCODERS.register_module()
class SparseEncoderSASSD(SparseEncoder):
r"""Sparse encoder for `SASSD <https://github.com/skyhehe123/SA-SSD>`_
Args:
in_channels (int): The number of input channels.
sparse_shape (list[int]): The sparse shape of input tensor.
order (list[str], optional): Order of conv module.
Defaults to ('conv', 'norm', 'act').
norm_cfg (dict, optional): Config of normalization layer. Defaults to
dict(type='BN1d', eps=1e-3, momentum=0.01).
base_channels (int, optional): Out channels for conv_input layer.
Defaults to 16.
output_channels (int, optional): Out channels for conv_out layer.
Defaults to 128.
encoder_channels (tuple[tuple[int]], optional):
Convolutional channels of each encode block.
Defaults to ((16, ), (32, 32, 32), (64, 64, 64), (64, 64, 64)).
encoder_paddings (tuple[tuple[int]], optional):
Paddings of each encode block.
Defaults to ((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1, 1)).
block_type (str, optional): Type of the block to use.
Defaults to 'conv_module'.
"""
def __init__(self,
in_channels,
sparse_shape,
order=('conv', 'norm', 'act'),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
base_channels=16,
output_channels=128,
encoder_channels=((16, ), (32, 32, 32), (64, 64, 64), (64, 64,
64)),
encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1,
1)),
block_type='conv_module'):
super(SparseEncoderSASSD, self).__init__(
in_channels=in_channels,
sparse_shape=sparse_shape,
order=order,
norm_cfg=norm_cfg,
base_channels=base_channels,
output_channels=output_channels,
encoder_channels=encoder_channels,
encoder_paddings=encoder_paddings,
block_type=block_type)
self.point_fc = nn.Linear(112, 64, bias=False)
self.point_cls = nn.Linear(64, 1, bias=False)
self.point_reg = nn.Linear(64, 3, bias=False)
@auto_fp16(apply_to=('voxel_features', ))
def forward(self, voxel_features, coors, batch_size, test_mode=False):
"""Forward of SparseEncoder.
Args:
voxel_features (torch.Tensor): Voxel features in shape (N, C).
coors (torch.Tensor): Coordinates in shape (N, 4),
the columns in the order of (batch_idx, z_idx, y_idx, x_idx).
batch_size (int): Batch size.
test_mode (bool, optional): Whether in test mode.
Defaults to False.
Returns:
dict: Backbone features.
tuple[torch.Tensor]: Mean feature value of the points,
Classificaion result of the points,
Regression offsets of the points.
"""
coors = coors.int()
input_sp_tensor = SparseConvTensor(voxel_features, coors,
self.sparse_shape, batch_size)
x = self.conv_input(input_sp_tensor)
encode_features = []
for encoder_layer in self.encoder_layers:
x = encoder_layer(x)
encode_features.append(x)
# for detection head
# [200, 176, 5] -> [200, 176, 2]
out = self.conv_out(encode_features[-1])
spatial_features = out.dense()
N, C, D, H, W = spatial_features.shape
spatial_features = spatial_features.view(N, C * D, H, W)
if test_mode:
return spatial_features, None
points_mean = torch.zeros_like(voxel_features)
points_mean[:, 0] = coors[:, 0]
points_mean[:, 1:] = voxel_features[:, :3]
# auxiliary network
p0 = self.make_auxiliary_points(
encode_features[0],
points_mean,
offset=(0, -40., -3.),
voxel_size=(.1, .1, .2))
p1 = self.make_auxiliary_points(
encode_features[1],
points_mean,
offset=(0, -40., -3.),
voxel_size=(.2, .2, .4))
p2 = self.make_auxiliary_points(
encode_features[2],
points_mean,
offset=(0, -40., -3.),
voxel_size=(.4, .4, .8))
pointwise = torch.cat([p0, p1, p2], dim=-1)
pointwise = self.point_fc(pointwise)
point_cls = self.point_cls(pointwise)
point_reg = self.point_reg(pointwise)
point_misc = (points_mean, point_cls, point_reg)
return spatial_features, point_misc
def get_auxiliary_targets(self, nxyz, gt_boxes3d, enlarge=1.0):
"""Get auxiliary target.
Args:
nxyz (torch.Tensor): Mean features of the points.
gt_boxes3d (torch.Tensor): Coordinates in shape (N, 4),
the columns in the order of (batch_idx, z_idx, y_idx, x_idx).
enlarge (int, optional): Enlaged scale. Defaults to 1.0.
Returns:
tuple[torch.Tensor]: Label of the points and
center offsets of the points.
"""
center_offsets = list()
pts_labels = list()
for i in range(len(gt_boxes3d)):
boxes3d = gt_boxes3d[i].tensor.cpu()
idx = torch.nonzero(nxyz[:, 0] == i).view(-1)
new_xyz = nxyz[idx, 1:].cpu()
boxes3d[:, 3:6] *= enlarge
pts_in_flag, center_offset = self.calculate_pts_offsets(
new_xyz, boxes3d)
pts_label = pts_in_flag.max(0)[0].byte()
pts_labels.append(pts_label)
center_offsets.append(center_offset)
center_offsets = torch.cat(center_offsets).cuda()
pts_labels = torch.cat(pts_labels).to(center_offsets.device)
return pts_labels, center_offsets
def calculate_pts_offsets(self, points, boxes):
"""Find all boxes in which each point is, as well as the offsets from
the box centers.
Args:
points (torch.Tensor): [M, 3], [x, y, z] in LiDAR/DEPTH coordinate
boxes (torch.Tensor): [T, 7],
num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz],
(x, y, z) is the bottom center.
Returns:
tuple[torch.Tensor]: Point indices of boxes with the shape of
(T, M). Default background = 0.
And offsets from the box centers of points,
if it belows to the box, with the shape of (M, 3).
Default background = 0.
"""
boxes_num = len(boxes)
pts_num = len(points)
points = points.cuda()
boxes = boxes.to(points.device)
box_idxs_of_pts = points_in_boxes_all(points[None, ...], boxes[None,
...])
pts_indices = box_idxs_of_pts.squeeze(0).transpose(0, 1)
center_offsets = torch.zeros_like(points).to(points.device)
for i in range(boxes_num):
for j in range(pts_num):
if pts_indices[i][j] == 1:
center_offsets[j][0] = points[j][0] - boxes[i][0]
center_offsets[j][1] = points[j][1] - boxes[i][1]
center_offsets[j][2] = (
points[j][2] - (boxes[i][2] + boxes[i][2] / 2.0))
return pts_indices.cpu(), center_offsets.cpu()
def aux_loss(self, points, point_cls, point_reg, gt_bboxes):
"""Calculate auxiliary loss.
Args:
points (torch.Tensor): Mean feature value of the points.
point_cls (torch.Tensor): Classificaion result of the points.
point_reg (torch.Tensor): Regression offsets of the points.
gt_bboxes (list[:obj:`BaseInstance3DBoxes`]): Ground truth
boxes for each sample.
Returns:
dict: Backbone features.
"""
num_boxes = len(gt_bboxes)
pts_labels, center_targets = self.get_auxiliary_targets(
points, gt_bboxes)
rpn_cls_target = pts_labels.long()
pos = (pts_labels > 0).float()
neg = (pts_labels == 0).float()
pos_normalizer = pos.sum().clamp(min=1.0)
cls_weights = pos + neg
reg_weights = pos
reg_weights = reg_weights / pos_normalizer
aux_loss_cls = sigmoid_focal_loss(
point_cls,
rpn_cls_target,
weight=cls_weights,
avg_factor=pos_normalizer)
aux_loss_cls /= num_boxes
weight = reg_weights[..., None]
aux_loss_reg = smooth_l1_loss(point_reg, center_targets, beta=1 / 9.)
aux_loss_reg = torch.sum(aux_loss_reg * weight)[None]
aux_loss_reg /= num_boxes
aux_loss_cls, aux_loss_reg = [aux_loss_cls], [aux_loss_reg]
return dict(aux_loss_cls=aux_loss_cls, aux_loss_reg=aux_loss_reg)
def make_auxiliary_points(self,
source_tensor,
target,
offset=(0., -40., -3.),
voxel_size=(.05, .05, .1)):
"""Make auxiliary points for loss computation.
Args:
source_tensor (torch.Tensor): (M, C) features to be propigated.
target (torch.Tensor): (N, 4) bxyz positions of the
target features.
offset (tuple[float], optional): Voxelization offset.
Defaults to (0., -40., -3.)
voxel_size (tuple[float], optional): Voxelization size.
Defaults to (.05, .05, .1)
Returns:
torch.Tensor: (N, C) tensor of the features of the target features.
"""
# Tansfer tensor to points
source = source_tensor.indices.float()
offset = torch.Tensor(offset).to(source.device)
voxel_size = torch.Tensor(voxel_size).to(source.device)
source[:, 1:] = (
source[:, [3, 2, 1]] * voxel_size + offset + .5 * voxel_size)
source_feats = source_tensor.features[None, ...].transpose(1, 2)
# Interplate auxiliary points
dist, idx = three_nn(target[None, ...], source[None, ...])
dist_recip = 1.0 / (dist + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
new_features = three_interpolate(source_feats.contiguous(), idx,
weight)
return new_features.squeeze(0).transpose(0, 1)
......@@ -25,3 +25,25 @@ def test_sparse_encoder():
ret = sparse_encoder(voxel_features, coors, 4)
assert ret.shape == torch.Size([4, 256, 128, 128])
def test_sparse_encoder_for_ssd():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
sparse_encoder_for_ssd_cfg = dict(
type='SparseEncoderSASSD',
in_channels=5,
sparse_shape=[40, 1024, 1024],
order=('conv', 'norm', 'act'),
encoder_channels=((16, 16, 32), (32, 32, 64), (64, 64, 128), (128,
128)),
encoder_paddings=((1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1,
1)),
block_type='basicblock')
sparse_encoder = build_middle_encoder(sparse_encoder_for_ssd_cfg).cuda()
voxel_features = torch.rand([207842, 5]).cuda()
coors = torch.randint(0, 4, [207842, 4]).cuda()
ret, _ = sparse_encoder(voxel_features, coors, 4, True)
assert ret.shape == torch.Size([4, 256, 128, 128])
......@@ -567,3 +567,42 @@ def test_smoke():
assert boxes_3d.tensor.shape[1] == 7
assert scores_3d.shape[0] >= 0
assert labels_3d.shape[0] >= 0
def test_sassd():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
_setup_seed(0)
sassd_cfg = _get_detector_cfg('sassd/sassd_6x8_80e_kitti-3d-3class.py')
self = build_detector(sassd_cfg).cuda()
points_0 = torch.rand([2010, 4], device='cuda')
points_1 = torch.rand([2020, 4], device='cuda')
points = [points_0, points_1]
gt_bbox_0 = LiDARInstance3DBoxes(torch.rand([10, 7], device='cuda'))
gt_bbox_1 = LiDARInstance3DBoxes(torch.rand([10, 7], device='cuda'))
gt_bboxes = [gt_bbox_0, gt_bbox_1]
gt_labels_0 = torch.randint(0, 3, [10], device='cuda')
gt_labels_1 = torch.randint(0, 3, [10], device='cuda')
gt_labels = [gt_labels_0, gt_labels_1]
img_meta_0 = dict(box_type_3d=LiDARInstance3DBoxes)
img_meta_1 = dict(box_type_3d=LiDARInstance3DBoxes)
img_metas = [img_meta_0, img_meta_1]
# test forward_train
losses = self.forward_train(points, img_metas, gt_bboxes, gt_labels)
assert losses['loss_cls'][0] >= 0
assert losses['loss_bbox'][0] >= 0
assert losses['loss_dir'][0] >= 0
assert losses['aux_loss_cls'][0] >= 0
assert losses['aux_loss_reg'][0] >= 0
# test simple_test
with torch.no_grad():
results = self.simple_test(points, img_metas)
boxes_3d = results[0]['boxes_3d']
scores_3d = results[0]['scores_3d']
labels_3d = results[0]['labels_3d']
assert boxes_3d.tensor.shape == (50, 7)
assert scores_3d.shape == torch.Size([50])
assert labels_3d.shape == torch.Size([50])
......@@ -7,7 +7,7 @@ from tools.data_converter import kitti_converter as kitti
from tools.data_converter import lyft_converter as lyft_converter
from tools.data_converter import nuscenes_converter as nuscenes_converter
from tools.data_converter.create_gt_database import (
create_groundtruth_database, GTDatabaseCreater)
GTDatabaseCreater, create_groundtruth_database)
def kitti_data_prep(root_path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment