viz.py 1.91 KB
Newer Older
1
import gluoncv as gcv
2
import numpy as np
3
4
from matplotlib import pyplot as plt

5

6
def plot_sg(img, preds, obj_classes, rel_classes, topk=1):
7
    """visualization of generated scene graph"""
8
9
10
11
12
13
14
15
16
17
18
19
20
    size = img.shape[0:2]
    box_scale = np.array([size[1], size[0], size[1], size[0]])
    topk = min(topk, preds.shape[0])
    ax = gcv.utils.viz.plot_image(img)
    for i in range(topk):
        rel = int(preds[i, 2])
        src = int(preds[i, 3])
        dst = int(preds[i, 4])
        src_name = obj_classes[src]
        dst_name = obj_classes[dst]
        rel_name = rel_classes[rel]
        src_bbox = preds[i, 5:9] * box_scale
        dst_bbox = preds[i, 9:13] * box_scale
21
22
23
24
25
26
27

        src_center = np.array(
            [(src_bbox[0] + src_bbox[2]) / 2, (src_bbox[1] + src_bbox[3]) / 2]
        )
        dst_center = np.array(
            [(dst_bbox[0] + dst_bbox[2]) / 2, (dst_bbox[1] + dst_bbox[3]) / 2]
        )
28
        rel_center = (src_center + dst_center) / 2
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

        line_x = np.array(
            [(src_bbox[0] + src_bbox[2]) / 2, (dst_bbox[0] + dst_bbox[2]) / 2]
        )
        line_y = np.array(
            [(src_bbox[1] + src_bbox[3]) / 2, (dst_bbox[1] + dst_bbox[3]) / 2]
        )

        ax.plot(
            line_x, line_y, linewidth=3.0, alpha=0.7, color=plt.cm.cool(rel)
        )

        ax.text(
            src_center[0],
            src_center[1],
            "{:s}".format(src_name),
            bbox=dict(alpha=0.5),
            fontsize=12,
            color="white",
        )
        ax.text(
            dst_center[0],
            dst_center[1],
            "{:s}".format(dst_name),
            bbox=dict(alpha=0.5),
            fontsize=12,
            color="white",
        )
        ax.text(
            rel_center[0],
            rel_center[1],
            "{:s}".format(rel_name),
            bbox=dict(alpha=0.5),
            fontsize=12,
            color="white",
        )
65
66
67
    return ax


68
plot_sg(img, preds, 2)