train.py 21 KB
Newer Older
1
2
3
4
5
6
7
from dataloader import EvalDataset, TrainDataset, NewBidirectionalOneShotIterator
from dataloader import get_dataset

import argparse
import os
import logging
import time
8
import json
9

10
11
from utils import get_compatible_batch_size

Da Zheng's avatar
Da Zheng committed
12
backend = os.environ.get('DGLBACKEND', 'pytorch')
13
if backend.lower() == 'mxnet':
14
    import multiprocessing as mp
15
16
17
18
    from train_mxnet import load_model
    from train_mxnet import train
    from train_mxnet import test
else:
19
    import torch.multiprocessing as mp
20
    from train_pytorch import load_model
21
22
    from train_pytorch import train, train_mp
    from train_pytorch import test, test_mp
23
24
25
26
27
28

class ArgParser(argparse.ArgumentParser):
    def __init__(self):
        super(ArgParser, self).__init__()

        self.add_argument('--model_name', default='TransE',
29
30
                          choices=['TransE', 'TransE_l1', 'TransE_l2', 'TransR',
                                   'RESCAL', 'DistMult', 'ComplEx', 'RotatE'],
31
32
33
34
35
                          help='model to use')
        self.add_argument('--data_path', type=str, default='data',
                          help='root path of all dataset')
        self.add_argument('--dataset', type=str, default='FB15k',
                          help='dataset name, under data_path')
36
37
38
39
40
        self.add_argument('--format', type=str, default='built_in',
                          help='the format of the dataset, it can be built_in,'\
                                'raw_udd_{htr} and udd_{htr}')
        self.add_argument('--data_files', type=str, default=None, nargs='+',
                          help='a list of data files, e.g. entity relation train valid test')
41
42
43
44
45
46
47
48
49
50
51
52
        self.add_argument('--save_path', type=str, default='ckpts',
                          help='place to save models and logs')
        self.add_argument('--save_emb', type=str, default=None,
                          help='save the embeddings in the specific location.')
        self.add_argument('--max_step', type=int, default=80000,
                          help='train xx steps')
        self.add_argument('--batch_size', type=int, default=1024,
                          help='batch size')
        self.add_argument('--batch_size_eval', type=int, default=8,
                          help='batch size used for eval and test')
        self.add_argument('--neg_sample_size', type=int, default=128,
                          help='negative sampling size')
53
54
55
56
        self.add_argument('--neg_deg_sample', action='store_true',
                          help='negative sample proportional to vertex degree in the training')
        self.add_argument('--neg_deg_sample_eval', action='store_true',
                          help='negative sampling proportional to vertex degree in the evaluation')
57
58
        self.add_argument('--neg_sample_size_eval', type=int, default=-1,
                          help='negative sampling size for evaluation')
59
60
        self.add_argument('--eval_percent', type=float, default=1,
                          help='sample some percentage for evaluation.')
61
62
63
64
65
66
        self.add_argument('--hidden_dim', type=int, default=256,
                          help='hidden dim used by relation and entity')
        self.add_argument('--lr', type=float, default=0.0001,
                          help='learning rate')
        self.add_argument('-g', '--gamma', type=float, default=12.0,
                          help='margin value')
67
68
        self.add_argument('--no_eval_filter', action='store_true',
                          help='do not filter positive edges among negative edges for evaluation')
69
70
        self.add_argument('--gpu', type=int, default=[-1], nargs='+', 
                          help='a list of active gpu ids, e.g. 0 1 2 4')
71
72
73
74
75
76
77
78
79
80
81
82
        self.add_argument('--mix_cpu_gpu', action='store_true',
                          help='mix CPU and GPU training')
        self.add_argument('-de', '--double_ent', action='store_true',
                          help='double entitiy dim for complex number')
        self.add_argument('-dr', '--double_rel', action='store_true',
                          help='double relation dim for complex number')
        self.add_argument('-log', '--log_interval', type=int, default=1000,
                          help='do evaluation after every x steps')
        self.add_argument('--eval_interval', type=int, default=10000,
                          help='do evaluation after every x steps')
        self.add_argument('-adv', '--neg_adversarial_sampling', action='store_true',
                          help='if use negative adversarial sampling')
