train.py 22.1 KB
Newer Older
1
2
3
4
5
6
7
8
from dataloader import EvalDataset, TrainDataset, NewBidirectionalOneShotIterator
from dataloader import get_dataset

import argparse
import os
import logging
import time

Da Zheng's avatar
Da Zheng committed
9
backend = os.environ.get('DGLBACKEND', 'pytorch')
10
if backend.lower() == 'mxnet':
11
    import multiprocessing as mp
12
13
14
15
    from train_mxnet import load_model
    from train_mxnet import train
    from train_mxnet import test
else:
16
    import torch.multiprocessing as mp
17
    from train_pytorch import load_model
18
19
    from train_pytorch import train, train_mp
    from train_pytorch import test, test_mp
20
21
22
23
24
25

class ArgParser(argparse.ArgumentParser):
    def __init__(self):
        super(ArgParser, self).__init__()

        self.add_argument('--model_name', default='TransE',
26
27
                          choices=['TransE', 'TransE_l1', 'TransE_l2', 'TransR',
                                   'RESCAL', 'DistMult', 'ComplEx', 'RotatE'],
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
                          help='model to use')
        self.add_argument('--data_path', type=str, default='data',
                          help='root path of all dataset')
        self.add_argument('--dataset', type=str, default='FB15k',
                          help='dataset name, under data_path')
        self.add_argument('--format', type=str, default='1',
                          help='the format of the dataset.')
        self.add_argument('--save_path', type=str, default='ckpts',
                          help='place to save models and logs')
        self.add_argument('--save_emb', type=str, default=None,
                          help='save the embeddings in the specific location.')

        self.add_argument('--max_step', type=int, default=80000,
                          help='train xx steps')
        self.add_argument('--warm_up_step', type=int, default=None,
                          help='for learning rate decay')
        self.add_argument('--batch_size', type=int, default=1024,
                          help='batch size')
        self.add_argument('--batch_size_eval', type=int, default=8,
                          help='batch size used for eval and test')
        self.add_argument('--neg_sample_size', type=int, default=128,
                          help='negative sampling size')
50
51
52
53
54
55
        self.add_argument('--neg_chunk_size', type=int, default=-1,
                          help='chunk size of the negative edges.')
        self.add_argument('--neg_deg_sample', action='store_true',
                          help='negative sample proportional to vertex degree in the training')
        self.add_argument('--neg_deg_sample_eval', action='store_true',
                          help='negative sampling proportional to vertex degree in the evaluation')
56
57
        self.add_argument('--neg_sample_size_valid', type=int, default=1000,
                          help='negative sampling size for validation')
58
59
        self.add_argument('--neg_chunk_size_valid', type=int, default=-1,
                          help='chunk size of the negative edges.')
60
61
        self.add_argument('--neg_sample_size_test', type=int, default=-1,
                          help='negative sampling size for testing')
62
63
        self.add_argument('--neg_chunk_size_test', type=int, default=-1,
                          help='chunk size of the negative edges.')
64
65
66
67
68
69
70
71
        self.add_argument('--hidden_dim', type=int, default=256,
                          help='hidden dim used by relation and entity')
        self.add_argument('--lr', type=float, default=0.0001,
                          help='learning rate')
        self.add_argument('-g', '--gamma', type=float, default=12.0,
                          help='margin value')
        self.add_argument('--eval_percent', type=float, default=1,
                          help='sample some percentage for evaluation.')
72
73
        self.add_argument('--no_eval_filter', action='store_true',
                          help='do not filter positive edges among negative edges for evaluation')
74

75
76
        self.add_argument('--gpu', type=int, default=[-1], nargs='+', 
                          help='a list of active gpu ids, e.g. 0 1 2 4')
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        self.add_argument('--mix_cpu_gpu', action='store_true',
                          help='mix CPU and GPU training')
        self.add_argument('-de', '--double_ent', action='store_true',
                          help='double entitiy dim for complex number')
        self.add_argument('-dr', '--double_rel', action='store_true',
                          help='double relation dim for complex number')
        self.add_argument('--seed', type=int, default=0,
                          help='set random seed fro reproducibility')
        self.add_argument('-log', '--log_interval', type=int, default=1000,
                          help='do evaluation after every x steps')
        self.add_argument('--eval_interval', type=int, default=10000,
                          help='do evaluation after every x steps')
        self.add_argument('-adv', '--neg_adversarial_sampling', action='store_true',
                          help='if use negative adversarial sampling')
        self.add_argument('-a', '--adversarial_temperature', default=1.0, type=float)

        self.add_argument('--valid', action='store_true',
                          help='if valid a model')
        self.add_argument('--test', action='store_true',
                          help='if test a model')
        self.add_argument('-rc', '--regularization_coef', type=float, default=0.000002,
                          help='set value > 0.0 if regularization is used')
        self.add_argument('-rn', '--regularization_norm', type=int, default=3,
                          help='norm used in regularization')
