train.py 21.2 KB
Newer Older
1
2
3
4
5
6
7
8
from dataloader import EvalDataset, TrainDataset, NewBidirectionalOneShotIterator
from dataloader import get_dataset

import argparse
import os
import logging
import time

Da Zheng's avatar
Da Zheng committed
9
backend = os.environ.get('DGLBACKEND', 'pytorch')
10
if backend.lower() == 'mxnet':
11
    import multiprocessing as mp
12
13
14
15
    from train_mxnet import load_model
    from train_mxnet import train
    from train_mxnet import test
else:
16
    import torch.multiprocessing as mp
17
    from train_pytorch import load_model
18
19
    from train_pytorch import train, train_mp
    from train_pytorch import test, test_mp
20
21
22
23
24
25

class ArgParser(argparse.ArgumentParser):
    def __init__(self):
        super(ArgParser, self).__init__()

        self.add_argument('--model_name', default='TransE',
26
27
                          choices=['TransE', 'TransE_l1', 'TransE_l2', 'TransR',
                                   'RESCAL', 'DistMult', 'ComplEx', 'RotatE'],
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
                          help='model to use')
        self.add_argument('--data_path', type=str, default='data',
                          help='root path of all dataset')
        self.add_argument('--dataset', type=str, default='FB15k',
                          help='dataset name, under data_path')
        self.add_argument('--format', type=str, default='1',
                          help='the format of the dataset.')
        self.add_argument('--save_path', type=str, default='ckpts',
                          help='place to save models and logs')
        self.add_argument('--save_emb', type=str, default=None,
                          help='save the embeddings in the specific location.')

        self.add_argument('--max_step', type=int, default=80000,
                          help='train xx steps')
        self.add_argument('--warm_up_step', type=int, default=None,
                          help='for learning rate decay')
        self.add_argument('--batch_size', type=int, default=1024,
                          help='batch size')
        self.add_argument('--batch_size_eval', type=int, default=8,
                          help='batch size used for eval and test')
        self.add_argument('--neg_sample_size', type=int, default=128,
                          help='negative sampling size')
50
51
52
53
54
55
        self.add_argument('--neg_chunk_size', type=int, default=-1,
                          help='chunk size of the negative edges.')
        self.add_argument('--neg_deg_sample', action='store_true',
                          help='negative sample proportional to vertex degree in the training')
        self.add_argument('--neg_deg_sample_eval', action='store_true',
                          help='negative sampling proportional to vertex degree in the evaluation')
56
57
        self.add_argument('--neg_sample_size_valid', type=int, default=1000,
                          help='negative sampling size for validation')
58
59
        self.add_argument('--neg_chunk_size_valid', type=int, default=-1,
                          help='chunk size of the negative edges.')
60
61
        self.add_argument('--neg_sample_size_test', type=int, default=-1,
                          help='negative sampling size for testing')
62
63
        self.add_argument('--neg_chunk_size_test', type=int, default=-1,
                          help='chunk size of the negative edges.')
64
65
66
67
68
69
70
71
        self.add_argument('--hidden_dim', type=int, default=256,
                          help='hidden dim used by relation and entity')
        self.add_argument('--lr', type=float, default=0.0001,
                          help='learning rate')
        self.add_argument('-g', '--gamma', type=float, default=12.0,
                          help='margin value')
        self.add_argument('--eval_percent', type=float, default=1,
                          help='sample some percentage for evaluation.')
72
73
        self.add_argument('--no_eval_filter', action='store_true',
                          help='do not filter positive edges among negative edges for evaluation')
74

