from collections import Sequence import copy import mmcv from mmcv.runner import obj_from_dict import torch import matplotlib.pyplot as plt import numpy as np from .concat_dataset import ConcatDataset from .. import datasets def to_tensor(data): """Convert objects of various python types to :obj:`torch.Tensor`. Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`Sequence`, :class:`int` and :class:`float`. """ if isinstance(data, torch.Tensor): return data elif isinstance(data, np.ndarray): return torch.from_numpy(data) elif isinstance(data, Sequence) and not mmcv.is_str(data): return torch.tensor(data) elif isinstance(data, int): return torch.LongTensor([data]) elif isinstance(data, float): return torch.FloatTensor([data]) else: raise TypeError('type {} cannot be converted to tensor.'.format( type(data))) def random_scale(img_scales, mode='range'): """Randomly select a scale from a list of scales or scale ranges. Args: img_scales (list[tuple]): Image scale or scale range. mode (str): "range" or "value". Returns: tuple: Sampled image scale. """ num_scales = len(img_scales) if num_scales == 1: # fixed scale is specified img_scale = img_scales[0] elif num_scales == 2: # randomly sample a scale if mode == 'range': img_scale_long = [max(s) for s in img_scales] img_scale_short = [min(s) for s in img_scales] long_edge = np.random.randint( min(img_scale_long), max(img_scale_long) + 1) short_edge = np.random.randint( min(img_scale_short), max(img_scale_short) + 1) img_scale = (long_edge, short_edge) elif mode == 'value': img_scale = img_scales[np.random.randint(num_scales)] else: if mode != 'value': raise ValueError( 'Only "value" mode supports more than 2 image scales') img_scale = img_scales[np.random.randint(num_scales)] return img_scale def show_ann(coco, img, ann_info): plt.imshow(mmcv.bgr2rgb(img)) plt.axis('off') coco.showAnns(ann_info) plt.show() def get_dataset(data_cfg): if isinstance(data_cfg['ann_file'], list) or isinstance(data_cfg['ann_file'], tuple): ann_files = data_cfg['ann_file'] dsets = [] for ann_file in ann_files: data_info = copy.deepcopy(data_cfg) data_info['ann_file'] = ann_file dset = obj_from_dict(data_info, datasets) dsets.append(dset) if len(dsets) > 1: dset = ConcatDataset(dsets) else: dset = dsets[0] else: dset = obj_from_dict(data_cfg, datasets) return dset