viz.py 1.74 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import numpy as np
import gluoncv as gcv
from matplotlib import pyplot as plt

def plot_sg(img, preds, obj_classes, rel_classes, topk=1):
    '''visualization of generated scene graph'''
    size = img.shape[0:2]
    box_scale = np.array([size[1], size[0], size[1], size[0]])
    topk = min(topk, preds.shape[0])
    ax = gcv.utils.viz.plot_image(img)
    for i in range(topk):
        rel = int(preds[i, 2])
        src = int(preds[i, 3])
        dst = int(preds[i, 4])
        src_name = obj_classes[src]
        dst_name = obj_classes[dst]
        rel_name = rel_classes[rel]
        src_bbox = preds[i, 5:9] * box_scale
        dst_bbox = preds[i, 9:13] * box_scale
        
        src_center = np.array([(src_bbox[0] + src_bbox[2]) / 2, (src_bbox[1] + src_bbox[3]) / 2])
        dst_center = np.array([(dst_bbox[0] + dst_bbox[2]) / 2, (dst_bbox[1] + dst_bbox[3]) / 2])
        rel_center = (src_center + dst_center) / 2
        
        line_x = np.array([(src_bbox[0] + src_bbox[2]) / 2, (dst_bbox[0] + dst_bbox[2]) / 2])
        line_y = np.array([(src_bbox[1] + src_bbox[3]) / 2, (dst_bbox[1] + dst_bbox[3]) / 2])
        
        ax.plot(line_x, line_y,
                linewidth=3.0, alpha=0.7, color=plt.cm.cool(rel))
        
        ax.text(src_center[0], src_center[1],
                '{:s}'.format(src_name),
                bbox=dict(alpha=0.5),
                fontsize=12, color='white')
        ax.text(dst_center[0], dst_center[1],
                '{:s}'.format(dst_name),
                bbox=dict(alpha=0.5),
                fontsize=12, color='white')
        ax.text(rel_center[0], rel_center[1],
                '{:s}'.format(rel_name),
                bbox=dict(alpha=0.5),
                fontsize=12, color='white')
    return ax

plot_sg(img, preds, 2)