train.py 20.9 KB
Newer Older
1
2
3
4
5
6
7
from dataloader import EvalDataset, TrainDataset, NewBidirectionalOneShotIterator
from dataloader import get_dataset

import argparse
import os
import logging
import time
8
import json
9

Da Zheng's avatar
Da Zheng committed
10
backend = os.environ.get('DGLBACKEND', 'pytorch')
11
if backend.lower() == 'mxnet':
12
    import multiprocessing as mp
13
14
15
16
    from train_mxnet import load_model
    from train_mxnet import train
    from train_mxnet import test
else:
17
    import torch.multiprocessing as mp
18
    from train_pytorch import load_model
19
20
    from train_pytorch import train, train_mp
    from train_pytorch import test, test_mp
21
22
23
24
25
26

class ArgParser(argparse.ArgumentParser):
    def __init__(self):
        super(ArgParser, self).__init__()

        self.add_argument('--model_name', default='TransE',
27
28
                          choices=['TransE', 'TransE_l1', 'TransE_l2', 'TransR',
                                   'RESCAL', 'DistMult', 'ComplEx', 'RotatE'],
29
30
31
32
33
                          help='model to use')
        self.add_argument('--data_path', type=str, default='data',
                          help='root path of all dataset')
        self.add_argument('--dataset', type=str, default='FB15k',
                          help='dataset name, under data_path')
34
35
36
37
38
        self.add_argument('--format', type=str, default='built_in',
                          help='the format of the dataset, it can be built_in,'\
                                'raw_udd_{htr} and udd_{htr}')
        self.add_argument('--data_files', type=str, default=None, nargs='+',
                          help='a list of data files, e.g. entity relation train valid test')
39
40
41
42
43
44
45
46
47
48
49
50
        self.add_argument('--save_path', type=str, default='ckpts',
                          help='place to save models and logs')
        self.add_argument('--save_emb', type=str, default=None,
                          help='save the embeddings in the specific location.')
        self.add_argument('--max_step', type=int, default=80000,
                          help='train xx steps')
        self.add_argument('--batch_size', type=int, default=1024,
                          help='batch size')
        self.add_argument('--batch_size_eval', type=int, default=8,
                          help='batch size used for eval and test')
        self.add_argument('--neg_sample_size', type=int, default=128,
                          help='negative sampling size')
51
52
53
54
        self.add_argument('--neg_deg_sample', action='store_true',
                          help='negative sample proportional to vertex degree in the training')
        self.add_argument('--neg_deg_sample_eval', action='store_true',
                          help='negative sampling proportional to vertex degree in the evaluation')
55
56
57
58
        self.add_argument('--neg_sample_size_valid', type=int, default=1000,
                          help='negative sampling size for validation')
        self.add_argument('--neg_sample_size_test', type=int, default=-1,
                          help='negative sampling size for testing')
59
60
        self.add_argument('--eval_percent', type=float, default=1,
                          help='sample some percentage for evaluation.')
61
62
63
64
65
66
        self.add_argument('--hidden_dim', type=int, default=256,
                          help='hidden dim used by relation and entity')
        self.add_argument('--lr', type=float, default=0.0001,
                          help='learning rate')
        self.add_argument('-g', '--gamma', type=float, default=12.0,
                          help='margin value')
67
68
        self.add_argument('--no_eval_filter', action='store_true',
                          help='do not filter positive edges among negative edges for evaluation')
69
70
        self.add_argument('--gpu', type=int, default=[-1], nargs='+', 
                          help='a list of active gpu ids, e.g. 0 1 2 4')
71
72
73
74
75
76
77
78
79
80
81
82
        self.add_argument('--mix_cpu_gpu', action='store_true',
                          help='mix CPU and GPU training')
        self.add_argument('-de', '--double_ent', action='store_true',
                          help='double entitiy dim for complex number')
        self.add_argument('-dr', '--double_rel', action='store_true',
                          help='double relation dim for complex number')
        self.add_argument('-log', '--log_interval', type=int, default=1000,
                          help='do evaluation after every x steps')
        self.add_argument('--eval_interval', type=int, default=10000,
                          help='do evaluation after every x steps')
        self.add_argument('-adv', '--neg_adversarial_sampling', action='store_true',
                          help='if use negative adversarial sampling')
