Commit eca5a9f2 authored by VVsssssk's avatar VVsssssk Committed by ChaimZhu
Browse files

[Fix]Fix det3ddataset for new input version

parent 928aa884
...@@ -253,7 +253,8 @@ class Det3DDataset(BaseDataset): ...@@ -253,7 +253,8 @@ class Det3DDataset(BaseDataset):
if not self.test_mode and self.filter_empty_gt: if not self.test_mode and self.filter_empty_gt:
# after pipeline drop the example with empty annotations # after pipeline drop the example with empty annotations
# return None to random another in `__getitem__` # return None to random another in `__getitem__`
if example is None or len(example['gt_labels_3d']) == 0: if example is None or len(
example['data_sample'].gt_instances_3d.labels_3d) == 0:
return None return None
return example return example
......
...@@ -27,7 +27,6 @@ class KittiDataset(Det3DDataset): ...@@ -27,7 +27,6 @@ class KittiDataset(Det3DDataset):
Args: Args:
data_root (str): Path of dataset root. data_root (str): Path of dataset root.
ann_file (str): Path of annotation file. ann_file (str): Path of annotation file.
split (str): Split of input data.
pipeline (list[dict], optional): Pipeline used for data processing. pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None. Defaults to None.
modality (dict, optional): Modality to specify the sensor data used modality (dict, optional): Modality to specify the sensor data used
......
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
import numpy as np import numpy as np
import torch import torch
from mmcv.transforms.base import BaseTransform from mmcv.transforms.base import BaseTransform
from mmengine.data import InstanceData
from mmengine.registry import TRANSFORMS from mmengine.registry import TRANSFORMS
from mmdet3d.core import LiDARInstance3DBoxes from mmdet3d.core import LiDARInstance3DBoxes
from mmdet3d.core.data_structures import Det3DDataSample
from mmdet3d.datasets import KittiDataset from mmdet3d.datasets import KittiDataset
...@@ -23,6 +25,11 @@ def _generate_kitti_dataset_config(): ...@@ -23,6 +25,11 @@ def _generate_kitti_dataset_config():
def transform(self, info): def transform(self, info):
if 'ann_info' in info: if 'ann_info' in info:
info['gt_labels_3d'] = info['ann_info']['gt_labels_3d'] info['gt_labels_3d'] = info['ann_info']['gt_labels_3d']
data_sample = Det3DDataSample()
gt_instances_3d = InstanceData()
gt_instances_3d.labels_3d = info['gt_labels_3d']
data_sample.gt_instances_3d = gt_instances_3d
info['data_sample'] = data_sample
return info return info
pipeline = [ pipeline = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment