utils.py 2.81 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
from collections import Sequence
wangg12's avatar
wangg12 committed
2
import copy
Kai Chen's avatar
Kai Chen committed
3
import mmcv
wangg12's avatar
wangg12 committed
4
from mmcv.runner import obj_from_dict
Kai Chen's avatar
Kai Chen committed
5
import torch
Kai Chen's avatar
Kai Chen committed
6
7
8

import matplotlib.pyplot as plt
import numpy as np
wangg12's avatar
wangg12 committed
9
10
from .concat_dataset import ConcatDataset
from .. import datasets
Kai Chen's avatar
Kai Chen committed
11

Kai Chen's avatar
Kai Chen committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def to_tensor(data):
    """Convert objects of various python types to :obj:`torch.Tensor`.

    Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
    :class:`Sequence`, :class:`int` and :class:`float`.
    """
    if isinstance(data, torch.Tensor):
        return data
    elif isinstance(data, np.ndarray):
        return torch.from_numpy(data)
    elif isinstance(data, Sequence) and not mmcv.is_str(data):
        return torch.tensor(data)
    elif isinstance(data, int):
        return torch.LongTensor([data])
    elif isinstance(data, float):
        return torch.FloatTensor([data])
    else:
        raise TypeError('type {} cannot be converted to tensor.'.format(
            type(data)))


Kai Chen's avatar
Kai Chen committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def random_scale(img_scales, mode='range'):
    """Randomly select a scale from a list of scales or scale ranges.

    Args:
        img_scales (list[tuple]): Image scale or scale range.
        mode (str): "range" or "value".

    Returns:
        tuple: Sampled image scale.
    """
    num_scales = len(img_scales)
    if num_scales == 1:  # fixed scale is specified
        img_scale = img_scales[0]
    elif num_scales == 2:  # randomly sample a scale
        if mode == 'range':
            img_scale_long = [max(s) for s in img_scales]
            img_scale_short = [min(s) for s in img_scales]
            long_edge = np.random.randint(
                min(img_scale_long),
                max(img_scale_long) + 1)
            short_edge = np.random.randint(
                min(img_scale_short),
                max(img_scale_short) + 1)
            img_scale = (long_edge, short_edge)
        elif mode == 'value':
            img_scale = img_scales[np.random.randint(num_scales)]
    else:
        if mode != 'value':
            raise ValueError(
                'Only "value" mode supports more than 2 image scales')
        img_scale = img_scales[np.random.randint(num_scales)]
    return img_scale


def show_ann(coco, img, ann_info):
    plt.imshow(mmcv.bgr2rgb(img))
    plt.axis('off')
    coco.showAnns(ann_info)
    plt.show()
wangg12's avatar
wangg12 committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89


def get_dataset(data_cfg):
    if isinstance(data_cfg['ann_file'], list) or isinstance(data_cfg['ann_file'], tuple):
        ann_files = data_cfg['ann_file']
        dsets = []
        for ann_file in ann_files:
            data_info = copy.deepcopy(data_cfg)
            data_info['ann_file'] = ann_file
            dset = obj_from_dict(data_info, datasets)
            dsets.append(dset)
        if len(dsets) > 1:
            dset = ConcatDataset(dsets)            
        else:
            dset = dsets[0]
    else:
        dset = obj_from_dict(data_cfg, datasets)
    return dset