utils.py 3.6 KB
Newer Older
wangg12's avatar
wangg12 committed
1
import copy
2
3
from collections import Sequence

Kai Chen's avatar
Kai Chen committed
4
import mmcv
wangg12's avatar
wangg12 committed
5
from mmcv.runner import obj_from_dict
Kai Chen's avatar
Kai Chen committed
6
import torch
Kai Chen's avatar
Kai Chen committed
7
8
9

import matplotlib.pyplot as plt
import numpy as np
wangg12's avatar
wangg12 committed
10
from .concat_dataset import ConcatDataset
yhcao6's avatar
yhcao6 committed
11
from .repeat_dataset import RepeatDataset
wangg12's avatar
wangg12 committed
12
from .. import datasets
Kai Chen's avatar
Kai Chen committed
13

wangg12's avatar
wangg12 committed
14

Kai Chen's avatar
Kai Chen committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def to_tensor(data):
    """Convert objects of various python types to :obj:`torch.Tensor`.

    Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
    :class:`Sequence`, :class:`int` and :class:`float`.
    """
    if isinstance(data, torch.Tensor):
        return data
    elif isinstance(data, np.ndarray):
        return torch.from_numpy(data)
    elif isinstance(data, Sequence) and not mmcv.is_str(data):
        return torch.tensor(data)
    elif isinstance(data, int):
        return torch.LongTensor([data])
    elif isinstance(data, float):
        return torch.FloatTensor([data])
    else:
        raise TypeError('type {} cannot be converted to tensor.'.format(
            type(data)))


Kai Chen's avatar
Kai Chen committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def random_scale(img_scales, mode='range'):
    """Randomly select a scale from a list of scales or scale ranges.

    Args:
        img_scales (list[tuple]): Image scale or scale range.
        mode (str): "range" or "value".

    Returns:
        tuple: Sampled image scale.
    """
    num_scales = len(img_scales)
    if num_scales == 1:  # fixed scale is specified
        img_scale = img_scales[0]
    elif num_scales == 2:  # randomly sample a scale
        if mode == 'range':
            img_scale_long = [max(s) for s in img_scales]
            img_scale_short = [min(s) for s in img_scales]
            long_edge = np.random.randint(
                min(img_scale_long),
                max(img_scale_long) + 1)
            short_edge = np.random.randint(
                min(img_scale_short),
                max(img_scale_short) + 1)
            img_scale = (long_edge, short_edge)
        elif mode == 'value':
            img_scale = img_scales[np.random.randint(num_scales)]
    else:
        if mode != 'value':
            raise ValueError(
                'Only "value" mode supports more than 2 image scales')
        img_scale = img_scales[np.random.randint(num_scales)]
    return img_scale


def show_ann(coco, img, ann_info):
    plt.imshow(mmcv.bgr2rgb(img))
    plt.axis('off')
    coco.showAnns(ann_info)
    plt.show()
wangg12's avatar
wangg12 committed
75
76
77


def get_dataset(data_cfg):
yhcao6's avatar
yhcao6 committed
78
    if data_cfg['type'] == 'RepeatDataset':
yhcao6's avatar
format  
yhcao6 committed
79
80
        return RepeatDataset(
            get_dataset(data_cfg['dataset']), data_cfg['times'])
yhcao6's avatar
yhcao6 committed
81

82
    if isinstance(data_cfg['ann_file'], (list, tuple)):
wangg12's avatar
wangg12 committed
83
        ann_files = data_cfg['ann_file']
84
85
86
87
88
89
90
91
        num_dset = len(ann_files)
    else:
        ann_files = [data_cfg['ann_file']]
        num_dset = 1

    if 'proposal_file' in data_cfg.keys():
        if isinstance(data_cfg['proposal_file'], (list, tuple)):
            proposal_files = data_cfg['proposal_file']
wangg12's avatar
wangg12 committed
92
        else:
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
            proposal_files = [data_cfg['proposal_file']]
    else:
        proposal_files = [None] * num_dset
    assert len(proposal_files) == num_dset

    if isinstance(data_cfg['img_prefix'], (list, tuple)):
        img_prefixes = data_cfg['img_prefix']
    else:
        img_prefixes = [data_cfg['img_prefix']] * num_dset
    assert len(img_prefixes) == num_dset

    dsets = []
    for i in range(num_dset):
        data_info = copy.deepcopy(data_cfg)
        data_info['ann_file'] = ann_files[i]
        data_info['proposal_file'] = proposal_files[i]
        data_info['img_prefix'] = img_prefixes[i]
        dset = obj_from_dict(data_info, datasets)
        dsets.append(dset)
    if len(dsets) > 1:
        dset = ConcatDataset(dsets)
wangg12's avatar
wangg12 committed
114
    else:
115
        dset = dsets[0]
wangg12's avatar
wangg12 committed
116
    return dset