83
84
        self.add_argument('-a', '--adversarial_temperature', default=1.0, type=float,
                          help='adversarial_temperature')
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        self.add_argument('--valid', action='store_true',
                          help='if valid a model')
        self.add_argument('--test', action='store_true',
                          help='if test a model')
        self.add_argument('-rc', '--regularization_coef', type=float, default=0.000002,
                          help='set value > 0.0 if regularization is used')
        self.add_argument('-rn', '--regularization_norm', type=int, default=3,
                          help='norm used in regularization')
        self.add_argument('--non_uni_weight', action='store_true',
                          help='if use uniform weight when computing loss')
        self.add_argument('--pickle_graph', action='store_true',
                          help='pickle built graph, building a huge graph is slow.')
        self.add_argument('--num_proc', type=int, default=1,
                          help='number of process used')
99
100
        self.add_argument('--num_thread', type=int, default=1,
                          help='number of thread used')
101
102
        self.add_argument('--rel_part', action='store_true',
                          help='enable relation partitioning')
103
104
        self.add_argument('--soft_rel_part', action='store_true',
                          help='enable soft relation partition')
105
106
107
108
        self.add_argument('--async_update', action='store_true',
                          help='allow async_update on node embedding')
        self.add_argument('--force_sync_interval', type=int, default=-1,
                          help='We force a synchronization between processes every x steps')
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137


def get_logger(args):
    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    folder = '{}_{}_'.format(args.model_name, args.dataset)
    n = len([x for x in os.listdir(args.save_path) if x.startswith(folder)])
    folder += str(n)
    args.save_path = os.path.join(args.save_path, folder)

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    log_file = os.path.join(args.save_path, 'train.log')

    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S',
        filename=log_file,
        filemode='w'
    )

    logger = logging.getLogger(__name__)
    print("Logs are being recorded at: {}".format(log_file))
    return logger


def run(args, logger):
138
    init_time_start = time.time()
139
    # load dataset and samplers
140
    dataset = get_dataset(args.data_path, args.dataset, args.format, args.data_files)
141

142
143
144
145
    if args.neg_sample_size_eval < 0:
        args.neg_sample_size_eval = dataset.n_entities
    args.batch_size = get_compatible_batch_size(args.batch_size, args.neg_sample_size)
    args.batch_size_eval = get_compatible_batch_size(args.batch_size_eval, args.neg_sample_size_eval)
146

147
    args.eval_filter = not args.no_eval_filter
148
149
150
    if args.neg_deg_sample_eval:
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

151
    train_data = TrainDataset(dataset, args, ranks=args.num_proc)
152
    # if there is no cross partition relaiton, we fall back to strict_rel_part
153
    args.strict_rel_part = args.mix_cpu_gpu and (train_data.cross_part == False)
154
    args.soft_rel_part = args.mix_cpu_gpu and args.soft_rel_part and train_data.cross_part
155
    args.num_workers = 8 # fix num_worker to 8
156

157
158
159
    if args.num_proc > 1:
        train_samplers = []
        for i in range(args.num_proc):
160
161
            train_sampler_head = train_data.create_sampler(args.batch_size,
                                                           args.neg_sample_size,
162
                                                           args.neg_sample_size,
163
                                                           mode='head',
164
                                                           num_workers=args.num_workers,
165
                                                           shuffle=True,
166
                                                           exclude_positive=False,
167
                                                           rank=i)
168
169
            train_sampler_tail = train_data.create_sampler(args.batch_size,
                                                           args.neg_sample_size,
170
                                                           args.neg_sample_size,
171
                                                           mode='tail',
172
                                                           num_workers=args.num_workers,
173
                                                           shuffle=True,
174
                                                           exclude_positive=False,
175
176
                                                           rank=i)
            train_samplers.append(NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
177
178
179
180
181
182
183
                                                                  args.neg_sample_size, args.neg_sample_size,
                                                                  True, dataset.n_entities))

        train_sampler = NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
                                                        args.neg_sample_size, args.neg_sample_size,
                                                       True, dataset.n_entities)
    else: # This is used for debug
184
185
        train_sampler_head = train_data.create_sampler(args.batch_size,
                                                       args.neg_sample_size,
186
                                                       args.neg_sample_size,
187
                                                       mode='head',
188
                                                       num_workers=args.num_workers,
189
                                                       shuffle=True,
190
191
192
                                                       exclude_positive=False)
        train_sampler_tail = train_data.create_sampler(args.batch_size,
                                                       args.neg_sample_size,
193
                                                       args.neg_sample_size,
194
                                                       mode='tail',
195
                                                       num_workers=args.num_workers,
196
                                                       shuffle=True,
197
                                                       exclude_positive=False)
198
        train_sampler = NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
199
200
201
                                                        args.neg_sample_size, args.neg_sample_size,
                                                        True, dataset.n_entities)

202
203

    if args.valid or args.test:
204
205
206
207
        if len(args.gpu) > 1:
            args.num_test_proc = args.num_proc if args.num_proc < len(args.gpu) else len(args.gpu)
        else:
            args.num_test_proc = args.num_proc
208
        eval_dataset = EvalDataset(dataset, args)
209

210
211
212
213
214
215
    if args.valid:
        if args.num_proc > 1:
            valid_sampler_heads = []
            valid_sampler_tails = []
            for i in range(args.num_proc):
                valid_sampler_head = eval_dataset.create_sampler('valid', args.batch_size_eval,
216
217
                                                                  args.neg_sample_size_eval,
                                                                  args.neg_sample_size_eval,
218
219
220
221
                                                                  args.eval_filter,
                                                                  mode='chunk-head',
                                                                  num_workers=args.num_workers,
                                                                  rank=i, ranks=args.num_proc)
222
                valid_sampler_tail = eval_dataset.create_sampler('valid', args.batch_size_eval,
223
224
                                                                  args.neg_sample_size_eval,
                                                                  args.neg_sample_size_eval,
225
226
227
228
                                                                  args.eval_filter,
                                                                  mode='chunk-tail',
                                                                  num_workers=args.num_workers,
                                                                  rank=i, ranks=args.num_proc)
229
230
                valid_sampler_heads.append(valid_sampler_head)
                valid_sampler_tails.append(valid_sampler_tail)
231
        else: # This is used for debug
232
            valid_sampler_head = eval_dataset.create_sampler('valid', args.batch_size_eval,
233
234
                                                             args.neg_sample_size_eval,
                                                             args.neg_sample_size_eval,
235
                                                             args.eval_filter,
236
                                                             mode='chunk-head',
237
                                                             num_workers=args.num_workers,
238
239
                                                             rank=0, ranks=1)
            valid_sampler_tail = eval_dataset.create_sampler('valid', args.batch_size_eval,
240
241
                                                             args.neg_sample_size_eval,
                                                             args.neg_sample_size_eval,
242
                                                             args.eval_filter,
243
                                                             mode='chunk-tail',
244
                                                             num_workers=args.num_workers,
245
246
                                                             rank=0, ranks=1)
    if args.test:
247
        if args.num_test_proc > 1:
248
249
            test_sampler_tails = []
            test_sampler_heads = []
250
            for i in range(args.num_test_proc):
251
                test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
252
253
                                                                 args.neg_sample_size_eval,
                                                                 args.neg_sample_size_eval,
254
255
256
257
                                                                 args.eval_filter,
                                                                 mode='chunk-head',
                                                                 num_workers=args.num_workers,
                                                                 rank=i, ranks=args.num_test_proc)
258
                test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
259
260
                                                                 args.neg_sample_size_eval,
                                                                 args.neg_sample_size_eval,
261
262
263
264
                                                                 args.eval_filter,
                                                                 mode='chunk-tail',
                                                                 num_workers=args.num_workers,
                                                                 rank=i, ranks=args.num_test_proc)
265
266
267
268
                test_sampler_heads.append(test_sampler_head)
                test_sampler_tails.append(test_sampler_tail)
        else:
            test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
269
270
                                                            args.neg_sample_size_eval,
                                                            args.neg_sample_size_eval,
271
                                                            args.eval_filter,
