train.py 23.4 KB
Newer Older
1
2
3
4
5
6
7
from dataloader import EvalDataset, TrainDataset, NewBidirectionalOneShotIterator
from dataloader import get_dataset

import argparse
import os
import logging
import time
8
import json
9

Da Zheng's avatar
Da Zheng committed
10
backend = os.environ.get('DGLBACKEND', 'pytorch')
11
if backend.lower() == 'mxnet':
12
    import multiprocessing as mp
13
14
15
16
    from train_mxnet import load_model
    from train_mxnet import train
    from train_mxnet import test
else:
17
    import torch.multiprocessing as mp
18
    from train_pytorch import load_model
19
20
    from train_pytorch import train, train_mp
    from train_pytorch import test, test_mp
21
22
23
24
25
26

class ArgParser(argparse.ArgumentParser):
    def __init__(self):
        super(ArgParser, self).__init__()

        self.add_argument('--model_name', default='TransE',
27
28
                          choices=['TransE', 'TransE_l1', 'TransE_l2', 'TransR',
                                   'RESCAL', 'DistMult', 'ComplEx', 'RotatE'],
29
30
31
32
33
                          help='model to use')
        self.add_argument('--data_path', type=str, default='data',
                          help='root path of all dataset')
        self.add_argument('--dataset', type=str, default='FB15k',
                          help='dataset name, under data_path')
34
35
36
37
38
        self.add_argument('--format', type=str, default='built_in',
                          help='the format of the dataset, it can be built_in,'\
                                'raw_udd_{htr} and udd_{htr}')
        self.add_argument('--data_files', type=str, default=None, nargs='+',
                          help='a list of data files, e.g. entity relation train valid test')
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
        self.add_argument('--save_path', type=str, default='ckpts',
                          help='place to save models and logs')
        self.add_argument('--save_emb', type=str, default=None,
                          help='save the embeddings in the specific location.')

        self.add_argument('--max_step', type=int, default=80000,
                          help='train xx steps')
        self.add_argument('--warm_up_step', type=int, default=None,
                          help='for learning rate decay')
        self.add_argument('--batch_size', type=int, default=1024,
                          help='batch size')
        self.add_argument('--batch_size_eval', type=int, default=8,
                          help='batch size used for eval and test')
        self.add_argument('--neg_sample_size', type=int, default=128,
                          help='negative sampling size')
54
55
56
57
58
59
        self.add_argument('--neg_chunk_size', type=int, default=-1,
                          help='chunk size of the negative edges.')
        self.add_argument('--neg_deg_sample', action='store_true',
                          help='negative sample proportional to vertex degree in the training')
        self.add_argument('--neg_deg_sample_eval', action='store_true',
                          help='negative sampling proportional to vertex degree in the evaluation')
60
61
        self.add_argument('--neg_sample_size_valid', type=int, default=1000,
                          help='negative sampling size for validation')
62
63
        self.add_argument('--neg_chunk_size_valid', type=int, default=-1,
                          help='chunk size of the negative edges.')
64
65
        self.add_argument('--neg_sample_size_test', type=int, default=-1,
                          help='negative sampling size for testing')
66
67
        self.add_argument('--neg_chunk_size_test', type=int, default=-1,
                          help='chunk size of the negative edges.')
68
69
70
71
72
73
74
75
        self.add_argument('--hidden_dim', type=int, default=256,
                          help='hidden dim used by relation and entity')
        self.add_argument('--lr', type=float, default=0.0001,
                          help='learning rate')
        self.add_argument('-g', '--gamma', type=float, default=12.0,
                          help='margin value')
        self.add_argument('--eval_percent', type=float, default=1,
                          help='sample some percentage for evaluation.')
76
77
        self.add_argument('--no_eval_filter', action='store_true',
                          help='do not filter positive edges among negative edges for evaluation')
78

79
80
        self.add_argument('--gpu', type=int, default=[-1], nargs='+', 
                          help='a list of active gpu ids, e.g. 0 1 2 4')
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        self.add_argument('--mix_cpu_gpu', action='store_true',
                          help='mix CPU and GPU training')
        self.add_argument('-de', '--double_ent', action='store_true',
                          help='double entitiy dim for complex number')
        self.add_argument('-dr', '--double_rel', action='store_true',
                          help='double relation dim for complex number')
        self.add_argument('--seed', type=int, default=0,
                          help='set random seed fro reproducibility')
        self.add_argument('-log', '--log_interval', type=int, default=1000,
                          help='do evaluation after every x steps')
        self.add_argument('--eval_interval', type=int, default=10000,
                          help='do evaluation after every x steps')
        self.add_argument('-adv', '--neg_adversarial_sampling', action='store_true',
                          help='if use negative adversarial sampling')
        self.add_argument('-a', '--adversarial_temperature', default=1.0, type=float)

        self.add_argument('--valid', action='store_true',
                          help='if valid a model')
        self.add_argument('--test', action='store_true',
                          help='if test a model')
        self.add_argument('-rc', '--regularization_coef', type=float, default=0.000002,
                          help='set value > 0.0 if regularization is used')
        self.add_argument('-rn', '--regularization_norm', type=int, default=3,
                          help='norm used in regularization')
