train_reldn.py 13 KB
Newer Older
1
2
3
4
import argparse
import logging
import time

5
6
import mxnet as mx
import numpy as np
7
from data import *
8
9
from gluoncv.data.batchify import Pad
from gluoncv.utils import makedirs
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
10
from model import faster_rcnn_resnet101_v1d_custom, RelDN
11
from mxnet import gluon, nd
12
from utils import *
13
14
15

import dgl

16
17

def parse_args():
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    parser = argparse.ArgumentParser(description="Train RelDN Model.")
    parser.add_argument(
        "--gpus",
        type=str,
        default="0",
        help="Training with GPUs, you can specify 1,3 for example.",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=8,
        help="Total batch-size for training.",
    )
    parser.add_argument(
        "--epochs", type=int, default=9, help="Training epochs."
    )
    parser.add_argument(
        "--lr-reldn",
        type=float,
        default=0.01,
        help="Learning rate for RelDN module.",
    )
    parser.add_argument(
        "--wd-reldn",
        type=float,
        default=0.0001,
        help="Weight decay for RelDN module.",
    )
    parser.add_argument(
        "--lr-faster-rcnn",
        type=float,
        default=0.01,
        help="Learning rate for Faster R-CNN module.",
    )
    parser.add_argument(
        "--wd-faster-rcnn",
        type=float,
        default=0.0001,
        help="Weight decay for RelDN module.",
    )
    parser.add_argument(
        "--lr-decay-epochs",
        type=str,
        default="5,8",
        help="Learning rate decay points.",
    )
    parser.add_argument(
        "--lr-warmup-iters",
        type=int,
        default=4000,
        help="Learning rate warm-up iterations.",
    )
    parser.add_argument(
        "--save-dir",
        type=str,
        default="params_resnet101_v1d_reldn",
        help="Path to save model parameters.",
    )
    parser.add_argument(
        "--log-dir",
        type=str,
        default="reldn_output.log",
        help="Path to save training logs.",
    )
    parser.add_argument(
        "--pretrained-faster-rcnn-params",
        type=str,
        required=True,
        help="Path to saved Faster R-CNN model parameters.",
    )
    parser.add_argument(
        "--freq-prior",
        type=str,
        default="freq_prior.pkl",
        help="Path to saved frequency prior data.",
    )
    parser.add_argument(
        "--verbose-freq",
        type=int,
        default=100,
        help="Frequency of log printing in number of iterations.",
    )
100
101
102
103

    args = parser.parse_args()
    return args

104

105
106
107
108
args = parse_args()

filehandler = logging.FileHandler(args.log_dir)
streamhandler = logging.StreamHandler()
109
logger = logging.getLogger("")
110
111
112
113
114
logger.setLevel(logging.INFO)
logger.addHandler(filehandler)
logger.addHandler(streamhandler)

# Hyperparams
115
ctx = [mx.gpu(int(i)) for i in args.gpus.split(",") if i.strip()]
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
if ctx:
    num_gpus = len(ctx)
    assert args.batch_size % num_gpus == 0
    per_device_batch_size = int(args.batch_size / num_gpus)
else:
    ctx = [mx.cpu()]
    per_device_batch_size = args.batch_size

aggregate_grad = per_device_batch_size > 1

nepoch = args.epochs
N_relations = 50
N_objects = 150
save_dir = args.save_dir
makedirs(save_dir)
batch_verbose_freq = args.verbose_freq
132
lr_decay_epochs = [int(i) for i in args.lr_decay_epochs.split(",")]
133
134

# Dataset and dataloader
135
136
137
138
139
140
141
142
143
vg_train = VGRelation(split="train")
logger.info("data loaded!")
train_data = gluon.data.DataLoader(
    vg_train,
    batch_size=len(ctx),
    shuffle=True,
    num_workers=8 * num_gpus,
    batchify_fn=dgl_mp_batchify_fn,
)
144
145
146
147
148
149
150
n_batches = len(train_data)

# Network definition
net = RelDN(n_classes=N_relations, prior_pkl=args.freq_prior)
net.spatial.initialize(mx.init.Normal(1e-4), ctx=ctx)
net.visual.initialize(mx.init.Normal(1e-4), ctx=ctx)
for k, v in net.collect_params().items():
151
    v.grad_req = "add" if aggregate_grad else "write"
152
net_params = net.collect_params()
153
154
155
156
157
net_trainer = gluon.Trainer(
    net.collect_params(),
    "adam",
    {"learning_rate": args.lr_reldn, "wd": args.wd_reldn},
)
158
159