272
                                                            mode='chunk-head',
273
                                                            num_workers=args.num_workers,
274
275
                                                            rank=0, ranks=1)
            test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
276
277
                                                            args.neg_sample_size_eval,
                                                            args.neg_sample_size_eval,
278
                                                            args.eval_filter,
279
                                                            mode='chunk-tail',
280
                                                            num_workers=args.num_workers,
281
282
283
                                                            rank=0, ranks=1)

    # load model
284
    model = load_model(logger, args, dataset.n_entities, dataset.n_relations)
285
    if args.num_proc > 1 or args.async_update:
286
287
        model.share_memory()

288
289
290
291
292
    # We need to free all memory referenced by dataset.
    eval_dataset = None
    dataset = None

    print('Total initialize time {:.3f} seconds'.format(time.time() - init_time_start))
293

294
295
    # train
    start = time.time()
296
297
    rel_parts = train_data.rel_parts if args.strict_rel_part or args.soft_rel_part else None
    cross_rels = train_data.cross_rels if args.soft_rel_part else None
298
299
    if args.num_proc > 1:
        procs = []
300
        barrier = mp.Barrier(args.num_proc)
301
        for i in range(args.num_proc):
302
303
304
305
306
307
308
            valid_sampler = [valid_sampler_heads[i], valid_sampler_tails[i]] if args.valid else None
            proc = mp.Process(target=train_mp, args=(args,
                                                     model,
                                                     train_samplers[i],
                                                     valid_sampler,
                                                     i,
                                                     rel_parts,
309
                                                     cross_rels,
310
                                                     barrier))
311
312
313
314
315
316
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()
    else:
        valid_samplers = [valid_sampler_head, valid_sampler_tail] if args.valid else None
317
        train(args, model, train_sampler, valid_samplers, rel_parts=rel_parts)
318

319
320
321
322
323
324
325
    print('training takes {} seconds'.format(time.time() - start))

    if args.save_emb is not None:
        if not os.path.exists(args.save_emb):
            os.mkdir(args.save_emb)
        model.save_emb(args.save_emb, args.dataset)

326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
        # We need to save the model configurations as well.
        conf_file = os.path.join(args.save_emb, 'config.json')
        with open(conf_file, 'w') as outfile:
            json.dump({'dataset': args.dataset,
                       'model': args.model_name,
                       'emb_size': args.hidden_dim,
                       'max_train_step': args.max_step,
                       'batch_size': args.batch_size,
                       'neg_sample_size': args.neg_sample_size,
                       'lr': args.lr,
                       'gamma': args.gamma,
                       'double_ent': args.double_ent,
                       'double_rel': args.double_rel,
                       'neg_adversarial_sampling': args.neg_adversarial_sampling,
                       'adversarial_temperature': args.adversarial_temperature,
                       'regularization_coef': args.regularization_coef,
                       'regularization_norm': args.regularization_norm},
                       outfile, indent=4)

345
346
    # test
    if args.test:
347
        start = time.time()
348
349
        if args.num_test_proc > 1:
            queue = mp.Queue(args.num_test_proc)
350
            procs = []
351
352
353
354
355
356
357
            for i in range(args.num_test_proc):
                proc = mp.Process(target=test_mp, args=(args,
                                                        model,
                                                        [test_sampler_heads[i], test_sampler_tails[i]],
                                                        i,
                                                        'Test',
                                                        queue))
358
359
                procs.append(proc)
                proc.start()
360
361

            total_metrics = {}
362
363
364
365
366
367
368
369
            metrics = {}
            logs = []
            for i in range(args.num_test_proc):
                log = queue.get()
                logs = logs + log
            
            for metric in logs[0].keys():
                metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
370
            for k, v in metrics.items():
371
                print('Test average {} : {}'.format(k, v))
372

373
374
375
376
            for proc in procs:
                proc.join()
        else:
            test(args, model, [test_sampler_head, test_sampler_tail])
377
        print('testing takes {:.3f} seconds'.format(time.time() - start))
378
379
380
381
382

if __name__ == '__main__':
    args = ArgParser().parse_args()
    logger = get_logger(args)
    run(args, logger)