Unverified Commit d5bddd25 authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Enhance] Reuse some functions in `Datasets` loading data (#583)

* move _get_data to utils

* add comment
parent 25a736f7
......@@ -8,7 +8,7 @@ from torch.utils.data import Dataset
from mmdet.datasets import DATASETS
from ..core.bbox import get_box_type
from .pipelines import Compose
from .utils import get_loading_pipeline
from .utils import extract_result_dict, get_loading_pipeline
@DATASETS.register_module()
......@@ -293,28 +293,6 @@ class Custom3DDataset(Dataset):
return Compose(loading_pipeline)
return Compose(pipeline)
@staticmethod
def _get_data(results, key):
"""Extract and return the data corresponding to key in result dict.
Args:
results (dict): Data loaded using pipeline.
key (str): Key of the desired data.
Returns:
np.ndarray | torch.Tensor | None: Data term.
"""
if key not in results.keys():
return None
# results[key] may be data or list[data]
# data may be wrapped inside DataContainer
data = results[key]
if isinstance(data, list) or isinstance(data, tuple):
data = data[0]
if isinstance(data, mmcv.parallel.DataContainer):
data = data._data
return data
def _extract_data(self, index, pipeline, key, load_annos=False):
"""Load data using input pipeline and extract data according to key.
......@@ -341,9 +319,9 @@ class Custom3DDataset(Dataset):
# extract data items according to keys
if isinstance(key, str):
data = self._get_data(example, key)
data = extract_result_dict(example, key)
else:
data = [self._get_data(example, k) for k in key]
data = [extract_result_dict(example, k) for k in key]
if load_annos:
self.test_mode = original_test_mode
......
......@@ -8,7 +8,7 @@ from torch.utils.data import Dataset
from mmdet.datasets import DATASETS
from mmseg.datasets import DATASETS as SEG_DATASETS
from .pipelines import Compose
from .utils import get_loading_pipeline
from .utils import extract_result_dict, get_loading_pipeline
@DATASETS.register_module()
......@@ -399,28 +399,6 @@ class Custom3DSegDataset(Dataset):
return Compose(loading_pipeline)
return Compose(pipeline)
@staticmethod
def _get_data(results, key):
"""Extract and return the data corresponding to key in result dict.
Args:
results (dict): Data loaded using pipeline.
key (str): Key of the desired data.
Returns:
np.ndarray | torch.Tensor | None: Data term.
"""
if key not in results.keys():
return None
# results[key] may be data or list[data]
# data may be wrapped inside DataContainer
data = results[key]
if isinstance(data, list) or isinstance(data, tuple):
data = data[0]
if isinstance(data, mmcv.parallel.DataContainer):
data = data._data
return data
def _extract_data(self, index, pipeline, key, load_annos=False):
"""Load data using input pipeline and extract data according to key.
......@@ -447,9 +425,9 @@ class Custom3DSegDataset(Dataset):
# extract data items according to keys
if isinstance(key, str):
data = self._get_data(example, key)
data = extract_result_dict(example, key)
else:
data = [self._get_data(example, k) for k in key]
data = [extract_result_dict(example, k) for k in key]
if load_annos:
self.test_mode = original_test_mode
......
......@@ -13,7 +13,7 @@ from mmdet.datasets import DATASETS, CocoDataset
from ..core import show_multi_modality_result
from ..core.bbox import CameraInstance3DBoxes, get_box_type, mono_cam_box2vis
from .pipelines import Compose
from .utils import get_loading_pipeline
from .utils import extract_result_dict, get_loading_pipeline
@DATASETS.register_module()
......@@ -541,28 +541,6 @@ class NuScenesMonoDataset(CocoDataset):
self.show(results, out_dir, pipeline=pipeline)
return results_dict
@staticmethod
def _get_data(results, key):
"""Extract and return the data corresponding to key in result dict.
Args:
results (dict): Data loaded using pipeline.
key (str): Key of the desired data.
Returns:
np.ndarray | torch.Tensor | None: Data term.
"""
if key not in results.keys():
return None
# results[key] may be data or list[data]
# data may be wrapped inside DataContainer
data = results[key]
if isinstance(data, list) or isinstance(data, tuple):
data = data[0]
if isinstance(data, mmcv.parallel.DataContainer):
data = data._data
return data
def _extract_data(self, index, pipeline, key, load_annos=False):
"""Load data using input pipeline and extract data according to key.
......@@ -590,9 +568,9 @@ class NuScenesMonoDataset(CocoDataset):
# extract data items according to keys
if isinstance(key, str):
data = self._get_data(example, key)
data = extract_result_dict(example, key)
else:
data = [self._get_data(example, k) for k in key]
data = [extract_result_dict(example, k) for k in key]
return data
......
import mmcv
# yapf: disable
from mmdet3d.datasets.pipelines import (Collect3D, DefaultFormatBundle3D,
LoadAnnotations3D,
......@@ -108,3 +110,30 @@ def get_loading_pipeline(pipeline):
'The data pipeline in your config file must include ' \
'loading step.'
return loading_pipeline
def extract_result_dict(results, key):
"""Extract and return the data corresponding to key in result dict.
``results`` is a dict output from `pipeline(input_dict)`, which is the
loaded data from ``Dataset`` class.
The data terms inside may be wrapped in list, tuple and DataContainer, so
this function essentially extracts data from these wrappers.
Args:
results (dict): Data loaded using pipeline.
key (str): Key of the desired data.
Returns:
np.ndarray | torch.Tensor | None: Data term.
"""
if key not in results.keys():
return None
# results[key] may be data or list[data] or tuple[data]
# data may be wrapped inside DataContainer
data = results[key]
if isinstance(data, (list, tuple)):
data = data[0]
if isinstance(data, mmcv.parallel.DataContainer):
data = data._data
return data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment