# Copyright (c) OpenMMLab. All rights reserved. from collections.abc import Sequence import mmcv import numpy as np import torch from mmcv.parallel import DataContainer as DC from ..builder import PIPELINES def to_tensor(data): """Convert objects of various python types to :obj:`torch.Tensor`. Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`Sequence`, :class:`int` and :class:`float`. """ if isinstance(data, torch.Tensor): return data if isinstance(data, np.ndarray): return torch.from_numpy(data) if isinstance(data, Sequence) and not mmcv.is_str(data): return torch.tensor(data) if isinstance(data, int): return torch.LongTensor([data]) if isinstance(data, float): return torch.FloatTensor([data]) raise TypeError(f'type {type(data)} cannot be converted to tensor.') @PIPELINES.register_module() class ToTensor: """Convert some values in results dict to `torch.Tensor` type in data loader pipeline. Args: keys (Sequence[str]): Required keys to be converted. """ def __init__(self, keys): self.keys = keys def __call__(self, results): """Call function. Args: results (dict): A dict containing the necessary information and data for augmentation. Returns: dict: A dict containing the processed data and information. """ for key in self.keys: results[key] = to_tensor(results[key]) return results def __repr__(self): return self.__class__.__name__ + f'(keys={self.keys})' @PIPELINES.register_module() class ImageToTensor: """Convert image type to `torch.Tensor` type. Args: keys (Sequence[str]): Required keys to be converted. to_float32 (bool): Whether convert numpy image array to np.float32 before converted to tensor. Default: True. """ def __init__(self, keys, to_float32=True): self.keys = keys self.to_float32 = to_float32 def __call__(self, results): """Call function. Args: results (dict): A dict containing the necessary information and data for augmentation. Returns: dict: A dict containing the processed data and information. """ for key in self.keys: # deal with gray scale img: expand a color channel if len(results[key].shape) == 2: results[key] = results[key][..., None] if self.to_float32 and not isinstance(results[key], np.float32): results[key] = results[key].astype(np.float32) results[key] = to_tensor(results[key].transpose(2, 0, 1)) return results def __repr__(self): return self.__class__.__name__ + ( f'(keys={self.keys}, to_float32={self.to_float32})') @PIPELINES.register_module() class Collect: """Collect data from the loader relevant to the specific task. This is usually the last stage of the data loader pipeline. Typically keys is set to some subset of "img", "gt_labels". The "img_meta" item is always populated. The contents of the "meta" dictionary depends on "meta_keys". Args: keys (Sequence[str]): Required keys to be collected. meta_keys (Sequence[str]): Required keys to be collected to "meta". Default: None. """ def __init__(self, keys, meta_keys=None): self.keys = keys self.meta_keys = meta_keys def __call__(self, results): """Call function. Args: results (dict): A dict containing the necessary information and data for augmentation. Returns: dict: A dict containing the processed data and information. """ data = {} img_meta = {} for key in self.meta_keys: img_meta[key] = results[key] data['meta'] = DC(img_meta, cpu_only=True) for key in self.keys: data[key] = results[key] return data def __repr__(self): return self.__class__.__name__ + ( f'(keys={self.keys}, meta_keys={self.meta_keys})')