83
84
        self.add_argument('-a', '--adversarial_temperature', default=1.0, type=float,
                          help='adversarial_temperature')
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        self.add_argument('--valid', action='store_true',
                          help='if valid a model')
        self.add_argument('--test', action='store_true',
                          help='if test a model')
        self.add_argument('-rc', '--regularization_coef', type=float, default=0.000002,
                          help='set value > 0.0 if regularization is used')
        self.add_argument('-rn', '--regularization_norm', type=int, default=3,
                          help='norm used in regularization')
        self.add_argument('--non_uni_weight', action='store_true',
                          help='if use uniform weight when computing loss')
        self.add_argument('--pickle_graph', action='store_true',
                          help='pickle built graph, building a huge graph is slow.')
        self.add_argument('--num_proc', type=int, default=1,
                          help='number of process used')
99
100
        self.add_argument('--num_thread', type=int, default=1,
                          help='number of thread used')
101
102
        self.add_argument('--rel_part', action='store_true',
                          help='enable relation partitioning')
103
104
        self.add_argument('--soft_rel_part', action='store_true',
                          help='enable soft relation partition')
105
106
107
108
        self.add_argument('--async_update', action='store_true',
                          help='allow async_update on node embedding')
        self.add_argument('--force_sync_interval', type=int, default=-1,
                          help='We force a synchronization between processes every x steps')
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137


def get_logger(args):
    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    folder = '{}_{}_'.format(args.model_name, args.dataset)
    n = len([x for x in os.listdir(args.save_path) if x.startswith(folder)])
    folder += str(n)
    args.save_path = os.path.join(args.save_path, folder)

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    log_file = os.path.join(args.save_path, 'train.log')

    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S',
        filename=log_file,
        filemode='w'
    )

    logger = logging.getLogger(__name__)
    print("Logs are being recorded at: {}".format(log_file))
    return logger


def run(args, logger):
138
    init_time_start = time.time()
139
    # load dataset and samplers
140
    dataset = get_dataset(args.data_path, args.dataset, args.format, args.data_files)
141

142
    if args.neg_sample_size_test < 0:
143
144
        args.neg_sample_size_test = dataset.n_entities

145
    args.eval_filter = not args.no_eval_filter
146
147
148
    if args.neg_deg_sample_eval:
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

149
    train_data = TrainDataset(dataset, args, ranks=args.num_proc)
150
    # if there is no cross partition relaiton, we fall back to strict_rel_part
151
    args.strict_rel_part = args.mix_cpu_gpu and (train_data.cross_part == False)
152
    args.soft_rel_part = args.mix_cpu_gpu and args.soft_rel_part and train_data.cross_part
153
    args.num_workers = 8 # fix num_worker to 8
154

155
156
157
    if args.num_proc > 1:
        train_samplers = []
        for i in range(args.num_proc):
158
159
            train_sampler_head = train_data.create_sampler(args.batch_size,
                                                           args.neg_sample_size,
160
                                                           args.neg_sample_size,
161
                                                           mode='head',
162
                                                           num_workers=args.num_workers,
163
                                                           shuffle=True,
164
                                                           exclude_positive=False,
165
                                                           rank=i)
166
167
            train_sampler_tail = train_data.create_sampler(args.batch_size,
                                                           args.neg_sample_size,
168
                                                           args.neg_sample_size,
169
                                                           mode='tail',
170
                                                           num_workers=args.num_workers,
171
                                                           shuffle=True,
172
                                                           exclude_positive=False,
173
174
                                                           rank=i)
            train_samplers.append(NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
175
176
177
178
179
180
181
                                                                  args.neg_sample_size, args.neg_sample_size,
                                                                  True, dataset.n_entities))

        train_sampler = NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
                                                        args.neg_sample_size, args.neg_sample_size,
                                                       True, dataset.n_entities)
    else: # This is used for debug