101
        self.add_argument('--num_worker', type=int, default=32,
102
103
104
105
106
107
108
109
110
111
112
                          help='number of workers used for loading data')
        self.add_argument('--non_uni_weight', action='store_true',
                          help='if use uniform weight when computing loss')
        self.add_argument('--init_step', type=int, default=0,
                          help='DONT SET MANUALLY, used for resume')
        self.add_argument('--step', type=int, default=0,
                          help='DONT SET MANUALLY, track current step')
        self.add_argument('--pickle_graph', action='store_true',
                          help='pickle built graph, building a huge graph is slow.')
        self.add_argument('--num_proc', type=int, default=1,
                          help='number of process used')
113
114
115
116
        self.add_argument('--num_test_proc', type=int, default=1,
                          help='number of process used for test')
        self.add_argument('--num_thread', type=int, default=1,
                          help='number of thread used')
117
118
        self.add_argument('--rel_part', action='store_true',
                          help='enable relation partitioning')
119
120
        self.add_argument('--soft_rel_part', action='store_true',
                          help='enable soft relation partition')
121
122
123
124
125
126
        self.add_argument('--nomp_thread_per_process', type=int, default=-1,
                          help='num of omp threads used per process in multi-process training')
        self.add_argument('--async_update', action='store_true',
                          help='allow async_update on node embedding')
        self.add_argument('--force_sync_interval', type=int, default=-1,
                          help='We force a synchronization between processes every x steps')
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155


def get_logger(args):
    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    folder = '{}_{}_'.format(args.model_name, args.dataset)
    n = len([x for x in os.listdir(args.save_path) if x.startswith(folder)])
    folder += str(n)
    args.save_path = os.path.join(args.save_path, folder)

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    log_file = os.path.join(args.save_path, 'train.log')

    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S',
        filename=log_file,
        filemode='w'
    )

    logger = logging.getLogger(__name__)
    print("Logs are being recorded at: {}".format(log_file))
    return logger


def run(args, logger):
156
    train_time_start = time.time()
157
158
159
160
161
162
    # load dataset and samplers
    dataset = get_dataset(args.data_path, args.dataset, args.format)
    n_entities = dataset.n_entities
    n_relations = dataset.n_relations
    if args.neg_sample_size_test < 0:
        args.neg_sample_size_test = n_entities
163
    args.eval_filter = not args.no_eval_filter
164
165
166
167
168
169
170
171
172
173
174
175
176
    if args.neg_deg_sample_eval:
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

    # When we generate a batch of negative edges from a set of positive edges,
    # we first divide the positive edges into chunks and corrupt the edges in a chunk
    # together. By default, the chunk size is equal to the negative sample size.
    # Usually, this works well. But we also allow users to specify the chunk size themselves.
    if args.neg_chunk_size < 0:
        args.neg_chunk_size = args.neg_sample_size
    if args.neg_chunk_size_valid < 0:
        args.neg_chunk_size_valid = args.neg_sample_size_valid
    if args.neg_chunk_size_test < 0:
        args.neg_chunk_size_test = args.neg_sample_size_test
177

178
    num_workers = args.num_worker
179
    train_data = TrainDataset(dataset, args, ranks=args.num_proc)
180
    # if there is no cross partition relaiton, we fall back to strict_rel_part
181
    args.strict_rel_part = args.mix_cpu_gpu and (train_data.cross_part == False)
182
    args.soft_rel_part = args.mix_cpu_gpu and args.soft_rel_part and train_data.cross_part
183
184
185
186
187
188
189
190
191
192

    # Automatically set number of OMP threads for each process if it is not provided
    # The value for GPU is evaluated in AWS p3.16xlarge
    # The value for CPU is evaluated in AWS x1.32xlarge
    if args.nomp_thread_per_process == -1:
        if len(args.gpu) > 0:
            # GPU training
            args.num_thread = 4
        else:
            # CPU training