75
76
        self.add_argument('--gpu', type=int, default=[-1], nargs='+', 
                          help='a list of active gpu ids, e.g. 0 1 2 4')
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        self.add_argument('--mix_cpu_gpu', action='store_true',
                          help='mix CPU and GPU training')
        self.add_argument('-de', '--double_ent', action='store_true',
                          help='double entitiy dim for complex number')
        self.add_argument('-dr', '--double_rel', action='store_true',
                          help='double relation dim for complex number')
        self.add_argument('--seed', type=int, default=0,
                          help='set random seed fro reproducibility')
        self.add_argument('-log', '--log_interval', type=int, default=1000,
                          help='do evaluation after every x steps')
        self.add_argument('--eval_interval', type=int, default=10000,
                          help='do evaluation after every x steps')
        self.add_argument('-adv', '--neg_adversarial_sampling', action='store_true',
                          help='if use negative adversarial sampling')
        self.add_argument('-a', '--adversarial_temperature', default=1.0, type=float)

        self.add_argument('--valid', action='store_true',
                          help='if valid a model')
        self.add_argument('--test', action='store_true',
                          help='if test a model')
        self.add_argument('-rc', '--regularization_coef', type=float, default=0.000002,
                          help='set value > 0.0 if regularization is used')
        self.add_argument('-rn', '--regularization_norm', type=int, default=3,
                          help='norm used in regularization')
101
        self.add_argument('--num_worker', type=int, default=32,
102
103
104
105
106
107
108
109
110
111
112
113
114
                          help='number of workers used for loading data')
        self.add_argument('--non_uni_weight', action='store_true',
                          help='if use uniform weight when computing loss')
        self.add_argument('--init_step', type=int, default=0,
                          help='DONT SET MANUALLY, used for resume')
        self.add_argument('--step', type=int, default=0,
                          help='DONT SET MANUALLY, track current step')
        self.add_argument('--pickle_graph', action='store_true',
                          help='pickle built graph, building a huge graph is slow.')
        self.add_argument('--num_proc', type=int, default=1,
                          help='number of process used')
        self.add_argument('--rel_part', action='store_true',
                          help='enable relation partitioning')
115
116
117
118
119
120
        self.add_argument('--nomp_thread_per_process', type=int, default=-1,
                          help='num of omp threads used per process in multi-process training')
        self.add_argument('--async_update', action='store_true',
                          help='allow async_update on node embedding')
        self.add_argument('--force_sync_interval', type=int, default=-1,
                          help='We force a synchronization between processes every x steps')
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155


def get_logger(args):
    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    folder = '{}_{}_'.format(args.model_name, args.dataset)
    n = len([x for x in os.listdir(args.save_path) if x.startswith(folder)])
    folder += str(n)
    args.save_path = os.path.join(args.save_path, folder)

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    log_file = os.path.join(args.save_path, 'train.log')

    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S',
        filename=log_file,
        filemode='w'
    )

    logger = logging.getLogger(__name__)
    print("Logs are being recorded at: {}".format(log_file))
    return logger


def run(args, logger):
    # load dataset and samplers
    dataset = get_dataset(args.data_path, args.dataset, args.format)
    n_entities = dataset.n_entities
    n_relations = dataset.n_relations
    if args.neg_sample_size_test < 0:
        args.neg_sample_size_test = n_entities
156
    args.eval_filter = not args.no_eval_filter
157
158
159
160
161
162
163
164
165
166
167
168
169
    if args.neg_deg_sample_eval:
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

    # When we generate a batch of negative edges from a set of positive edges,
    # we first divide the positive edges into chunks and corrupt the edges in a chunk
    # together. By default, the chunk size is equal to the negative sample size.
    # Usually, this works well. But we also allow users to specify the chunk size themselves.
    if args.neg_chunk_size < 0:
        args.neg_chunk_size = args.neg_sample_size
    if args.neg_chunk_size_valid < 0:
        args.neg_chunk_size_valid = args.neg_sample_size_valid
    if args.neg_chunk_size_test < 0:
        args.neg_chunk_size_test = args.neg_sample_size_test
170

171
    num_workers = args.num_worker
172
    train_data = TrainDataset(dataset, args, ranks=args.num_proc)
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    args.strict_rel_part = args.mix_cpu_gpu and (train_data.cross_part == False)

    # Automatically set number of OMP threads for each process if it is not provided
    # The value for GPU is evaluated in AWS p3.16xlarge
    # The value for CPU is evaluated in AWS x1.32xlarge
    if args.nomp_thread_per_process == -1:
        if len(args.gpu) > 0:
            # GPU training
            args.num_thread = 4
        else:
            # CPU training
            args.num_thread = mp.cpu_count() // args.num_proc + 1
    else:
        args.num_thread = args.nomp_thread_per_process