182
183
        train_sampler_head = train_data.create_sampler(args.batch_size,
                                                       args.neg_sample_size,
184
                                                       args.neg_sample_size,
185
                                                       mode='head',
186
                                                       num_workers=args.num_workers,
187
                                                       shuffle=True,
188
189
190
                                                       exclude_positive=False)
        train_sampler_tail = train_data.create_sampler(args.batch_size,
                                                       args.neg_sample_size,
191
                                                       args.neg_sample_size,
192
                                                       mode='tail',
193
                                                       num_workers=args.num_workers,
194
                                                       shuffle=True,
195
                                                       exclude_positive=False)
196
        train_sampler = NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
197
198
199
                                                        args.neg_sample_size, args.neg_sample_size,
                                                        True, dataset.n_entities)

200
201

    if args.valid or args.test:
202
203
204
205
        if len(args.gpu) > 1:
            args.num_test_proc = args.num_proc if args.num_proc < len(args.gpu) else len(args.gpu)
        else:
            args.num_test_proc = args.num_proc
206
        eval_dataset = EvalDataset(dataset, args)
207

208
209
210
211
212
213
    if args.valid:
        if args.num_proc > 1:
            valid_sampler_heads = []
            valid_sampler_tails = []
            for i in range(args.num_proc):
                valid_sampler_head = eval_dataset.create_sampler('valid', args.batch_size_eval,
214
215
216
217
218
219
                                                                  args.neg_sample_size_valid,
                                                                  args.neg_sample_size_valid,
                                                                  args.eval_filter,
                                                                  mode='chunk-head',
                                                                  num_workers=args.num_workers,
                                                                  rank=i, ranks=args.num_proc)
220
                valid_sampler_tail = eval_dataset.create_sampler('valid', args.batch_size_eval,
221
222
223
224
225
226
                                                                  args.neg_sample_size_valid,
                                                                  args.neg_sample_size_valid,
                                                                  args.eval_filter,
                                                                  mode='chunk-tail',
                                                                  num_workers=args.num_workers,
                                                                  rank=i, ranks=args.num_proc)
227
228
                valid_sampler_heads.append(valid_sampler_head)
                valid_sampler_tails.append(valid_sampler_tail)
229
        else: # This is used for debug
230
231
            valid_sampler_head = eval_dataset.create_sampler('valid', args.batch_size_eval,
                                                             args.neg_sample_size_valid,
232
                                                             args.neg_sample_size_valid,
233
                                                             args.eval_filter,
234
                                                             mode='chunk-head',
235
                                                             num_workers=args.num_workers,
236
237
238
                                                             rank=0, ranks=1)
            valid_sampler_tail = eval_dataset.create_sampler('valid', args.batch_size_eval,
                                                             args.neg_sample_size_valid,
239
                                                             args.neg_sample_size_valid,
240
                                                             args.eval_filter,
241
                                                             mode='chunk-tail',
242
                                                             num_workers=args.num_workers,
243
244
                                                             rank=0, ranks=1)
    if args.test:
245
        if args.num_test_proc > 1:
246
247
            test_sampler_tails = []
            test_sampler_heads = []
248
            for i in range(args.num_test_proc):
249
                test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
250
251
252
253
254
255
                                                                 args.neg_sample_size_test,
                                                                 args.neg_sample_size_test,
                                                                 args.eval_filter,
                                                                 mode='chunk-head',
                                                                 num_workers=args.num_workers,
                                                                 rank=i, ranks=args.num_test_proc)
256
                test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
257
258
259
260
261
262
                                                                 args.neg_sample_size_test,
                                                                 args.neg_sample_size_test,
                                                                 args.eval_filter,
                                                                 mode='chunk-tail',
                                                                 num_workers=args.num_workers,
                                                                 rank=i, ranks=args.num_test_proc)
263
264
265
266
267
                test_sampler_heads.append(test_sampler_head)
                test_sampler_tails.append(test_sampler_tail)
        else:
            test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
                                                            args.neg_sample_size_test,
268
                                                            args.neg_sample_size_test,
269
                                                            args.eval_filter,
