demo_reldn.py 3.88 KB
Newer Older
1
import argparse
2

3
import gluoncv as gcv
4
5
import mxnet as mx
from data import *
6
from gluoncv.data.transforms import presets
7
from gluoncv.utilz import download
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
8
from model import faster_rcnn_resnet101_v1d_custom, RelDN
9
from utils import *
10
11
12

import dgl

13
14

def parse_args():
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    parser = argparse.ArgumentParser(
        description="Demo of Scene Graph Extraction."
    )
    parser.add_argument(
        "--image",
        type=str,
        default="",
        help="The image for scene graph extraction.",
    )
    parser.add_argument(
        "--gpu",
        type=str,
        default="",
        help="GPU id to use for inference, default is not using GPU.",
    )
    parser.add_argument(
        "--pretrained-faster-rcnn-params",
        type=str,
        default="",
        help="Path to saved Faster R-CNN model parameters.",
    )
    parser.add_argument(
        "--reldn-params",
        type=str,
        default="",
        help="Path to saved Faster R-CNN model parameters.",
    )
    parser.add_argument(
        "--faster-rcnn-params",
        type=str,
        default="",
        help="Path to saved Faster R-CNN model parameters.",
    )
    parser.add_argument(
        "--freq-prior",
        type=str,
        default="freq_prior.pkl",
        help="Path to saved frequency prior data.",
    )
54
55
56
    args = parser.parse_args()
    return args

57

58
59
60
61
62
63
64
args = parse_args()
if args.gpu:
    ctx = mx.gpu(int(args.gpu))
else:
    ctx = mx.cpu()

net = RelDN(n_classes=50, prior_pkl=args.freq_prior, semantic_only=False)
65
66
67
if args.reldn_params == "":
    download("http://data.dgl.ai/models/SceneGraph/reldn.params")
    net.load_parameters("rendl.params", ctx=ctx)
68
69
70
71
else:
    net.load_parameters(args.reldn_params, ctx=ctx)

# dataset and dataloader
72
73
74
75
76
77
78
79
80
81
82
83
vg_val = VGRelation(split="val")
detector = faster_rcnn_resnet101_v1d_custom(
    classes=vg_val.obj_classes,
    pretrained_base=False,
    pretrained=False,
    additional_output=True,
)
if args.pretrained_faster_rcnn_params == "":
    download(
        "http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params"
    )
    params_path = "faster_rcnn_resnet101_v1d_visualgenome.params"
84
85
else:
    params_path = args.pretrained_faster_rcnn_params
86
87
88
detector.load_parameters(
    params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
89

90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
detector_feat = faster_rcnn_resnet101_v1d_custom(
    classes=vg_val.obj_classes,
    pretrained_base=False,
    pretrained=False,
    additional_output=True,
)
detector_feat.load_parameters(
    params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
if args.faster_rcnn_params == "":
    download(
        "http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params"
    )
    detector_feat.features.load_parameters(
        "faster_rcnn_resnet101_v1d_visualgenome.params", ctx=ctx
    )
106
107
108
109
110
111
112
else:
    detector_feat.features.load_parameters(args.faster_rcnn_params, ctx=ctx)

# image input
if args.image:
    image_path = args.image
else:
113
114
115
116
117
118
119
120
121
    gcv.utils.download(
        "https://raw.githubusercontent.com/dmlc/web-data/master/"
        + "dgl/examples/mxnet/scenegraph/old-couple.png",
        "old-couple.png",
    )
    image_path = "old-couple.png"
x, img = presets.rcnn.load_test(
    args.image, short=detector.short, max_size=detector.max_size
)
122
123
124
125
x = x.as_in_context(ctx)
# detector prediction
ids, scores, bboxes, feat, feat_ind, spatial_feat = detector(x)
# build graph, extract edge features
126
127
128
129
130
131
132
133
134
135
136
137
g = build_graph_validate_pred(
    x,
    ids,
    scores,
    bboxes,
    feat_ind,
    spatial_feat,
    bbox_improvement=True,
    scores_top_k=75,
    overlap=False,
)
rel_bbox = g.edata["rel_bbox"].expand_dims(0).as_in_context(ctx)
138
_, _, _, spatial_feat_rel = detector_feat(x, None, None, rel_bbox)
139
g.edata["edge_feat"] = spatial_feat_rel[0]
140
141
142
143
# graph prediction
g = net(g)

_, preds = extract_pred(g, joint_preds=True)
144
preds = preds[preds[:, 1].argsort()[::-1]]
145
146

plot_sg(img, preds, detector.classes, vg_val.rel_classes, 10)