Unverified Commit cccbaf7a authored by yinchimaoliang's avatar yinchimaoliang Committed by GitHub
Browse files

[enhance]: Use points loader to load point cloud data in db_infos (#87)

* Change gt_database and dbsampler.

* Add int condition for use_dim.

* Change default sweep_num.

* Change docstring.

* Change to pipeline in db_sampler.

* remove comments.
parent 136dd481
import copy
import mmcv
import numpy as np
import os
import pickle
from mmdet3d.core.bbox import box_np_ops
from mmdet3d.datasets.pipelines import data_augment_utils
from mmdet.datasets import PIPELINES
from ..registry import OBJECTSAMPLERS
......@@ -86,6 +87,8 @@ class DataBaseSampler(object):
prepare (dict): Name of preparation functions and the input value.
sample_groups (dict): Sampled classes and numbers.
classes (list[str]): List of classes. Default: None.
points_loader(dict): Config of points loader. Default: dict(
type='LoadPointsFromFile', load_dim=4, use_dim=[0,1,2,3])
"""
def __init__(self,
......@@ -94,7 +97,11 @@ class DataBaseSampler(object):
rate,
prepare,
sample_groups,
classes=None):
classes=None,
points_loader=dict(
type='LoadPointsFromFile',
load_dim=4,
use_dim=[0, 1, 2, 3])):
super().__init__()
self.data_root = data_root
self.info_path = info_path
......@@ -103,9 +110,9 @@ class DataBaseSampler(object):
self.classes = classes
self.cat2label = {name: i for i, name in enumerate(classes)}
self.label2cat = {i: name for i, name in enumerate(classes)}
self.points_loader = mmcv.build_from_cfg(points_loader, PIPELINES)
with open(info_path, 'rb') as f:
db_infos = pickle.load(f)
db_infos = mmcv.load(info_path)
# filter database infos
from mmdet3d.utils import get_root_logger
......@@ -244,8 +251,9 @@ class DataBaseSampler(object):
file_path = os.path.join(
self.data_root,
info['path']) if self.data_root else info['path']
s_points = np.fromfile(
file_path, dtype=np.float32).reshape([-1, 4])
results = dict(pts_filename=file_path)
s_points = self.points_loader(results)['points']
s_points[:, :3] += info['box3d_lidar'][:3]
count += 1
......
......@@ -145,7 +145,7 @@ def create_groundtruth_database(dataset_class_name,
type=dataset_class_name,
data_root=data_path,
ann_file=info_path,
)
use_valid_flag=True)
if dataset_class_name == 'KittiDataset':
file_client_args = dict(backend='disk')
dataset_cfg.update(
......@@ -176,7 +176,9 @@ def create_groundtruth_database(dataset_class_name,
dict(
type='LoadPointsFromMultiSweeps',
sweeps_num=10,
),
use_dim=[0, 1, 2, 3, 4],
pad_empty_sweeps=True,
remove_close=True),
dict(
type='LoadAnnotations3D',
with_bbox_3d=True,
......@@ -191,7 +193,6 @@ def create_groundtruth_database(dataset_class_name,
f'{info_prefix}_dbinfos_train.pkl')
mmcv.mkdir_or_exist(database_save_path)
all_db_infos = dict()
if with_mask:
coco = COCO(osp.join(data_path, mask_anno_path))
imgIds = coco.getImgIds()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment