import numpy as np import gluoncv as gcv from matplotlib import pyplot as plt def plot_sg(img, preds, obj_classes, rel_classes, topk=1): '''visualization of generated scene graph''' size = img.shape[0:2] box_scale = np.array([size[1], size[0], size[1], size[0]]) topk = min(topk, preds.shape[0]) ax = gcv.utils.viz.plot_image(img) for i in range(topk): rel = int(preds[i, 2]) src = int(preds[i, 3]) dst = int(preds[i, 4]) src_name = obj_classes[src] dst_name = obj_classes[dst] rel_name = rel_classes[rel] src_bbox = preds[i, 5:9] * box_scale dst_bbox = preds[i, 9:13] * box_scale src_center = np.array([(src_bbox[0] + src_bbox[2]) / 2, (src_bbox[1] + src_bbox[3]) / 2]) dst_center = np.array([(dst_bbox[0] + dst_bbox[2]) / 2, (dst_bbox[1] + dst_bbox[3]) / 2]) rel_center = (src_center + dst_center) / 2 line_x = np.array([(src_bbox[0] + src_bbox[2]) / 2, (dst_bbox[0] + dst_bbox[2]) / 2]) line_y = np.array([(src_bbox[1] + src_bbox[3]) / 2, (dst_bbox[1] + dst_bbox[3]) / 2]) ax.plot(line_x, line_y, linewidth=3.0, alpha=0.7, color=plt.cm.cool(rel)) ax.text(src_center[0], src_center[1], '{:s}'.format(src_name), bbox=dict(alpha=0.5), fontsize=12, color='white') ax.text(dst_center[0], dst_center[1], '{:s}'.format(dst_name), bbox=dict(alpha=0.5), fontsize=12, color='white') ax.text(rel_center[0], rel_center[1], '{:s}'.format(rel_name), bbox=dict(alpha=0.5), fontsize=12, color='white') return ax plot_sg(img, preds, 2)