193
            args.num_thread = 1
194
195
196
    else:
        args.num_thread = args.nomp_thread_per_process

197
198
199
    if args.num_proc > 1:
        train_samplers = []
        for i in range(args.num_proc):
200
201
            train_sampler_head = train_data.create_sampler(args.batch_size,
                                                           args.neg_sample_size,
202
                                                           args.neg_chunk_size,
203
204
                                                           mode='head',
                                                           num_workers=num_workers,
205
                                                           shuffle=True,
206
                                                           exclude_positive=False,
207
                                                           rank=i)
208
209
            train_sampler_tail = train_data.create_sampler(args.batch_size,
                                                           args.neg_sample_size,
210
                                                           args.neg_chunk_size,
211
212
                                                           mode='tail',
                                                           num_workers=num_workers,
213
                                                           shuffle=True,
214
                                                           exclude_positive=False,
215
216
                                                           rank=i)
            train_samplers.append(NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
217
                                                                  args.neg_chunk_size, args.neg_sample_size,
218
219
                                                                  True, n_entities))
    else:
220
221
        train_sampler_head = train_data.create_sampler(args.batch_size,
                                                       args.neg_sample_size,
222
                                                       args.neg_chunk_size,
223
224
                                                       mode='head',
                                                       num_workers=num_workers,
225
                                                       shuffle=True,
226
227
228
                                                       exclude_positive=False)
        train_sampler_tail = train_data.create_sampler(args.batch_size,
                                                       args.neg_sample_size,
229
                                                       args.neg_chunk_size,
230
231
                                                       mode='tail',
                                                       num_workers=num_workers,
232
                                                       shuffle=True,
233
                                                       exclude_positive=False)
234
        train_sampler = NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
235
                                                        args.neg_chunk_size, args.neg_sample_size,
236
237
                                                        True, n_entities)

Da Zheng's avatar
Da Zheng committed
238
239
240
241
    # for multiprocessing evaluation, we don't need to sample multiple batches at a time
    # in each process.
    if args.num_proc > 1:
        num_workers = 1
242
    if args.valid or args.test:
243
244
245
246
        if len(args.gpu) > 1:
            args.num_test_proc = args.num_proc if args.num_proc < len(args.gpu) else len(args.gpu)
        else:
            args.num_test_proc = args.num_proc
247
248
249
250
251
252
253
254
255
256
        eval_dataset = EvalDataset(dataset, args)
    if args.valid:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
        if args.num_proc > 1:
            valid_sampler_heads = []
            valid_sampler_tails = []
            for i in range(args.num_proc):
                valid_sampler_head = eval_dataset.create_sampler('valid', args.batch_size_eval,
                                                                 args.neg_sample_size_valid,
257
                                                                 args.neg_chunk_size_valid,
258
                                                                 args.eval_filter,
259
                                                                 mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
260
                                                                 num_workers=num_workers,
261
262
263
                                                                 rank=i, ranks=args.num_proc)
                valid_sampler_tail = eval_dataset.create_sampler('valid', args.batch_size_eval,
                                                                 args.neg_sample_size_valid,
264
                                                                 args.neg_chunk_size_valid,
265
                                                                 args.eval_filter,
266
                                                                 mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
267
                                                                 num_workers=num_workers,
268
269
270
271
272
273
                                                                 rank=i, ranks=args.num_proc)
                valid_sampler_heads.append(valid_sampler_head)
                valid_sampler_tails.append(valid_sampler_tail)
        else:
            valid_sampler_head = eval_dataset.create_sampler('valid', args.batch_size_eval,
                                                             args.neg_sample_size_valid,
274
                                                             args.neg_chunk_size_valid,
275
                                                             args.eval_filter,
276
                                                             mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
277
                                                             num_workers=num_workers,
278
279
280
                                                             rank=0, ranks=1)
            valid_sampler_tail = eval_dataset.create_sampler('valid', args.batch_size_eval,
                                                             args.neg_sample_size_valid,
281
                                                             args.neg_chunk_size_valid,
282
                                                             args.eval_filter,
283
                                                             mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
284
                                                             num_workers=num_workers,
285
286
287
288
                                                             rank=0, ranks=1)
    if args.test:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