105
        self.add_argument('--num_worker', type=int, default=32,
106
107
108
109
110
111
112
113
114
115
116
                          help='number of workers used for loading data')
        self.add_argument('--non_uni_weight', action='store_true',
                          help='if use uniform weight when computing loss')
        self.add_argument('--init_step', type=int, default=0,
                          help='DONT SET MANUALLY, used for resume')
        self.add_argument('--step', type=int, default=0,
                          help='DONT SET MANUALLY, track current step')
        self.add_argument('--pickle_graph', action='store_true',
                          help='pickle built graph, building a huge graph is slow.')
        self.add_argument('--num_proc', type=int, default=1,
                          help='number of process used')
117
118
119
120
        self.add_argument('--num_test_proc', type=int, default=1,
                          help='number of process used for test')
        self.add_argument('--num_thread', type=int, default=1,
                          help='number of thread used')
121
122
        self.add_argument('--rel_part', action='store_true',
                          help='enable relation partitioning')
123
124
        self.add_argument('--soft_rel_part', action='store_true',
                          help='enable soft relation partition')
125
126
127
128
129
130
        self.add_argument('--nomp_thread_per_process', type=int, default=-1,
                          help='num of omp threads used per process in multi-process training')
        self.add_argument('--async_update', action='store_true',
                          help='allow async_update on node embedding')
        self.add_argument('--force_sync_interval', type=int, default=-1,
                          help='We force a synchronization between processes every x steps')
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159


def get_logger(args):
    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    folder = '{}_{}_'.format(args.model_name, args.dataset)
    n = len([x for x in os.listdir(args.save_path) if x.startswith(folder)])
    folder += str(n)
    args.save_path = os.path.join(args.save_path, folder)

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    log_file = os.path.join(args.save_path, 'train.log')

    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S',
        filename=log_file,
        filemode='w'
    )

    logger = logging.getLogger(__name__)
    print("Logs are being recorded at: {}".format(log_file))
    return logger


def run(args, logger):
160
    train_time_start = time.time()
161
    # load dataset and samplers
162
    dataset = get_dataset(args.data_path, args.dataset, args.format, args.data_files)
163
164
165
166
    n_entities = dataset.n_entities
    n_relations = dataset.n_relations
    if args.neg_sample_size_test < 0:
        args.neg_sample_size_test = n_entities
167
    args.eval_filter = not args.no_eval_filter
168
169
170
171
172
173
174
175
176
177
178
179
180
    if args.neg_deg_sample_eval:
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

    # When we generate a batch of negative edges from a set of positive edges,
    # we first divide the positive edges into chunks and corrupt the edges in a chunk
    # together. By default, the chunk size is equal to the negative sample size.
    # Usually, this works well. But we also allow users to specify the chunk size themselves.
    if args.neg_chunk_size < 0:
        args.neg_chunk_size = args.neg_sample_size
    if args.neg_chunk_size_valid < 0:
        args.neg_chunk_size_valid = args.neg_sample_size_valid
    if args.neg_chunk_size_test < 0:
        args.neg_chunk_size_test = args.neg_sample_size_test
181

182
    num_workers = args.num_worker
183
    train_data = TrainDataset(dataset, args, ranks=args.num_proc)
184
    # if there is no cross partition relaiton, we fall back to strict_rel_part
185
    args.strict_rel_part = args.mix_cpu_gpu and (train_data.cross_part == False)
186
    args.soft_rel_part = args.mix_cpu_gpu and args.soft_rel_part and train_data.cross_part
187
188
189
190
191
192
193
194
195
196

    # Automatically set number of OMP threads for each process if it is not provided
    # The value for GPU is evaluated in AWS p3.16xlarge
    # The value for CPU is evaluated in AWS x1.32xlarge
    if args.nomp_thread_per_process == -1:
        if len(args.gpu) > 0:
            # GPU training
            args.num_thread = 4
        else:
            # CPU training
197
            args.num_thread = 1
198
199
200
    else:
        args.num_thread = args.nomp_thread_per_process

