Commit 2155ff69 authored by VVsssssk's avatar VVsssssk Committed by ChaimZhu
Browse files

[Fix]fix dbsampler and object sampler

parent 14a18646
...@@ -114,7 +114,7 @@ class DataBaseSampler(object): ...@@ -114,7 +114,7 @@ class DataBaseSampler(object):
self.classes = classes self.classes = classes
self.cat2label = {name: i for i, name in enumerate(classes)} self.cat2label = {name: i for i, name in enumerate(classes)}
self.label2cat = {i: name for i, name in enumerate(classes)} self.label2cat = {i: name for i, name in enumerate(classes)}
self.points_loader = mmcv.build_from_cfg(points_loader, PIPELINES) self.points_loader = TRANSFORMS.build(points_loader)
self.file_client = mmcv.FileClient(**file_client_args) self.file_client = mmcv.FileClient(**file_client_args)
# load data base infos # load data base infos
...@@ -267,7 +267,7 @@ class DataBaseSampler(object): ...@@ -267,7 +267,7 @@ class DataBaseSampler(object):
file_path = os.path.join( file_path = os.path.join(
self.data_root, self.data_root,
info['path']) if self.data_root else info['path'] info['path']) if self.data_root else info['path']
results = dict(pts_filename=file_path) results = dict(lidar_points=dict(lidar_path=file_path))
s_points = self.points_loader(results)['points'] s_points = self.points_loader(results)['points']
s_points.translate(info['box3d_lidar'][:3]) s_points.translate(info['box3d_lidar'][:3])
......
...@@ -7,7 +7,6 @@ import cv2 ...@@ -7,7 +7,6 @@ import cv2
import numpy as np import numpy as np
from mmcv import is_tuple_of from mmcv import is_tuple_of
from mmcv.transforms import BaseTransform from mmcv.transforms import BaseTransform
from mmengine.registry import build_from_cfg
from mmdet3d.core import VoxelGenerator from mmdet3d.core import VoxelGenerator
from mmdet3d.core.bbox import (CameraInstance3DBoxes, DepthInstance3DBoxes, from mmdet3d.core.bbox import (CameraInstance3DBoxes, DepthInstance3DBoxes,
...@@ -334,7 +333,7 @@ class ObjectSample(BaseTransform): ...@@ -334,7 +333,7 @@ class ObjectSample(BaseTransform):
self.sample_2d = sample_2d self.sample_2d = sample_2d
if 'type' not in db_sampler.keys(): if 'type' not in db_sampler.keys():
db_sampler['type'] = 'DataBaseSampler' db_sampler['type'] = 'DataBaseSampler'
self.db_sampler = build_from_cfg(db_sampler, TRANSFORMS) self.db_sampler = TRANSFORMS.build(db_sampler)
self.use_ground_plane = use_ground_plane self.use_ground_plane = use_ground_plane
@staticmethod @staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment