validate_reldn.py 9.75 KB
Newer Older
1
2
3
4
import argparse
import logging
import time

5
6
import mxnet as mx
import numpy as np
7
from data import *
8
from gluoncv.data.batchify import Pad
9
10
from model import RelDN, faster_rcnn_resnet101_v1d_custom
from mxnet import gluon, nd
11
from utils import *
12
13
14

import dgl

15
16

def parse_args():
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    parser = argparse.ArgumentParser(
        description="Validate Pre-trained RelDN Model."
    )
    parser.add_argument(
        "--gpus",
        type=str,
        default="0",
        help="Training with GPUs, you can specify 1,3 for example.",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=8,
        help="Total batch-size for training.",
    )
    parser.add_argument(
        "--metric",
        type=str,
        default="sgdet",
        help="Evaluation metric, could be 'predcls', 'phrcls', 'sgdet' or 'sgdet+'.",
    )
    parser.add_argument(
        "--pretrained-faster-rcnn-params",
        type=str,
        required=True,
        help="Path to saved Faster R-CNN model parameters.",
    )
    parser.add_argument(
        "--reldn-params",
        type=str,
        required=True,
        help="Path to saved Faster R-CNN model parameters.",
    )
    parser.add_argument(
        "--faster-rcnn-params",
        type=str,
        required=True,
        help="Path to saved Faster R-CNN model parameters.",
    )
    parser.add_argument(
        "--log-dir",
        type=str,
        default="reldn_output.log",
        help="Path to save training logs.",
    )
    parser.add_argument(
        "--freq-prior",
        type=str,
        default="freq_prior.pkl",
        help="Path to saved frequency prior data.",
    )
    parser.add_argument(
        "--verbose-freq",
        type=int,
        default=100,
        help="Frequency of log printing in number of iterations.",
    )
74
75
76
    args = parser.parse_args()
    return args

77

78
79
80
81
args = parse_args()

filehandler = logging.FileHandler(args.log_dir)
streamhandler = logging.StreamHandler()
82
logger = logging.getLogger("")
83
84
85
86
87
logger.setLevel(logging.INFO)
logger.addHandler(filehandler)
logger.addHandler(streamhandler)

# Hyperparams
88
ctx = [mx.gpu(int(i)) for i in args.gpus.split(",") if i.strip()]
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
if ctx:
    num_gpus = len(ctx)
    assert args.batch_size % num_gpus == 0
    per_device_batch_size = int(args.batch_size / num_gpus)
else:
    ctx = [mx.cpu()]
    per_device_batch_size = args.batch_size
batch_size = args.batch_size
N_relations = 50
N_objects = 150
batch_verbose_freq = args.verbose_freq

mode = args.metric
metric_list = []
topk_list = [20, 50, 100]
104
if mode == "predcls":
105
106
    for topk in topk_list:
        metric_list.append(PredCls(topk=topk))
107
if mode == "phrcls":
108
109
    for topk in topk_list:
        metric_list.append(PhrCls(topk=topk))
110
if mode == "sgdet":
111
112
    for topk in topk_list:
        metric_list.append(SGDet(topk=topk))
113
if mode == "sgdet+":
114
115
116
117
118
119
    for topk in topk_list:
        metric_list.append(SGDetPlus(topk=topk))
for metric in metric_list:
    metric.reset()

semantic_only = False
120
121
122
123
124
net = RelDN(
    n_classes=N_relations,
    prior_pkl=args.freq_prior,
    semantic_only=semantic_only,
)
125
126
127
net.load_parameters(args.reldn_params, ctx=ctx)

# dataset and dataloader
128
129
130
131
132
133
134
135
136
vg_val = VGRelation(split="val")
logger.info("data loaded!")
val_data = gluon.data.DataLoader(
    vg_val,
    batch_size=len(ctx),
    shuffle=False,
    num_workers=16 * num_gpus,
    batchify_fn=dgl_mp_batchify_fn,
)
137
138
n_batches = len(val_data)