201
202
203
    if args.num_proc > 1:
        train_samplers = []
        for i in range(args.num_proc):
204
205
            train_sampler_head = train_data.create_sampler(args.batch_size,
                                                           args.neg_sample_size,
206
                                                           args.neg_chunk_size,
207
208
                                                           mode='head',
                                                           num_workers=num_workers,
209
                                                           shuffle=True,
210
                                                           exclude_positive=False,
211
                                                           rank=i)
212
213
            train_sampler_tail = train_data.create_sampler(args.batch_size,
                                                           args.neg_sample_size,
214
                                                           args.neg_chunk_size,
215
216
                                                           mode='tail',
                                                           num_workers=num_workers,
217
                                                           shuffle=True,
218
                                                           exclude_positive=False,
219
220
                                                           rank=i)
            train_samplers.append(NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
221
                                                                  args.neg_chunk_size, args.neg_sample_size,
222
223
                                                                  True, n_entities))
    else:
224
225
        train_sampler_head = train_data.create_sampler(args.batch_size,
                                                       args.neg_sample_size,
226
                                                       args.neg_chunk_size,
227
228
                                                       mode='head',
                                                       num_workers=num_workers,
229
                                                       shuffle=True,
230
231
232
                                                       exclude_positive=False)
        train_sampler_tail = train_data.create_sampler(args.batch_size,
                                                       args.neg_sample_size,
233
                                                       args.neg_chunk_size,
234
235
                                                       mode='tail',
                                                       num_workers=num_workers,
236
                                                       shuffle=True,
237
                                                       exclude_positive=False)
238
        train_sampler = NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
239
                                                        args.neg_chunk_size, args.neg_sample_size,
240
241
                                                        True, n_entities)

Da Zheng's avatar
Da Zheng committed
242
243
244
245
    # for multiprocessing evaluation, we don't need to sample multiple batches at a time
    # in each process.
    if args.num_proc > 1:
        num_workers = 1
246
    if args.valid or args.test:
247
248
249
250
        if len(args.gpu) > 1:
            args.num_test_proc = args.num_proc if args.num_proc < len(args.gpu) else len(args.gpu)
        else:
            args.num_test_proc = args.num_proc
251
252
253
254
255
256
257
258
259
260
        eval_dataset = EvalDataset(dataset, args)
    if args.valid:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
        if args.num_proc > 1:
            valid_sampler_heads = []
            valid_sampler_tails = []
            for i in range(args.num_proc):
                valid_sampler_head = eval_dataset.create_sampler('valid', args.batch_size_eval,
                                                                 args.neg_sample_size_valid,
261
                                                                 args.neg_chunk_size_valid,
262
                                                                 args.eval_filter,
263
                                                                 mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
264
                                                                 num_workers=num_workers,
265
266
267
                                                                 rank=i, ranks=args.num_proc)
                valid_sampler_tail = eval_dataset.create_sampler('valid', args.batch_size_eval,
                                                                 args.neg_sample_size_valid,
268
                                                                 args.neg_chunk_size_valid,
269
                                                                 args.eval_filter,
270
                                                                 mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
271
                                                                 num_workers=num_workers,
272
273
274
275
276
277
                                                                 rank=i, ranks=args.num_proc)
                valid_sampler_heads.append(valid_sampler_head)
                valid_sampler_tails.append(valid_sampler_tail)
        else:
            valid_sampler_head = eval_dataset.create_sampler('valid', args.batch_size_eval,
                                                             args.neg_sample_size_valid,
278
                                                             args.neg_chunk_size_valid,
279
                                                             args.eval_filter,
280
                                                             mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
281
                                                             num_workers=num_workers,
282
283
284
                                                             rank=0, ranks=1)
            valid_sampler_tail = eval_dataset.create_sampler('valid', args.batch_size_eval,
                                                             args.neg_sample_size_valid,
285
                                                             args.neg_chunk_size_valid,
286
                                                             args.eval_filter,
287
                                                             mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
288
                                                             num_workers=num_workers,
289
290
291
292
                                                             rank=0, ranks=1)
    if args.test:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
293
294
        # We use a maximum of num_gpu in test stage to save GPU memory.
        if args.num_test_proc > 1:
295
296
            test_sampler_tails = []
            test_sampler_heads = []
297
            for i in range(args.num_test_proc):
298
299
                test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
                                                                args.neg_sample_size_test,
300
                                                                args.neg_chunk_size_test,
301
                                                                args.eval_filter,
302
                                                                mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
303
                                                                num_workers=num_workers,
304
                                                                rank=i, ranks=args.num_test_proc)
