from collections.abc import Sequence import mmcv import numpy as np import torch from mmcv.parallel import DataContainer as DC from ..registry import PIPELINES def to_tensor(data): """Convert objects of various python types to :obj:`torch.Tensor`. Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`Sequence`, :class:`int` and :class:`float`. """ if isinstance(data, torch.Tensor): return data elif isinstance(data, np.ndarray): return torch.from_numpy(data) elif isinstance(data, Sequence) and not mmcv.is_str(data): return torch.tensor(data) elif isinstance(data, int): return torch.LongTensor([data]) elif isinstance(data, float): return torch.FloatTensor([data]) else: raise TypeError('type {} cannot be converted to tensor.'.format( type(data))) @PIPELINES.register_module class ToTensor(object): def __init__(self, keys): self.keys = keys def __call__(self, results): for key in self.keys: results[key] = to_tensor(results[key]) return results def __repr__(self): return self.__class__.__name__ + '(keys={})'.format(self.keys) @PIPELINES.register_module class ImageToTensor(object): def __init__(self, keys): self.keys = keys def __call__(self, results): for key in self.keys: results[key] = to_tensor(results[key].transpose(2, 0, 1)) return results def __repr__(self): return self.__class__.__name__ + '(keys={})'.format(self.keys) @PIPELINES.register_module class Transpose(object): def __init__(self, keys, order): self.keys = keys self.order = order def __call__(self, results): for key in self.keys: results[key] = results[key].transpose(self.order) return results def __repr__(self): return self.__class__.__name__ + '(keys={}, order={})'.format( self.keys, self.order) @PIPELINES.register_module class ToDataContainer(object): def __init__(self, fields=(dict(key='img', stack=True), dict(key='gt_bboxes'), dict(key='gt_labels'))): self.fields = fields def __call__(self, results): for field in self.fields: field = field.copy() key = field.pop('key') results[key] = DC(results[key], **field) return results def __repr__(self): return self.__class__.__name__ + '(fields={})'.format(self.fields) @PIPELINES.register_module class DefaultFormatBundle(object): """Default formatting bundle. It simplifies the pipeline of formatting common fields, including "img", "proposals", "gt_bboxes", "gt_labels", "gt_masks" and "gt_semantic_seg". These fields are formatted as follows. - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True) - proposals: (1)to tensor, (2)to DataContainer - gt_bboxes: (1)to tensor, (2)to DataContainer - gt_bboxes_ignore: (1)to tensor, (2)to DataContainer - gt_labels: (1)to tensor, (2)to DataContainer - gt_masks: (1)to tensor, (2)to DataContainer (cpu_only=True) - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, (3)to DataContainer (stack=True) """ def __call__(self, results): if 'img' in results: img = np.ascontiguousarray(results['img'].transpose(2, 0, 1)) results['img'] = DC(to_tensor(img), stack=True) for key in ['proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels']: if key not in results: continue results[key] = DC(to_tensor(results[key])) if 'gt_masks' in results: results['gt_masks'] = DC(results['gt_masks'], cpu_only=True) if 'gt_semantic_seg' in results: results['gt_semantic_seg'] = DC( to_tensor(results['gt_semantic_seg'][None, ...]), stack=True) return results def __repr__(self): return self.__class__.__name__ @PIPELINES.register_module class Collect(object): def __init__(self, keys, meta_keys=('filename', 'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip', 'img_norm_cfg')): self.keys = keys self.meta_keys = meta_keys def __call__(self, results): data = {} img_meta = {} for key in self.meta_keys: img_meta[key] = results[key] data['img_meta'] = DC(img_meta, cpu_only=True) for key in self.keys: data[key] = results[key] return data def __repr__(self): return self.__class__.__name__ + '(keys={}, meta_keys={})'.format( self.keys, self.meta_keys)