270
                                                            mode='chunk-head',
271
                                                            num_workers=args.num_workers,
272
273
274
                                                            rank=0, ranks=1)
            test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
                                                            args.neg_sample_size_test,
275
                                                            args.neg_sample_size_test,
276
                                                            args.eval_filter,
277
                                                            mode='chunk-tail',
278
                                                            num_workers=args.num_workers,
279
280
281
                                                            rank=0, ranks=1)

    # load model
282
    model = load_model(logger, args, dataset.n_entities, dataset.n_relations)
283
    if args.num_proc > 1 or args.async_update:
284
285
        model.share_memory()

286
287
288
289
290
    # We need to free all memory referenced by dataset.
    eval_dataset = None
    dataset = None

    print('Total initialize time {:.3f} seconds'.format(time.time() - init_time_start))
291

292
293
    # train
    start = time.time()
294
295
    rel_parts = train_data.rel_parts if args.strict_rel_part or args.soft_rel_part else None
    cross_rels = train_data.cross_rels if args.soft_rel_part else None
296
297
    if args.num_proc > 1:
        procs = []
298
        barrier = mp.Barrier(args.num_proc)
299
        for i in range(args.num_proc):
300
301
302
303
304
305
306
            valid_sampler = [valid_sampler_heads[i], valid_sampler_tails[i]] if args.valid else None
            proc = mp.Process(target=train_mp, args=(args,
                                                     model,
                                                     train_samplers[i],
                                                     valid_sampler,
                                                     i,
                                                     rel_parts,
307
                                                     cross_rels,
308
                                                     barrier))
309
310
311
312
313
314
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()
    else:
        valid_samplers = [valid_sampler_head, valid_sampler_tail] if args.valid else None
315
        train(args, model, train_sampler, valid_samplers, rel_parts=rel_parts)
316

317
318
319
320
321
322
323
    print('training takes {} seconds'.format(time.time() - start))

    if args.save_emb is not None:
        if not os.path.exists(args.save_emb):
            os.mkdir(args.save_emb)
        model.save_emb(args.save_emb, args.dataset)

324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        # We need to save the model configurations as well.
        conf_file = os.path.join(args.save_emb, 'config.json')
        with open(conf_file, 'w') as outfile:
            json.dump({'dataset': args.dataset,
                       'model': args.model_name,
                       'emb_size': args.hidden_dim,
                       'max_train_step': args.max_step,
                       'batch_size': args.batch_size,
                       'neg_sample_size': args.neg_sample_size,
                       'lr': args.lr,
                       'gamma': args.gamma,
                       'double_ent': args.double_ent,
                       'double_rel': args.double_rel,
                       'neg_adversarial_sampling': args.neg_adversarial_sampling,
                       'adversarial_temperature': args.adversarial_temperature,
                       'regularization_coef': args.regularization_coef,
                       'regularization_norm': args.regularization_norm},
                       outfile, indent=4)

343
344
    # test
    if args.test:
345
        start = time.time()
346
347
        if args.num_test_proc > 1:
            queue = mp.Queue(args.num_test_proc)
348
            procs = []
349
350
351
352
353
354
355
            for i in range(args.num_test_proc):
                proc = mp.Process(target=test_mp, args=(args,
                                                        model,
                                                        [test_sampler_heads[i], test_sampler_tails[i]],
                                                        i,
                                                        'Test',
                                                        queue))
356
357
                procs.append(proc)
                proc.start()
358
359

            total_metrics = {}
360
361
362
363
364
365
366
367
            metrics = {}
            logs = []
            for i in range(args.num_test_proc):
                log = queue.get()
                logs = logs + log
            
            for metric in logs[0].keys():
                metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
368
            for k, v in metrics.items():
369
                print('Test average {} : {}'.format(k, v))
370

371
372
373
374
            for proc in procs:
                proc.join()
        else:
            test(args, model, [test_sampler_head, test_sampler_tail])
375
        print('test:', time.time() - start)
376
377
378
379
380

if __name__ == '__main__':
    args = ArgParser().parse_args()
    logger = get_logger(args)
    run(args, logger)