"tests/vscode:/vscode.git/clone" did not exist on "2bfa5a61fb21e03cb3e70b0cdace7bd8466a2817"
Unverified Commit d5bddd25 authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Enhance] Reuse some functions in `Datasets` loading data (#583)

* move _get_data to utils

* add comment
parent 25a736f7
...@@ -8,7 +8,7 @@ from torch.utils.data import Dataset ...@@ -8,7 +8,7 @@ from torch.utils.data import Dataset
from mmdet.datasets import DATASETS from mmdet.datasets import DATASETS
from ..core.bbox import get_box_type from ..core.bbox import get_box_type
from .pipelines import Compose from .pipelines import Compose
from .utils import get_loading_pipeline from .utils import extract_result_dict, get_loading_pipeline
@DATASETS.register_module() @DATASETS.register_module()
...@@ -293,28 +293,6 @@ class Custom3DDataset(Dataset): ...@@ -293,28 +293,6 @@ class Custom3DDataset(Dataset):
return Compose(loading_pipeline) return Compose(loading_pipeline)
return Compose(pipeline) return Compose(pipeline)
@staticmethod
def _get_data(results, key):
"""Extract and return the data corresponding to key in result dict.
Args:
results (dict): Data loaded using pipeline.
key (str): Key of the desired data.
Returns:
np.ndarray | torch.Tensor | None: Data term.
"""
if key not in results.keys():
return None
# results[key] may be data or list[data]
# data may be wrapped inside DataContainer
data = results[key]
if isinstance(data, list) or isinstance(data, tuple):
data = data[0]
if isinstance(data, mmcv.parallel.DataContainer):
data = data._data
return data
def _extract_data(self, index, pipeline, key, load_annos=False): def _extract_data(self, index, pipeline, key, load_annos=False):
"""Load data using input pipeline and extract data according to key. """Load data using input pipeline and extract data according to key.
...@@ -341,9 +319,9 @@ class Custom3DDataset(Dataset): ...@@ -341,9 +319,9 @@ class Custom3DDataset(Dataset):
# extract data items according to keys # extract data items according to keys
if isinstance(key, str): if isinstance(key, str):
data = self._get_data(example, key) data = extract_result_dict(example, key)
else: else:
data = [self._get_data(example, k) for k in key] data = [extract_result_dict(example, k) for k in key]
if load_annos: if load_annos:
self.test_mode = original_test_mode self.test_mode = original_test_mode
......
...@@ -8,7 +8,7 @@ from torch.utils.data import Dataset ...@@ -8,7 +8,7 @@ from torch.utils.data import Dataset
from mmdet.datasets import DATASETS from mmdet.datasets import DATASETS
from mmseg.datasets import DATASETS as SEG_DATASETS from mmseg.datasets import DATASETS as SEG_DATASETS
from .pipelines import Compose from .pipelines import Compose
from .utils import get_loading_pipeline from .utils import extract_result_dict, get_loading_pipeline
@DATASETS.register_module() @DATASETS.register_module()
...@@ -399,28 +399,6 @@ class Custom3DSegDataset(Dataset): ...@@ -399,28 +399,6 @@ class Custom3DSegDataset(Dataset):
return Compose(loading_pipeline) return Compose(loading_pipeline)
return Compose(pipeline) return Compose(pipeline)
@staticmethod
def _get_data(results, key):
"""Extract and return the data corresponding to key in result dict.
Args:
results (dict): Data loaded using pipeline.
key (str): Key of the desired data.
Returns:
np.ndarray | torch.Tensor | None: Data term.
"""
if key not in results.keys():
return None
# results[key] may be data or list[data]
# data may be wrapped inside DataContainer
data = results[key]
if isinstance(data, list) or isinstance(data, tuple):
data = data[0]
if isinstance(data, mmcv.parallel.DataContainer):
data = data._data
return data
def _extract_data(self, index, pipeline, key, load_annos=False): def _extract_data(self, index, pipeline, key, load_annos=False):
"""Load data using input pipeline and extract data according to key. """Load data using input pipeline and extract data according to key.
...@@ -447,9 +425,9 @@ class Custom3DSegDataset(Dataset): ...@@ -447,9 +425,9 @@ class Custom3DSegDataset(Dataset):
# extract data items according to keys # extract data items according to keys
if isinstance(key, str): if isinstance(key, str):
data = self._get_data(example, key) data = extract_result_dict(example, key)
else: else:
data = [self._get_data(example, k) for k in key] data = [extract_result_dict(example, k) for k in key]
if load_annos: if load_annos:
self.test_mode = original_test_mode self.test_mode = original_test_mode
......
...@@ -13,7 +13,7 @@ from mmdet.datasets import DATASETS, CocoDataset ...@@ -13,7 +13,7 @@ from mmdet.datasets import DATASETS, CocoDataset
from ..core import show_multi_modality_result from ..core import show_multi_modality_result
from ..core.bbox import CameraInstance3DBoxes, get_box_type, mono_cam_box2vis from ..core.bbox import CameraInstance3DBoxes, get_box_type, mono_cam_box2vis
from .pipelines import Compose from .pipelines import Compose
from .utils import get_loading_pipeline from .utils import extract_result_dict, get_loading_pipeline
@DATASETS.register_module() @DATASETS.register_module()
...@@ -541,28 +541,6 @@ class NuScenesMonoDataset(CocoDataset): ...@@ -541,28 +541,6 @@ class NuScenesMonoDataset(CocoDataset):
self.show(results, out_dir, pipeline=pipeline) self.show(results, out_dir, pipeline=pipeline)
return results_dict return results_dict
@staticmethod
def _get_data(results, key):
"""Extract and return the data corresponding to key in result dict.
Args:
results (dict): Data loaded using pipeline.
key (str): Key of the desired data.
Returns:
np.ndarray | torch.Tensor | None: Data term.
"""
if key not in results.keys():
return None
# results[key] may be data or list[data]
# data may be wrapped inside DataContainer
data = results[key]
if isinstance(data, list) or isinstance(data, tuple):
data = data[0]
if isinstance(data, mmcv.parallel.DataContainer):
data = data._data
return data
def _extract_data(self, index, pipeline, key, load_annos=False): def _extract_data(self, index, pipeline, key, load_annos=False):
"""Load data using input pipeline and extract data according to key. """Load data using input pipeline and extract data according to key.
...@@ -590,9 +568,9 @@ class NuScenesMonoDataset(CocoDataset): ...@@ -590,9 +568,9 @@ class NuScenesMonoDataset(CocoDataset):
# extract data items according to keys # extract data items according to keys
if isinstance(key, str): if isinstance(key, str):
data = self._get_data(example, key) data = extract_result_dict(example, key)
else: else:
data = [self._get_data(example, k) for k in key] data = [extract_result_dict(example, k) for k in key]
return data return data
......
import mmcv
# yapf: disable # yapf: disable
from mmdet3d.datasets.pipelines import (Collect3D, DefaultFormatBundle3D, from mmdet3d.datasets.pipelines import (Collect3D, DefaultFormatBundle3D,
LoadAnnotations3D, LoadAnnotations3D,
...@@ -108,3 +110,30 @@ def get_loading_pipeline(pipeline): ...@@ -108,3 +110,30 @@ def get_loading_pipeline(pipeline):
'The data pipeline in your config file must include ' \ 'The data pipeline in your config file must include ' \
'loading step.' 'loading step.'
return loading_pipeline return loading_pipeline
def extract_result_dict(results, key):
"""Extract and return the data corresponding to key in result dict.
``results`` is a dict output from `pipeline(input_dict)`, which is the
loaded data from ``Dataset`` class.
The data terms inside may be wrapped in list, tuple and DataContainer, so
this function essentially extracts data from these wrappers.
Args:
results (dict): Data loaded using pipeline.
key (str): Key of the desired data.
Returns:
np.ndarray | torch.Tensor | None: Data term.
"""
if key not in results.keys():
return None
# results[key] may be data or list[data] or tuple[data]
# data may be wrapped inside DataContainer
data = results[key]
if isinstance(data, (list, tuple)):
data = data[0]
if isinstance(data, mmcv.parallel.DataContainer):
data = data._data
return data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment