demo_reldn.py 3.84 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import dgl
import argparse
import mxnet as mx
import gluoncv as gcv
from gluoncv.utilz import download
from gluoncv.data.transforms import presets
from model import faster_rcnn_resnet101_v1d_custom, RelDN
from utils import *
from data import *

def parse_args():
    parser = argparse.ArgumentParser(description='Demo of Scene Graph Extraction.')
    parser.add_argument('--image', type=str, default='',
                        help="The image for scene graph extraction.")
    parser.add_argument('--gpu', type=str, default='',
                        help="GPU id to use for inference, default is not using GPU.")
    parser.add_argument('--pretrained-faster-rcnn-params', type=str, default='',
                        help="Path to saved Faster R-CNN model parameters.")
    parser.add_argument('--reldn-params', type=str, default='',
                        help="Path to saved Faster R-CNN model parameters.")
    parser.add_argument('--faster-rcnn-params', type=str, default='',
                        help="Path to saved Faster R-CNN model parameters.")
    parser.add_argument('--freq-prior', type=str, default='freq_prior.pkl',
                        help="Path to saved frequency prior data.")
    args = parser.parse_args()
    return args

args = parse_args()
if args.gpu:
    ctx = mx.gpu(int(args.gpu))
else:
    ctx = mx.cpu()

net = RelDN(n_classes=50, prior_pkl=args.freq_prior, semantic_only=False)
if args.reldn_params == '':
    download('http://data.dgl.ai/models/SceneGraph/reldn.params')
    net.load_parameters('rendl.params', ctx=ctx)
else:
    net.load_parameters(args.reldn_params, ctx=ctx)

# dataset and dataloader
vg_val = VGRelation(split='val')
detector = faster_rcnn_resnet101_v1d_custom(classes=vg_val.obj_classes,
                                            pretrained_base=False, pretrained=False,
                                            additional_output=True)
if args.pretrained_faster_rcnn_params == '':
    download('http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params')
    params_path = 'faster_rcnn_resnet101_v1d_visualgenome.params'
else:
    params_path = args.pretrained_faster_rcnn_params
detector.load_parameters(params_path, ctx=ctx, ignore_extra=True, allow_missing=True)

detector_feat = faster_rcnn_resnet101_v1d_custom(classes=vg_val.obj_classes,
                                                 pretrained_base=False, pretrained=False,
                                                 additional_output=True)
detector_feat.load_parameters(params_path, ctx=ctx, ignore_extra=True, allow_missing=True)
if args.faster_rcnn_params == '':
    download('http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params')
    detector_feat.features.load_parameters('faster_rcnn_resnet101_v1d_visualgenome.params', ctx=ctx)
else:
    detector_feat.features.load_parameters(args.faster_rcnn_params, ctx=ctx)

# image input
if args.image:
    image_path = args.image
else:
    gcv.utils.download('https://raw.githubusercontent.com/dmlc/web-data/master/' +
                       'dgl/examples/mxnet/scenegraph/old-couple.png',
                       'old-couple.png')
    image_path = 'old-couple.png'
x, img = presets.rcnn.load_test(args.image, short=detector.short, max_size=detector.max_size)
x = x.as_in_context(ctx)
# detector prediction
ids, scores, bboxes, feat, feat_ind, spatial_feat = detector(x)
# build graph, extract edge features
g = build_graph_validate_pred(x, ids, scores, bboxes, feat_ind, spatial_feat, bbox_improvement=True, scores_top_k=75, overlap=False)
rel_bbox = g.edata['rel_bbox'].expand_dims(0).as_in_context(ctx)
_, _, _, spatial_feat_rel = detector_feat(x, None, None, rel_bbox)
g.edata['edge_feat'] = spatial_feat_rel[0]
# graph prediction
g = net(g)

_, preds = extract_pred(g, joint_preds=True)
preds = preds[preds[:,1].argsort()[::-1]]

plot_sg(img, preds, detector.classes, vg_val.rel_classes, 10)