139
140
141
142
143
144
detector = faster_rcnn_resnet101_v1d_custom(
    classes=vg_val.obj_classes,
    pretrained_base=False,
    pretrained=False,
    additional_output=True,
)
145
params_path = args.pretrained_faster_rcnn_params
146
147
148
detector.load_parameters(
    params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
149

150
151
152
153
154
155
156
157
158
detector_feat = faster_rcnn_resnet101_v1d_custom(
    classes=vg_val.obj_classes,
    pretrained_base=False,
    pretrained=False,
    additional_output=True,
)
detector_feat.load_parameters(
    params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
159
160
161

detector_feat.features.load_parameters(args.faster_rcnn_params, ctx=ctx)

162

163
164
165
166
167
168
169
170
def get_data_batch(g_list, img_list, ctx_list):
    if g_list is None or len(g_list) == 0:
        return None, None
    n_gpu = len(ctx_list)
    size = len(g_list)
    if size < n_gpu:
        raise Exception("too small batch")
    step = size // n_gpu
171
172
173
174
175
176
177
178
179
180
181
182
    G_list = [
        g_list[i * step : (i + 1) * step]
        if i < n_gpu - 1
        else g_list[i * step : size]
        for i in range(n_gpu)
    ]
    img_list = [
        img_list[i * step : (i + 1) * step]
        if i < n_gpu - 1
        else img_list[i * step : size]
        for i in range(n_gpu)
    ]
183
184
185

    for G_slice, ctx in zip(G_list, ctx_list):
        for G in G_slice:
186
187
188
189
190
191
            G.ndata["bbox"] = G.ndata["bbox"].as_in_context(ctx)
            G.ndata["node_class"] = G.ndata["node_class"].as_in_context(ctx)
            G.ndata["node_class_vec"] = G.ndata["node_class_vec"].as_in_context(
                ctx
            )
            G.edata["rel_class"] = G.edata["rel_class"].as_in_context(ctx)
192
193
194
    img_list = [img.as_in_context(ctx) for img in img_list]
    return G_list, img_list

195

196
197
198
for i, (G_list, img_list) in enumerate(val_data):
    G_list, img_list = get_data_batch(G_list, img_list, ctx)
    if G_list is None or img_list is None:
199
200
        if (i + 1) % batch_verbose_freq == 0:
            print_txt = "Batch[%d/%d] " % (i, n_batches)
201
202
            for metric in metric_list:
                metric_name, metric_val = metric.get()
203
                print_txt += "%s=%.4f " % (metric_name, metric_val)
204
205
206
207
208
209
210
211
212
            logger.info(print_txt)
        continue

    detector_res_list = []
    G_batch = []
    bbox_pad = Pad(axis=(0))
    # loss_cls_val = 0
    for G_slice, img in zip(G_list, img_list):
        cur_ctx = img.context
213
214
        if mode == "predcls":
            bbox_list = [G.ndata["bbox"] for G in G_slice]
215
            bbox_stack = bbox_pad(bbox_list).as_in_context(cur_ctx)
216
217
218
            ids, scores, bbox, spatial_feat = detector(
                img, None, None, bbox_stack
            )
219

220
            node_class_list = [G.ndata["node_class"] for G in G_slice]
221
            node_class_stack = bbox_pad(node_class_list).as_in_context(cur_ctx)
222
223
224
225
226
227
228
229
230
            g_pred_batch = build_graph_validate_gt_obj(
                img,
                node_class_stack,
                bbox,
                spatial_feat,
                bbox_improvement=True,
                overlap=False,
            )
        elif mode == "phrcls":
231
            # use ground truth bbox
232
            bbox_list = [G.ndata["bbox"] for G in G_slice]
233
            bbox_stack = bbox_pad(bbox_list).as_in_context(cur_ctx)
234
235
236
            ids, scores, bbox, spatial_feat = detector(
                img, None, None, bbox_stack
            )
237

238
239
240
241
242
243
244
245
246
            g_pred_batch = build_graph_validate_gt_bbox(
                img,
                ids,
                scores,
                bbox,
                spatial_feat,
                bbox_improvement=True,
                overlap=False,
            )
247
248
249
        else:
            # use predicted bbox
            ids, scores, bbox, feat, feat_ind, spatial_feat = detector(img)
250
251
252
253
254
255
256
257
258
259
260
            g_pred_batch = build_graph_validate_pred(
                img,
                ids,
                scores,
                bbox,
                feat_ind,
                spatial_feat,
                bbox_improvement=True,
                scores_top_k=75,
                overlap=False,
            )
261
        if not semantic_only:
262
263
            rel_bbox = g_pred_batch.edata["rel_bbox"]
            batch_id = g_pred_batch.edata["batch_id"].asnumpy()
264
265
266
267
268
269
270
271
272
            n_sample_edges = g_pred_batch.number_of_edges()
            # g_pred_batch.edata['edge_feat'] = mx.nd.zeros((n_sample_edges, 49), ctx=cur_ctx)
            n_graph = len(G_slice)
            bbox_rel_list = []
            for j in range(n_graph):
                eids = np.where(batch_id == j)[0]
                if len(eids) > 0:
                    bbox_rel_list.append(rel_bbox[eids])
            bbox_rel_stack = bbox_pad(bbox_rel_list).as_in_context(cur_ctx)
273
274
275
            _, _, _, spatial_feat_rel = detector_feat(
                img, None, None, bbox_rel_stack
            )
276
277
278
279
            spatial_feat_rel_list = []
            for j in range(n_graph):
                eids = np.where(batch_id == j)[0]
                if len(eids) > 0:
280
281
282
283
284
285
                    spatial_feat_rel_list.append(
                        spatial_feat_rel[j, 0 : len(eids)]
                    )
            g_pred_batch.edata["edge_feat"] = nd.concat(
                *spatial_feat_rel_list, dim=0
            )
286
287
288
289
290
291
292
293
294
295
296
297

        G_batch.append(g_pred_batch)

    G_batch = [net(G) for G in G_batch]

    for G_slice, G_pred, img_slice in zip(G_list, G_batch, img_list):
        for G_gt, G_pred_one in zip(G_slice, [G_pred]):
            if G_pred_one is None or G_pred_one.number_of_nodes() == 0:
                continue
            gt_objects, gt_triplet = extract_gt(G_gt, img_slice.shape[2:4])
            pred_objects, pred_triplet = extract_pred(G_pred, joint_preds=True)
            for metric in metric_list:
298
299
300
301
302
                if (
                    isinstance(metric, PredCls)
                    or isinstance(metric, PhrCls)
                    or isinstance(metric, SGDet)
                ):
303
304
                    metric.update(gt_triplet, pred_triplet)
                else:
305
306
307
308
309
                    metric.update(
                        (gt_objects, gt_triplet), (pred_objects, pred_triplet)
                    )
    if (i + 1) % batch_verbose_freq == 0:
        print_txt = "Batch[%d/%d] " % (i, n_batches)
310
311
        for metric in metric_list:
            metric_name, metric_val = metric.get()
312
            print_txt += "%s=%.4f " % (metric_name, metric_val)
313
314
        logger.info(print_txt)

315
print_txt = "Batch[%d/%d] " % (n_batches, n_batches)
316
317
for metric in metric_list:
    metric_name, metric_val = metric.get()
318
    print_txt += "%s=%.4f " % (metric_name, metric_val)
319
logger.info(print_txt)