inference.py 2.62 KB
Newer Older
myownskyW7's avatar
myownskyW7 committed
1
2
import mmcv
import numpy as np
3
import pycocotools.mask as maskUtils
myownskyW7's avatar
myownskyW7 committed
4
5
import torch

6
from mmdet.core import get_classes
myownskyW7's avatar
myownskyW7 committed
7
8
9
10
11
12
13
from mmdet.datasets import to_tensor
from mmdet.datasets.transforms import ImageTransform


def _prepare_data(img, img_transform, cfg, device):
    ori_shape = img.shape
    img, img_shape, pad_shape, scale_factor = img_transform(
高志华's avatar
高志华 committed
14
15
16
        img,
        scale=cfg.data.test.img_scale,
        keep_ratio=cfg.data.test.get('resize_keep_ratio', True))
myownskyW7's avatar
myownskyW7 committed
17
18
19
20
21
22
23
24
25
26
27
28
    img = to_tensor(img).to(device).unsqueeze(0)
    img_meta = [
        dict(
            ori_shape=ori_shape,
            img_shape=img_shape,
            pad_shape=pad_shape,
            scale_factor=scale_factor,
            flip=False)
    ]
    return dict(img=[img], img_meta=[img_meta])


Kai Chen's avatar
Kai Chen committed
29
30
31
32
33
34
35
36
37
38
39
def _inference_single(model, img, img_transform, cfg, device):
    img = mmcv.imread(img)
    data = _prepare_data(img, img_transform, cfg, device)
    with torch.no_grad():
        result = model(return_loss=False, rescale=True, **data)
    return result


def _inference_generator(model, imgs, img_transform, cfg, device):
    for img in imgs:
        yield _inference_single(model, img, img_transform, cfg, device)
myownskyW7's avatar
myownskyW7 committed
40

Kai Chen's avatar
Kai Chen committed
41
42

def inference_detector(model, imgs, cfg, device='cuda:0'):
myownskyW7's avatar
myownskyW7 committed
43
    img_transform = ImageTransform(
myownskyW7's avatar
myownskyW7 committed
44
        size_divisor=cfg.data.test.size_divisor, **cfg.img_norm_cfg)
myownskyW7's avatar
myownskyW7 committed
45
46
    model = model.to(device)
    model.eval()
Kai Chen's avatar
Kai Chen committed
47
48
49
50
51

    if not isinstance(imgs, list):
        return _inference_single(model, imgs, img_transform, cfg, device)
    else:
        return _inference_generator(model, imgs, img_transform, cfg, device)
myownskyW7's avatar
myownskyW7 committed
52
53


54
55
def show_result(img, result, dataset='coco', score_thr=0.3, out_file=None):
    img = mmcv.imread(img)
myownskyW7's avatar
myownskyW7 committed
56
    class_names = get_classes(dataset)
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    if isinstance(result, tuple):
        bbox_result, segm_result = result
    else:
        bbox_result, segm_result = result, None
    bboxes = np.vstack(bbox_result)
    # draw segmentation masks
    if segm_result is not None:
        segms = mmcv.concat_list(segm_result)
        inds = np.where(bboxes[:, -1] > score_thr)[0]
        for i in inds:
            color_mask = np.random.randint(
                0, 256, (1, 3), dtype=np.uint8)
            mask = maskUtils.decode(segms[i]).astype(np.bool)
            img[mask] = img[mask] * 0.5 + color_mask * 0.5
    # draw bounding boxes
myownskyW7's avatar
myownskyW7 committed
72
73
    labels = [
        np.full(bbox.shape[0], i, dtype=np.int32)
74
        for i, bbox in enumerate(bbox_result)
myownskyW7's avatar
myownskyW7 committed
75
76
77
78
79
80
81
    ]
    labels = np.concatenate(labels)
    mmcv.imshow_det_bboxes(
        img.copy(),
        bboxes,
        labels,
        class_names=class_names,
82
        score_thr=score_thr,
zhijl's avatar
zhijl committed
83
84
        show=out_file is None,
        out_file=out_file)