Unverified Commit cccbaf7a authored by yinchimaoliang's avatar yinchimaoliang Committed by GitHub
Browse files

[enhance]: Use points loader to load point cloud data in db_infos (#87)

* Change gt_database and dbsampler.

* Add int condition for use_dim.

* Change default sweep_num.

* Change docstring.

* Change to pipeline in db_sampler.

* remove comments.
parent 136dd481
import copy import copy
import mmcv
import numpy as np import numpy as np
import os import os
import pickle
from mmdet3d.core.bbox import box_np_ops from mmdet3d.core.bbox import box_np_ops
from mmdet3d.datasets.pipelines import data_augment_utils from mmdet3d.datasets.pipelines import data_augment_utils
from mmdet.datasets import PIPELINES
from ..registry import OBJECTSAMPLERS from ..registry import OBJECTSAMPLERS
...@@ -86,6 +87,8 @@ class DataBaseSampler(object): ...@@ -86,6 +87,8 @@ class DataBaseSampler(object):
prepare (dict): Name of preparation functions and the input value. prepare (dict): Name of preparation functions and the input value.
sample_groups (dict): Sampled classes and numbers. sample_groups (dict): Sampled classes and numbers.
classes (list[str]): List of classes. Default: None. classes (list[str]): List of classes. Default: None.
points_loader(dict): Config of points loader. Default: dict(
type='LoadPointsFromFile', load_dim=4, use_dim=[0,1,2,3])
""" """
def __init__(self, def __init__(self,
...@@ -94,7 +97,11 @@ class DataBaseSampler(object): ...@@ -94,7 +97,11 @@ class DataBaseSampler(object):
rate, rate,
prepare, prepare,
sample_groups, sample_groups,
classes=None): classes=None,
points_loader=dict(
type='LoadPointsFromFile',
load_dim=4,
use_dim=[0, 1, 2, 3])):
super().__init__() super().__init__()
self.data_root = data_root self.data_root = data_root
self.info_path = info_path self.info_path = info_path
...@@ -103,9 +110,9 @@ class DataBaseSampler(object): ...@@ -103,9 +110,9 @@ class DataBaseSampler(object):
self.classes = classes self.classes = classes
self.cat2label = {name: i for i, name in enumerate(classes)} self.cat2label = {name: i for i, name in enumerate(classes)}
self.label2cat = {i: name for i, name in enumerate(classes)} self.label2cat = {i: name for i, name in enumerate(classes)}
self.points_loader = mmcv.build_from_cfg(points_loader, PIPELINES)
with open(info_path, 'rb') as f: db_infos = mmcv.load(info_path)
db_infos = pickle.load(f)
# filter database infos # filter database infos
from mmdet3d.utils import get_root_logger from mmdet3d.utils import get_root_logger
...@@ -244,8 +251,9 @@ class DataBaseSampler(object): ...@@ -244,8 +251,9 @@ class DataBaseSampler(object):
file_path = os.path.join( file_path = os.path.join(
self.data_root, self.data_root,
info['path']) if self.data_root else info['path'] info['path']) if self.data_root else info['path']
s_points = np.fromfile(
file_path, dtype=np.float32).reshape([-1, 4]) results = dict(pts_filename=file_path)
s_points = self.points_loader(results)['points']
s_points[:, :3] += info['box3d_lidar'][:3] s_points[:, :3] += info['box3d_lidar'][:3]
count += 1 count += 1
......
...@@ -145,7 +145,7 @@ def create_groundtruth_database(dataset_class_name, ...@@ -145,7 +145,7 @@ def create_groundtruth_database(dataset_class_name,
type=dataset_class_name, type=dataset_class_name,
data_root=data_path, data_root=data_path,
ann_file=info_path, ann_file=info_path,
) use_valid_flag=True)
if dataset_class_name == 'KittiDataset': if dataset_class_name == 'KittiDataset':
file_client_args = dict(backend='disk') file_client_args = dict(backend='disk')
dataset_cfg.update( dataset_cfg.update(
...@@ -176,7 +176,9 @@ def create_groundtruth_database(dataset_class_name, ...@@ -176,7 +176,9 @@ def create_groundtruth_database(dataset_class_name,
dict( dict(
type='LoadPointsFromMultiSweeps', type='LoadPointsFromMultiSweeps',
sweeps_num=10, sweeps_num=10,
), use_dim=[0, 1, 2, 3, 4],
pad_empty_sweeps=True,
remove_close=True),
dict( dict(
type='LoadAnnotations3D', type='LoadAnnotations3D',
with_bbox_3d=True, with_bbox_3d=True,
...@@ -191,7 +193,6 @@ def create_groundtruth_database(dataset_class_name, ...@@ -191,7 +193,6 @@ def create_groundtruth_database(dataset_class_name,
f'{info_prefix}_dbinfos_train.pkl') f'{info_prefix}_dbinfos_train.pkl')
mmcv.mkdir_or_exist(database_save_path) mmcv.mkdir_or_exist(database_save_path)
all_db_infos = dict() all_db_infos = dict()
if with_mask: if with_mask:
coco = COCO(osp.join(data_path, mask_anno_path)) coco = COCO(osp.join(data_path, mask_anno_path))
imgIds = coco.getImgIds() imgIds = coco.getImgIds()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment