import functools from collections import Sequence import mmcv import numpy as np import torch def to_tensor(data): """Convert objects of various python types to :obj:`torch.Tensor`. Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`Sequence`, :class:`int` and :class:`float`. """ if isinstance(data, np.ndarray): return torch.from_numpy(data) elif isinstance(data, torch.Tensor): return data elif isinstance(data, Sequence) and not mmcv.is_str(data): return torch.tensor(data) elif isinstance(data, int): return torch.LongTensor([data]) elif isinstance(data, float): return torch.FloatTensor([data]) else: raise TypeError('type {} cannot be converted to tensor.'.format( type(data))) def assert_tensor_type(func): @functools.wraps(func) def wrapper(*args, **kwargs): if not isinstance(args[0].data, torch.Tensor): raise AttributeError('{} has no attribute {} for type {}'.format( args[0].__class__.__name__, func.__name__, args[0].datatype)) return func(*args, **kwargs) return wrapper class DataContainer(object): def __init__(self, data, stack=False, padding_value=0): if isinstance(data, list): self._data = data else: self._data = to_tensor(data) self._stack = stack self._padding_value = padding_value def __repr__(self): return '{}({})'.format(self.__class__.__name__, repr(self.data)) @property def data(self): return self._data @property def datatype(self): if isinstance(self.data, torch.Tensor): return self.data.type() else: return type(self.data) @property def stack(self): return self._stack @property def padding_value(self): return self._padding_value @assert_tensor_type def size(self, *args, **kwargs): return self.data.size(*args, **kwargs) @assert_tensor_type def dim(self): return self.data.dim()