det_params_path = args.pretrained_faster_rcnn_params
160
161
162
163
164
165
166
167
168
detector = faster_rcnn_resnet101_v1d_custom(
    classes=vg_train.obj_classes,
    pretrained_base=False,
    pretrained=False,
    additional_output=True,
)
detector.load_parameters(
    det_params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
169
for k, v in detector.collect_params().items():
170
    v.grad_req = "null"
171

172
173
174
175
176
177
178
179
180
detector_feat = faster_rcnn_resnet101_v1d_custom(
    classes=vg_train.obj_classes,
    pretrained_base=False,
    pretrained=False,
    additional_output=True,
)
detector_feat.load_parameters(
    det_params_path, ctx=ctx, ignore_extra=True, allow_missing=True
)
181
for k, v in detector_feat.collect_params().items():
182
    v.grad_req = "null"
183
for k, v in detector_feat.features.collect_params().items():
184
    v.grad_req = "add" if aggregate_grad else "write"
185
det_params = detector_feat.features.collect_params()
186
187
188
189
190
191
det_trainer = gluon.Trainer(
    detector_feat.features.collect_params(),
    "adam",
    {"learning_rate": args.lr_faster_rcnn, "wd": args.wd_faster_rcnn},
)

192
193
194
195
196
197
198
199
200

def get_data_batch(g_list, img_list, ctx_list):
    if g_list is None or len(g_list) == 0:
        return None, None
    n_gpu = len(ctx_list)
    size = len(g_list)
    if size < n_gpu:
        raise Exception("too small batch")
    step = size // n_gpu
201
202
203
204
205
206
207
208
209
210
211
212
    G_list = [
        g_list[i * step : (i + 1) * step]
        if i < n_gpu - 1
        else g_list[i * step : size]
        for i in range(n_gpu)
    ]
    img_list = [
        img_list[i * step : (i + 1) * step]
        if i < n_gpu - 1
        else img_list[i * step : size]
        for i in range(n_gpu)
    ]
213
214
215

    for G_slice, ctx in zip(G_list, ctx_list):
        for G in G_slice:
216
217
218
219
220
221
            G.ndata["bbox"] = G.ndata["bbox"].as_in_context(ctx)
            G.ndata["node_class"] = G.ndata["node_class"].as_in_context(ctx)
            G.ndata["node_class_vec"] = G.ndata["node_class_vec"].as_in_context(
                ctx
            )
            G.edata["rel_class"] = G.edata["rel_class"].as_in_context(ctx)
222
223
224
    img_list = [img.as_in_context(ctx) for img in img_list]
    return G_list, img_list

225

226
227
L_rel = gluon.loss.SoftmaxCELoss()

228
229
train_metric = mx.metric.Accuracy(name="rel_acc")
train_metric_top5 = mx.metric.TopKAccuracy(5, name="rel_acc_top5")
230
231
metric_list = [train_metric, train_metric_top5]

232
233
234
235
236
237
238
239
240
241
242
243

def batch_print(
    epoch, i, batch_verbose_freq, n_batches, btic, loss_rel_val, metric_list
):
    if (i + 1) % batch_verbose_freq == 0:
        print_txt = "Epoch[%d] Batch[%d/%d], time: %d, loss_rel=%.4f " % (
            epoch,
            i,
            n_batches,
            int(time.time() - btic),
            loss_rel_val / (i + 1),
        )
244
245
        for metric in metric_list:
            metric_name, metric_val = metric.get()
246
            print_txt += "%s=%.4f " % (metric_name, metric_val)
247
248
249
250
251
        logger.info(print_txt)
        btic = time.time()
        loss_rel_val = 0
    return btic, loss_rel_val

252

253
254
255
256
257
258
259
260
261
262
for epoch in range(nepoch):
    loss_rel_val = 0
    tic = time.time()
    btic = time.time()
    for metric in metric_list:
        metric.reset()
    if epoch == 0:
        net_trainer_base_lr = net_trainer.learning_rate
        det_trainer_base_lr = det_trainer.learning_rate
    if epoch == 5 or epoch == 8:
263
264
        net_trainer.set_learning_rate(net_trainer.learning_rate * 0.1)
        det_trainer.set_learning_rate(det_trainer.learning_rate * 0.1)
265
266
267
    for i, (G_list, img_list) in enumerate(train_data):
        if epoch == 0 and i < args.lr_warmup_iters:
            alpha = i / args.lr_warmup_iters
268
269
270
            warmup_factor = 1 / 3 * (1 - alpha) + alpha
            net_trainer.set_learning_rate(net_trainer_base_lr * warmup_factor)
            det_trainer.set_learning_rate(det_trainer_base_lr * warmup_factor)
271
272
        G_list, img_list = get_data_batch(G_list, img_list, ctx)
        if G_list is None or img_list is None:
273
274
275
276
277
278
279
280
281
            btic, loss_rel_val = batch_print(
                epoch,
                i,
                batch_verbose_freq,
                n_batches,
                btic,
                loss_rel_val,
                metric_list,
            )
282
283
284
285
286
287
288
289
290
            continue

        loss = []
        detector_res_list = []
        G_batch = []
        bbox_pad = Pad(axis=(0))
        with mx.autograd.record():
            for G_slice, img in zip(G_list, img_list):
                cur_ctx = img.context
291
                bbox_list = [G.ndata["bbox"] for G in G_slice]
292
293
                bbox_stack = bbox_pad(bbox_list).as_in_context(cur_ctx)
                with mx.autograd.pause():
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
                    ids, scores, bbox, feat, feat_ind, spatial_feat = detector(
                        img
                    )
                g_pred_batch = build_graph_train(
                    G_slice,
                    bbox_stack,
                    img,
                    ids,
                    scores,
                    bbox,
                    feat_ind,
                    spatial_feat,
                    scores_top_k=300,
                    overlap=False,
                )
309
310
311
                g_batch = l0_sample(g_pred_batch)
                if g_batch is None:
                    continue
312
313
                rel_bbox = g_batch.edata["rel_bbox"]
                batch_id = g_batch.edata["batch_id"].asnumpy()
314
315
316
317
318
319
320
321
322
323
324
325
326
                n_sample_edges = g_batch.number_of_edges()
                n_graph = len(G_slice)
                bbox_rel_list = []
                for j in range(n_graph):
                    eids = np.where(batch_id == j)[0]
                    if len(eids) > 0:
                        bbox_rel_list.append(rel_bbox[eids])
                bbox_rel_stack = bbox_pad(bbox_rel_list).as_in_context(cur_ctx)
                img_size = img.shape[2:4]
                bbox_rel_stack[:, :, 0] *= img_size[1]
                bbox_rel_stack[:, :, 1] *= img_size[0]
                bbox_rel_stack[:, :, 2] *= img_size[1]
                bbox_rel_stack[:, :, 3] *= img_size[0]
327
328
329
                _, _, _, spatial_feat_rel = detector_feat(
                    img, None, None, bbox_rel_stack
                )
330
331
332
333
                spatial_feat_rel_list = []
                for j in range(n_graph):
                    eids = np.where(batch_id == j)[0]
                    if len(eids) > 0:
334
335
336
337
338
339
                        spatial_feat_rel_list.append(
                            spatial_feat_rel[j, 0 : len(eids)]
                        )
                g_batch.edata["edge_feat"] = nd.concat(
                    *spatial_feat_rel_list, dim=0
                )
340
341
342
343
344
345
346
347

                G_batch.append(g_batch)

            G_batch = [net(G) for G in G_batch]

            for G_pred, img in zip(G_batch, img_list):
                if G_pred is None or G_pred.number_of_nodes() == 0:
                    continue
348
349
350
351
352
                loss_rel = L_rel(
                    G_pred.edata["preds"],
                    G_pred.edata["rel_class"],
                    G_pred.edata["sample_weights"],
                )
353
354
355
356
                loss.append(loss_rel.sum())
                loss_rel_val += loss_rel.mean().asscalar() / num_gpus

        if len(loss) == 0:
357
358
359
360
361
362
363
364
365
            btic, loss_rel_val = batch_print(
                epoch,
                i,
                batch_verbose_freq,
                n_batches,
                btic,
                loss_rel_val,
                metric_list,
            )
366
367
368
            continue
        for l in loss:
            l.backward()
369
        if (i + 1) % per_device_batch_size == 0 or i == n_batches - 1:
370
371
372
373
374
375
376
377
378
379
            net_trainer.step(args.batch_size)
            det_trainer.step(args.batch_size)
            if aggregate_grad:
                for k, v in net_params.items():
                    v.zero_grad()
                for k, v in det_params.items():
                    v.zero_grad()
        for G_pred, img_slice in zip(G_batch, img_list):
            if G_pred is None or G_pred.number_of_nodes() == 0:
                continue
380
            link_ind = np.where(G_pred.edata["rel_class"].asnumpy() > 0)[0]
381
382
            if len(link_ind) == 0:
                continue
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
            train_metric.update(
                [G_pred.edata["rel_class"][link_ind]],
                [G_pred.edata["preds"][link_ind]],
            )
            train_metric_top5.update(
                [G_pred.edata["rel_class"][link_ind]],
                [G_pred.edata["preds"][link_ind]],
            )
        btic, loss_rel_val = batch_print(
            epoch,
            i,
            batch_verbose_freq,
            n_batches,
            btic,
            loss_rel_val,
            metric_list,
        )
        if (i + 1) % batch_verbose_freq == 0:
            net.save_parameters("%s/model-%d.params" % (save_dir, epoch))
            detector_feat.features.save_parameters(
                "%s/detector_feat.features-%d.params" % (save_dir, epoch)
            )
    print_txt = "Epoch[%d], time: %d, loss_rel=%.4f," % (
        epoch,
        int(time.time() - tic),
        loss_rel_val / (i + 1),
    )
410
411
    for metric in metric_list:
        metric_name, metric_val = metric.get()
412
        print_txt += "%s=%.4f " % (metric_name, metric_val)
413
    logger.info(print_txt)
414
415
416
417
    net.save_parameters("%s/model-%d.params" % (save_dir, epoch))
    detector_feat.features.save_parameters(
        "%s/detector_feat.features-%d.params" % (save_dir, epoch)
    )