188
189
190
    if args.num_proc > 1:
        train_samplers = []
        for i in range(args.num_proc):
191
192
            train_sampler_head = train_data.create_sampler(args.batch_size,
                                                           args.neg_sample_size,
193
                                                           args.neg_chunk_size,
194
195
                                                           mode='head',
                                                           num_workers=num_workers,
196
                                                           shuffle=True,
197
                                                           exclude_positive=False,
198
                                                           rank=i)
199
200
            train_sampler_tail = train_data.create_sampler(args.batch_size,
                                                           args.neg_sample_size,
201
                                                           args.neg_chunk_size,
202
203
                                                           mode='tail',
                                                           num_workers=num_workers,
204
                                                           shuffle=True,
205
                                                           exclude_positive=False,
206
207
                                                           rank=i)
            train_samplers.append(NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
208
                                                                  args.neg_chunk_size, args.neg_sample_size,
209
210
                                                                  True, n_entities))
    else:
211
212
        train_sampler_head = train_data.create_sampler(args.batch_size,
                                                       args.neg_sample_size,
213
                                                       args.neg_chunk_size,
214
215
                                                       mode='head',
                                                       num_workers=num_workers,
216
                                                       shuffle=True,
217
218
219
                                                       exclude_positive=False)
        train_sampler_tail = train_data.create_sampler(args.batch_size,
                                                       args.neg_sample_size,
220
                                                       args.neg_chunk_size,
221
222
                                                       mode='tail',
                                                       num_workers=num_workers,
223
                                                       shuffle=True,
224
                                                       exclude_positive=False)
225
        train_sampler = NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
226
                                                        args.neg_chunk_size, args.neg_sample_size,
227
228
                                                        True, n_entities)

Da Zheng's avatar
Da Zheng committed
229
230
231
232
    # for multiprocessing evaluation, we don't need to sample multiple batches at a time
    # in each process.
    if args.num_proc > 1:
        num_workers = 1
233
    if args.valid or args.test:
234
        args.num_test_proc = args.num_proc if args.num_proc < len(args.gpu) else len(args.gpu)
235
236
237
238
239
240
241
242
243
244
        eval_dataset = EvalDataset(dataset, args)
    if args.valid:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
        if args.num_proc > 1:
            valid_sampler_heads = []
            valid_sampler_tails = []
            for i in range(args.num_proc):
                valid_sampler_head = eval_dataset.create_sampler('valid', args.batch_size_eval,
                                                                 args.neg_sample_size_valid,
245
                                                                 args.neg_chunk_size_valid,
246
                                                                 args.eval_filter,
247
                                                                 mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
248
                                                                 num_workers=num_workers,
249
250
251
                                                                 rank=i, ranks=args.num_proc)
                valid_sampler_tail = eval_dataset.create_sampler('valid', args.batch_size_eval,
                                                                 args.neg_sample_size_valid,
252
                                                                 args.neg_chunk_size_valid,
253
                                                                 args.eval_filter,
254
                                                                 mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
255
                                                                 num_workers=num_workers,
256
257
258
259
260
261
                                                                 rank=i, ranks=args.num_proc)
                valid_sampler_heads.append(valid_sampler_head)
                valid_sampler_tails.append(valid_sampler_tail)
        else:
            valid_sampler_head = eval_dataset.create_sampler('valid', args.batch_size_eval,
                                                             args.neg_sample_size_valid,
262
                                                             args.neg_chunk_size_valid,
263
                                                             args.eval_filter,
264
                                                             mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
265
                                                             num_workers=num_workers,
266
267
268
                                                             rank=0, ranks=1)
            valid_sampler_tail = eval_dataset.create_sampler('valid', args.batch_size_eval,
                                                             args.neg_sample_size_valid,
269
                                                             args.neg_chunk_size_valid,
270
                                                             args.eval_filter,
271
                                                             mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
272
                                                             num_workers=num_workers,
273
274
275
276
                                                             rank=0, ranks=1)
    if args.test:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