305
306
                test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
                                                                args.neg_sample_size_test,
307
                                                                args.neg_chunk_size_test,
308
                                                                args.eval_filter,
309
                                                                mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
310
                                                                num_workers=num_workers,
311
                                                                rank=i, ranks=args.num_test_proc)
312
313
314
315
316
                test_sampler_heads.append(test_sampler_head)
                test_sampler_tails.append(test_sampler_tail)
        else:
            test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
                                                            args.neg_sample_size_test,
317
                                                            args.neg_chunk_size_test,
318
                                                            args.eval_filter,
319
                                                            mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
320
                                                            num_workers=num_workers,
321
322
323
                                                            rank=0, ranks=1)
            test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
                                                            args.neg_sample_size_test,
324
                                                            args.neg_chunk_size_test,
325
                                                            args.eval_filter,
326
                                                            mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
327
                                                            num_workers=num_workers,
328
329
330
331
332
333
334
335
                                                            rank=0, ranks=1)

    # We need to free all memory referenced by dataset.
    eval_dataset = None
    dataset = None
    # load model
    model = load_model(logger, args, n_entities, n_relations)

336
    if args.num_proc > 1 or args.async_update:
337
338
        model.share_memory()

339
340
    print('Total data loading time {:.3f} seconds'.format(time.time() - train_time_start))

341
342
    # train
    start = time.time()
343
344
    rel_parts = train_data.rel_parts if args.strict_rel_part or args.soft_rel_part else None
    cross_rels = train_data.cross_rels if args.soft_rel_part else None
345
346
    if args.num_proc > 1:
        procs = []
347
        barrier = mp.Barrier(args.num_proc)
348
        for i in range(args.num_proc):
349
350
351
352
353
354
355
            valid_sampler = [valid_sampler_heads[i], valid_sampler_tails[i]] if args.valid else None
            proc = mp.Process(target=train_mp, args=(args,
                                                     model,
                                                     train_samplers[i],
                                                     valid_sampler,
                                                     i,
                                                     rel_parts,
356
                                                     cross_rels,
357
                                                     barrier))
358
359
360
361
362
363
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()
    else:
        valid_samplers = [valid_sampler_head, valid_sampler_tail] if args.valid else None
364
        train(args, model, train_sampler, valid_samplers, rel_parts=rel_parts)
365
366
367
368
369
370
371
    print('training takes {} seconds'.format(time.time() - start))

    if args.save_emb is not None:
        if not os.path.exists(args.save_emb):
            os.mkdir(args.save_emb)
        model.save_emb(args.save_emb, args.dataset)

372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
        # We need to save the model configurations as well.
        conf_file = os.path.join(args.save_emb, 'config.json')
        with open(conf_file, 'w') as outfile:
            json.dump({'dataset': args.dataset,
                       'model': args.model_name,
                       'emb_size': args.hidden_dim,
                       'max_train_step': args.max_step,
                       'batch_size': args.batch_size,
                       'neg_sample_size': args.neg_sample_size,
                       'lr': args.lr,
                       'gamma': args.gamma,
                       'double_ent': args.double_ent,
                       'double_rel': args.double_rel,
                       'neg_adversarial_sampling': args.neg_adversarial_sampling,
                       'adversarial_temperature': args.adversarial_temperature,
                       'regularization_coef': args.regularization_coef,
                       'regularization_norm': args.regularization_norm},
                       outfile, indent=4)

391
392
    # test
    if args.test:
393
        start = time.time()
394
395
        if args.num_test_proc > 1:
            queue = mp.Queue(args.num_test_proc)
396
            procs = []
397
398
399
400
401
402
403
            for i in range(args.num_test_proc):
                proc = mp.Process(target=test_mp, args=(args,
                                                        model,
                                                        [test_sampler_heads[i], test_sampler_tails[i]],
                                                        i,
                                                        'Test',
                                                        queue))
404
405
                procs.append(proc)
                proc.start()
406
407

            total_metrics = {}
408
409
410
411
412
413
414
415
            metrics = {}
            logs = []
            for i in range(args.num_test_proc):
                log = queue.get()
                logs = logs + log
            
            for metric in logs[0].keys():
                metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
416
417
418
            for k, v in metrics.items():
                print('Test average {} at [{}/{}]: {}'.format(k, args.step, args.max_step, v))

419
420
421
422
            for proc in procs:
                proc.join()
        else:
            test(args, model, [test_sampler_head, test_sampler_tail])
423
        print('test:', time.time() - start)
424
425
426
427
428

if __name__ == '__main__':
    args = ArgParser().parse_args()
    logger = get_logger(args)
    run(args, logger)