misc.py 3.6 KB
Newer Older
pangjm's avatar
pangjm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import subprocess

import mmcv
import numpy as np
import torch

__all__ = ['tensor2imgs', 'unique', 'unmap', 'results2json']


def tensor2imgs(tensor,
                color_order='RGB',
                color_mean=(0.485, 0.456, 0.406),
                color_std=(0.229, 0.224, 0.225)):
    assert color_order in ['RGB', 'BGR']
    img_per_gpu = tensor.size(0)
    color_mean = np.array(color_mean, dtype=np.float32)
    color_std = np.array(color_std, dtype=np.float32)
    imgs = []
    for img_id in range(img_per_gpu):
        img = tensor[img_id, ...].cpu().numpy().transpose(1, 2, 0)
        if color_order == 'RGB':
            img = mmcv.rgb2bgr(img)
        img = img * color_std + color_mean
        imgs.append(np.ascontiguousarray(img))
    return imgs


def unique(tensor):
    if tensor.is_cuda:
        u_tensor = np.unique(tensor.cpu().numpy())
        return tensor.new_tensor(u_tensor)
    else:
        return torch.unique(tensor)


def unmap(data, count, inds, fill=0):
    """ Unmap a subset of item (data) back to the original set of items (of
    size count) """
    if data.dim() == 1:
        ret = data.new_full((count, ), fill)
        ret[inds] = data
    else:
        new_size = (count, ) + data.size()[1:]
        ret = data.new_full(new_size, fill)
        ret[inds, :] = data
    return ret

def xyxy2xywh(bbox):
    _bbox = bbox.tolist()
    return [
        _bbox[0],
        _bbox[1],
        _bbox[2] - _bbox[0] + 1,
        _bbox[3] - _bbox[1] + 1,
    ]

def det2json(dataset, results):
    json_results = []
    for idx in range(len(dataset)):
        img_id = dataset.img_ids[idx]
        result = results[idx]
        for label in range(len(result)):
            bboxes = result[label]
            for i in range(bboxes.shape[0]):
                data = dict()
                data['image_id'] = img_id
                data['bbox'] = xyxy2xywh(bboxes[i])
                data['score'] = float(bboxes[i][4])
                data['category_id'] = dataset.cat_ids[label]
                json_results.append(data)
    return json_results


def segm2json(dataset, results):
    json_results = []
    for idx in range(len(dataset)):
        img_id = dataset.img_ids[idx]
        det, seg = results[idx]
        for label in range(len(det)):
            bboxes = det[label]
            segms = seg[label]
            for i in range(bboxes.shape[0]):
                data = dict()
                data['image_id'] = img_id
                data['bbox'] = xyxy2xywh(bboxes[i])
                data['score'] = float(bboxes[i][4])
                data['category_id'] = dataset.cat_ids[label]
                segms[i]['counts'] = segms[i]['counts'].decode()
                data['segmentation'] = segms[i]
                json_results.append(data)
    return json_results


def proposal2json(dataset, results):
    json_results = []
    for idx in range(len(dataset)):
        img_id = dataset.img_ids[idx]
        bboxes = results[idx]
        for i in range(bboxes.shape[0]):
            data = dict()
            data['image_id'] = img_id
            data['bbox'] = xyxy2xywh(bboxes[i])
            data['score'] = float(bboxes[i][4])
            data['category_id'] = 1
            json_results.append(data)
    return json_results


def results2json(dataset, results, out_file):
    if isinstance(results[0], list):
        json_results = det2json(dataset, results)
    elif isinstance(results[0], tuple):
        json_results = segm2json(dataset, results)
    elif isinstance(results[0], np.ndarray):
        json_results = proposal2json(dataset, results)
    else:
        raise TypeError('invalid type of results')
    mmcv.dump(json_results, out_file)