289
290
        # We use a maximum of num_gpu in test stage to save GPU memory.
        if args.num_test_proc > 1:
291
292
            test_sampler_tails = []
            test_sampler_heads = []
293
            for i in range(args.num_test_proc):
294
295
                test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
                                                                args.neg_sample_size_test,
296
                                                                args.neg_chunk_size_test,
297
                                                                args.eval_filter,
298
                                                                mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
299
                                                                num_workers=num_workers,
300
                                                                rank=i, ranks=args.num_test_proc)
301
302
                test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
                                                                args.neg_sample_size_test,
303
                                                                args.neg_chunk_size_test,
304
                                                                args.eval_filter,
305
                                                                mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
306
                                                                num_workers=num_workers,
307
                                                                rank=i, ranks=args.num_test_proc)
308
309
310
311
312
                test_sampler_heads.append(test_sampler_head)
                test_sampler_tails.append(test_sampler_tail)
        else:
            test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
                                                            args.neg_sample_size_test,
313
                                                            args.neg_chunk_size_test,
314
                                                            args.eval_filter,
315
                                                            mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
316
                                                            num_workers=num_workers,
317
318
319
                                                            rank=0, ranks=1)
            test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
                                                            args.neg_sample_size_test,
320
                                                            args.neg_chunk_size_test,
321
                                                            args.eval_filter,
322
                                                            mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
323
                                                            num_workers=num_workers,
324
325
326
327
328
329
330
331
                                                            rank=0, ranks=1)

    # We need to free all memory referenced by dataset.
    eval_dataset = None
    dataset = None
    # load model
    model = load_model(logger, args, n_entities, n_relations)

332
    if args.num_proc > 1 or args.async_update:
333
334
        model.share_memory()

335
336
    print('Total data loading time {:.3f} seconds'.format(time.time() - train_time_start))

337
338
    # train
    start = time.time()
339
340
    rel_parts = train_data.rel_parts if args.strict_rel_part or args.soft_rel_part else None
    cross_rels = train_data.cross_rels if args.soft_rel_part else None
341
342
    if args.num_proc > 1:
        procs = []
343
        barrier = mp.Barrier(args.num_proc)
344
        for i in range(args.num_proc):
345
346
347
348
349
350
351
            valid_sampler = [valid_sampler_heads[i], valid_sampler_tails[i]] if args.valid else None
            proc = mp.Process(target=train_mp, args=(args,
                                                     model,
                                                     train_samplers[i],
                                                     valid_sampler,
                                                     i,
                                                     rel_parts,
352
                                                     cross_rels,
353
                                                     barrier))
354
355
356
357
358
359
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()
    else:
        valid_samplers = [valid_sampler_head, valid_sampler_tail] if args.valid else None
360
        train(args, model, train_sampler, valid_samplers, rel_parts=rel_parts)
361
362
363
364
365
366
367
368
369
    print('training takes {} seconds'.format(time.time() - start))

    if args.save_emb is not None:
        if not os.path.exists(args.save_emb):
            os.mkdir(args.save_emb)
        model.save_emb(args.save_emb, args.dataset)

    # test
    if args.test:
370
        start = time.time()
371
372
        if args.num_test_proc > 1:
            queue = mp.Queue(args.num_test_proc)
373
            procs = []
374
375
376
377
378
379
380
            for i in range(args.num_test_proc):
                proc = mp.Process(target=test_mp, args=(args,
                                                        model,
                                                        [test_sampler_heads[i], test_sampler_tails[i]],
                                                        i,
                                                        'Test',
                                                        queue))
381
382
                procs.append(proc)
                proc.start()
383
384

            total_metrics = {}
385
386
387
388
389
390
391
392
            metrics = {}
            logs = []
            for i in range(args.num_test_proc):
                log = queue.get()
                logs = logs + log
            
            for metric in logs[0].keys():
                metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
393
394
395
            for k, v in metrics.items():
                print('Test average {} at [{}/{}]: {}'.format(k, args.step, args.max_step, v))

396
397
398
399
            for proc in procs:
                proc.join()
        else:
            test(args, model, [test_sampler_head, test_sampler_tail])
400
        print('test:', time.time() - start)
401
402
403
404
405

if __name__ == '__main__':
    args = ArgParser().parse_args()
    logger = get_logger(args)
    run(args, logger)