277
278
        # We use a maximum of num_gpu in test stage to save GPU memory.
        if args.num_test_proc > 1:
279
280
            test_sampler_tails = []
            test_sampler_heads = []
281
            for i in range(args.num_test_proc):
282
283
                test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
                                                                args.neg_sample_size_test,
284
                                                                args.neg_chunk_size_test,
285
                                                                args.eval_filter,
286
                                                                mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
287
                                                                num_workers=num_workers,
288
                                                                rank=i, ranks=args.num_test_proc)
289
290
                test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
                                                                args.neg_sample_size_test,
291
                                                                args.neg_chunk_size_test,
292
                                                                args.eval_filter,
293
                                                                mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
294
                                                                num_workers=num_workers,
295
                                                                rank=i, ranks=args.num_test_proc)
296
297
298
299
300
                test_sampler_heads.append(test_sampler_head)
                test_sampler_tails.append(test_sampler_tail)
        else:
            test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
                                                            args.neg_sample_size_test,
301
                                                            args.neg_chunk_size_test,
302
                                                            args.eval_filter,
303
                                                            mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
304
                                                            num_workers=num_workers,
305
306
307
                                                            rank=0, ranks=1)
            test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
                                                            args.neg_sample_size_test,
308
                                                            args.neg_chunk_size_test,
309
                                                            args.eval_filter,
310
                                                            mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
311
                                                            num_workers=num_workers,
312
313
314
315
316
317
318
319
                                                            rank=0, ranks=1)

    # We need to free all memory referenced by dataset.
    eval_dataset = None
    dataset = None
    # load model
    model = load_model(logger, args, n_entities, n_relations)

320
    if args.num_proc > 1 or args.async_update:
321
322
323
324
        model.share_memory()

    # train
    start = time.time()
325
    rel_parts = train_data.rel_parts if args.strict_rel_part else None
326
327
    if args.num_proc > 1:
        procs = []
328
        barrier = mp.Barrier(args.num_proc)
329
        for i in range(args.num_proc):
330
331
332
333
334
335
336
337
            valid_sampler = [valid_sampler_heads[i], valid_sampler_tails[i]] if args.valid else None
            proc = mp.Process(target=train_mp, args=(args,
                                                     model,
                                                     train_samplers[i],
                                                     valid_sampler,
                                                     i,
                                                     rel_parts,
                                                     barrier))
338
339
340
341
342
343
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()
    else:
        valid_samplers = [valid_sampler_head, valid_sampler_tail] if args.valid else None
344
        train(args, model, train_sampler, valid_samplers, rel_parts=rel_parts)
345
346
347
348
349
350
351
352
353
    print('training takes {} seconds'.format(time.time() - start))

    if args.save_emb is not None:
        if not os.path.exists(args.save_emb):
            os.mkdir(args.save_emb)
        model.save_emb(args.save_emb, args.dataset)

    # test
    if args.test:
354
        start = time.time()
355
356
        if args.num_test_proc > 1:
            queue = mp.Queue(args.num_test_proc)
357
            procs = []
358
359
360
361
362
363
364
            for i in range(args.num_test_proc):
                proc = mp.Process(target=test_mp, args=(args,
                                                        model,
                                                        [test_sampler_heads[i], test_sampler_tails[i]],
                                                        i,
                                                        'Test',
                                                        queue))
365
366
                procs.append(proc)
                proc.start()
367
368

            total_metrics = {}
369
370
371
372
373
374
375
376
            metrics = {}
            logs = []
            for i in range(args.num_test_proc):
                log = queue.get()
                logs = logs + log
            
            for metric in logs[0].keys():
                metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
377
378
379
            for k, v in metrics.items():
                print('Test average {} at [{}/{}]: {}'.format(k, args.step, args.max_step, v))

380
381
382
383
            for proc in procs:
                proc.join()
        else:
            test(args, model, [test_sampler_head, test_sampler_tail])
384
        print('test:', time.time() - start)
385
386
387
388
389

if __name__ == '__main__':
    args = ArgParser().parse_args()
    logger = get_logger(args)
    run(args, logger)