inference.py 1.88 KB
Newer Older
myownskyW7's avatar
myownskyW7 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import mmcv
import numpy as np
import torch

from mmdet.datasets import to_tensor
from mmdet.datasets.transforms import ImageTransform
from mmdet.core import get_classes


def _prepare_data(img, img_transform, cfg, device):
    ori_shape = img.shape
    img, img_shape, pad_shape, scale_factor = img_transform(
        img, scale=cfg.data.test.img_scale)
    img = to_tensor(img).to(device).unsqueeze(0)
    img_meta = [
        dict(
            ori_shape=ori_shape,
            img_shape=img_shape,
            pad_shape=pad_shape,
            scale_factor=scale_factor,
            flip=False)
    ]
    return dict(img=[img], img_meta=[img_meta])


Kai Chen's avatar
Kai Chen committed
26
27
28
29
30
31
32
33
34
35
36
def _inference_single(model, img, img_transform, cfg, device):
    img = mmcv.imread(img)
    data = _prepare_data(img, img_transform, cfg, device)
    with torch.no_grad():
        result = model(return_loss=False, rescale=True, **data)
    return result


def _inference_generator(model, imgs, img_transform, cfg, device):
    for img in imgs:
        yield _inference_single(model, img, img_transform, cfg, device)
myownskyW7's avatar
myownskyW7 committed
37

Kai Chen's avatar
Kai Chen committed
38
39

def inference_detector(model, imgs, cfg, device='cuda:0'):
myownskyW7's avatar
myownskyW7 committed
40
    img_transform = ImageTransform(
myownskyW7's avatar
myownskyW7 committed
41
        size_divisor=cfg.data.test.size_divisor, **cfg.img_norm_cfg)
myownskyW7's avatar
myownskyW7 committed
42
43
    model = model.to(device)
    model.eval()
Kai Chen's avatar
Kai Chen committed
44
45
46
47
48

    if not isinstance(imgs, list):
        return _inference_single(model, imgs, img_transform, cfg, device)
    else:
        return _inference_generator(model, imgs, img_transform, cfg, device)
myownskyW7's avatar
myownskyW7 committed
49
50
51
52
53
54
55
56
57
58


def show_result(img, result, dataset='coco', score_thr=0.3):
    class_names = get_classes(dataset)
    labels = [
        np.full(bbox.shape[0], i, dtype=np.int32)
        for i, bbox in enumerate(result)
    ]
    labels = np.concatenate(labels)
    bboxes = np.vstack(result)
Kai Chen's avatar
Kai Chen committed
59
    img = mmcv.imread(img)
myownskyW7's avatar
myownskyW7 committed
60
61
62
63
64
65
    mmcv.imshow_det_bboxes(
        img.copy(),
        bboxes,
        labels,
        class_names=class_names,
        score_thr=score_thr)