train.py 19.1 KB
Newer Older
1
2
3
4
5
6
7
8
from dataloader import EvalDataset, TrainDataset, NewBidirectionalOneShotIterator
from dataloader import get_dataset

import argparse
import os
import logging
import time

Da Zheng's avatar
Da Zheng committed
9
backend = os.environ.get('DGLBACKEND', 'pytorch')
10
if backend.lower() == 'mxnet':
11
    import multiprocessing as mp
12
13
14
15
    from train_mxnet import load_model
    from train_mxnet import train
    from train_mxnet import test
else:
16
    import torch.multiprocessing as mp
17
18
19
20
21
22
23
24
25
    from train_pytorch import load_model
    from train_pytorch import train
    from train_pytorch import test

class ArgParser(argparse.ArgumentParser):
    def __init__(self):
        super(ArgParser, self).__init__()

        self.add_argument('--model_name', default='TransE',
26
27
                          choices=['TransE', 'TransE_l1', 'TransE_l2', 'TransR',
                                   'RESCAL', 'DistMult', 'ComplEx', 'RotatE'],
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
                          help='model to use')
        self.add_argument('--data_path', type=str, default='data',
                          help='root path of all dataset')
        self.add_argument('--dataset', type=str, default='FB15k',
                          help='dataset name, under data_path')
        self.add_argument('--format', type=str, default='1',
                          help='the format of the dataset.')
        self.add_argument('--save_path', type=str, default='ckpts',
                          help='place to save models and logs')
        self.add_argument('--save_emb', type=str, default=None,
                          help='save the embeddings in the specific location.')

        self.add_argument('--max_step', type=int, default=80000,
                          help='train xx steps')
        self.add_argument('--warm_up_step', type=int, default=None,
                          help='for learning rate decay')
        self.add_argument('--batch_size', type=int, default=1024,
                          help='batch size')
        self.add_argument('--batch_size_eval', type=int, default=8,
                          help='batch size used for eval and test')
        self.add_argument('--neg_sample_size', type=int, default=128,
                          help='negative sampling size')
50
51
52
53
54
55
        self.add_argument('--neg_chunk_size', type=int, default=-1,
                          help='chunk size of the negative edges.')
        self.add_argument('--neg_deg_sample', action='store_true',
                          help='negative sample proportional to vertex degree in the training')
        self.add_argument('--neg_deg_sample_eval', action='store_true',
                          help='negative sampling proportional to vertex degree in the evaluation')
56
57
        self.add_argument('--neg_sample_size_valid', type=int, default=1000,
                          help='negative sampling size for validation')
58
59
        self.add_argument('--neg_chunk_size_valid', type=int, default=-1,
                          help='chunk size of the negative edges.')
60
61
        self.add_argument('--neg_sample_size_test', type=int, default=-1,
                          help='negative sampling size for testing')
62
63
        self.add_argument('--neg_chunk_size_test', type=int, default=-1,
                          help='chunk size of the negative edges.')
64
65
66
67
68
69
70
71
        self.add_argument('--hidden_dim', type=int, default=256,
                          help='hidden dim used by relation and entity')
        self.add_argument('--lr', type=float, default=0.0001,
                          help='learning rate')
        self.add_argument('-g', '--gamma', type=float, default=12.0,
                          help='margin value')
        self.add_argument('--eval_percent', type=float, default=1,
                          help='sample some percentage for evaluation.')
72
73
        self.add_argument('--no_eval_filter', action='store_true',
                          help='do not filter positive edges among negative edges for evaluation')
74

75
76
        self.add_argument('--gpu', type=int, default=[-1], nargs='+', 
                          help='a list of active gpu ids, e.g. 0 1 2 4')
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        self.add_argument('--mix_cpu_gpu', action='store_true',
                          help='mix CPU and GPU training')
        self.add_argument('-de', '--double_ent', action='store_true',
                          help='double entitiy dim for complex number')
        self.add_argument('-dr', '--double_rel', action='store_true',
                          help='double relation dim for complex number')
        self.add_argument('--seed', type=int, default=0,
                          help='set random seed fro reproducibility')
        self.add_argument('-log', '--log_interval', type=int, default=1000,
                          help='do evaluation after every x steps')
        self.add_argument('--eval_interval', type=int, default=10000,
                          help='do evaluation after every x steps')
        self.add_argument('-adv', '--neg_adversarial_sampling', action='store_true',
                          help='if use negative adversarial sampling')
        self.add_argument('-a', '--adversarial_temperature', default=1.0, type=float)

        self.add_argument('--valid', action='store_true',
                          help='if valid a model')
        self.add_argument('--test', action='store_true',
                          help='if test a model')
        self.add_argument('-rc', '--regularization_coef', type=float, default=0.000002,
                          help='set value > 0.0 if regularization is used')
        self.add_argument('-rn', '--regularization_norm', type=int, default=3,
                          help='norm used in regularization')
        self.add_argument('--num_worker', type=int, default=16,
                          help='number of workers used for loading data')
        self.add_argument('--non_uni_weight', action='store_true',
                          help='if use uniform weight when computing loss')
        self.add_argument('--init_step', type=int, default=0,
                          help='DONT SET MANUALLY, used for resume')
        self.add_argument('--step', type=int, default=0,
                          help='DONT SET MANUALLY, track current step')
        self.add_argument('--pickle_graph', action='store_true',
                          help='pickle built graph, building a huge graph is slow.')
        self.add_argument('--num_proc', type=int, default=1,
                          help='number of process used')
        self.add_argument('--rel_part', action='store_true',
                          help='enable relation partitioning')


def get_logger(args):
    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    folder = '{}_{}_'.format(args.model_name, args.dataset)
    n = len([x for x in os.listdir(args.save_path) if x.startswith(folder)])
    folder += str(n)
    args.save_path = os.path.join(args.save_path, folder)

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    log_file = os.path.join(args.save_path, 'train.log')

    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S',
        filename=log_file,
        filemode='w'
    )

    logger = logging.getLogger(__name__)
    print("Logs are being recorded at: {}".format(log_file))
    return logger


def run(args, logger):
    # load dataset and samplers
    dataset = get_dataset(args.data_path, args.dataset, args.format)
    n_entities = dataset.n_entities
    n_relations = dataset.n_relations
    if args.neg_sample_size_test < 0:
        args.neg_sample_size_test = n_entities
150
    args.eval_filter = not args.no_eval_filter
151
152
153
154
155
156
157
158
159
160
161
162
163
    if args.neg_deg_sample_eval:
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

    # When we generate a batch of negative edges from a set of positive edges,
    # we first divide the positive edges into chunks and corrupt the edges in a chunk
    # together. By default, the chunk size is equal to the negative sample size.
    # Usually, this works well. But we also allow users to specify the chunk size themselves.
    if args.neg_chunk_size < 0:
        args.neg_chunk_size = args.neg_sample_size
    if args.neg_chunk_size_valid < 0:
        args.neg_chunk_size_valid = args.neg_sample_size_valid
    if args.neg_chunk_size_test < 0:
        args.neg_chunk_size_test = args.neg_sample_size_test
164
165
166
167
168
169

    train_data = TrainDataset(dataset, args, ranks=args.num_proc)
    if args.num_proc > 1:
        train_samplers = []
        for i in range(args.num_proc):
            train_sampler_head = train_data.create_sampler(args.batch_size, args.neg_sample_size,
170
171
                                                           args.neg_chunk_size,
                                                           mode='chunk-head',
172
173
174
175
176
                                                           num_workers=args.num_worker,
                                                           shuffle=True,
                                                           exclude_positive=True,
                                                           rank=i)
            train_sampler_tail = train_data.create_sampler(args.batch_size, args.neg_sample_size,
177
178
                                                           args.neg_chunk_size,
                                                           mode='chunk-tail',
179
180
181
182
183
                                                           num_workers=args.num_worker,
                                                           shuffle=True,
                                                           exclude_positive=True,
                                                           rank=i)
            train_samplers.append(NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
184
                                                                  args.neg_chunk_size,
185
186
187
                                                                  True, n_entities))
    else:
        train_sampler_head = train_data.create_sampler(args.batch_size, args.neg_sample_size,
188
189
                                                       args.neg_chunk_size,
                                                       mode='chunk-head',
190
191
192
193
                                                       num_workers=args.num_worker,
                                                       shuffle=True,
                                                       exclude_positive=True)
        train_sampler_tail = train_data.create_sampler(args.batch_size, args.neg_sample_size,
194
195
                                                       args.neg_chunk_size,
                                                       mode='chunk-tail',
196
197
198
199
                                                       num_workers=args.num_worker,
                                                       shuffle=True,
                                                       exclude_positive=True)
        train_sampler = NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
200
                                                        args.neg_chunk_size,
201
202
                                                        True, n_entities)

Da Zheng's avatar
Da Zheng committed
203
204
205
206
207
    # for multiprocessing evaluation, we don't need to sample multiple batches at a time
    # in each process.
    num_workers = args.num_worker
    if args.num_proc > 1:
        num_workers = 1
208
209
210
211
212
213
214
215
216
217
218
    if args.valid or args.test:
        eval_dataset = EvalDataset(dataset, args)
    if args.valid:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
        if args.num_proc > 1:
            valid_sampler_heads = []
            valid_sampler_tails = []
            for i in range(args.num_proc):
                valid_sampler_head = eval_dataset.create_sampler('valid', args.batch_size_eval,
                                                                 args.neg_sample_size_valid,
219
                                                                 args.neg_chunk_size_valid,
220
                                                                 args.eval_filter,
221
                                                                 mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
222
                                                                 num_workers=num_workers,
223
224
225
                                                                 rank=i, ranks=args.num_proc)
                valid_sampler_tail = eval_dataset.create_sampler('valid', args.batch_size_eval,
                                                                 args.neg_sample_size_valid,
226
                                                                 args.neg_chunk_size_valid,
227
                                                                 args.eval_filter,
228
                                                                 mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
229
                                                                 num_workers=num_workers,
230
231
232
233
234
235
                                                                 rank=i, ranks=args.num_proc)
                valid_sampler_heads.append(valid_sampler_head)
                valid_sampler_tails.append(valid_sampler_tail)
        else:
            valid_sampler_head = eval_dataset.create_sampler('valid', args.batch_size_eval,
                                                             args.neg_sample_size_valid,
236
                                                             args.neg_chunk_size_valid,
237
                                                             args.eval_filter,
238
                                                             mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
239
                                                             num_workers=num_workers,
240
241
242
                                                             rank=0, ranks=1)
            valid_sampler_tail = eval_dataset.create_sampler('valid', args.batch_size_eval,
                                                             args.neg_sample_size_valid,
243
                                                             args.neg_chunk_size_valid,
244
                                                             args.eval_filter,
245
                                                             mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
246
                                                             num_workers=num_workers,
247
248
249
250
251
252
253
254
255
256
                                                             rank=0, ranks=1)
    if args.test:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
        if args.num_proc > 1:
            test_sampler_tails = []
            test_sampler_heads = []
            for i in range(args.num_proc):
                test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
                                                                args.neg_sample_size_test,
257
                                                                args.neg_chunk_size_test,
258
                                                                args.eval_filter,
259
                                                                mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
260
                                                                num_workers=num_workers,
261
262
263
                                                                rank=i, ranks=args.num_proc)
                test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
                                                                args.neg_sample_size_test,
264
                                                                args.neg_chunk_size_test,
265
                                                                args.eval_filter,
266
                                                                mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
267
                                                                num_workers=num_workers,
268
269
270
271
272
273
                                                                rank=i, ranks=args.num_proc)
                test_sampler_heads.append(test_sampler_head)
                test_sampler_tails.append(test_sampler_tail)
        else:
            test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
                                                            args.neg_sample_size_test,
274
                                                            args.neg_chunk_size_test,
275
                                                            args.eval_filter,
276
                                                            mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
277
                                                            num_workers=num_workers,
278
279
280
                                                            rank=0, ranks=1)
            test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
                                                            args.neg_sample_size_test,
281
                                                            args.neg_chunk_size_test,
282
                                                            args.eval_filter,
283
                                                            mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
284
                                                            num_workers=num_workers,
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
                                                            rank=0, ranks=1)

    # We need to free all memory referenced by dataset.
    eval_dataset = None
    dataset = None
    # load model
    model = load_model(logger, args, n_entities, n_relations)

    if args.num_proc > 1:
        model.share_memory()

    # train
    start = time.time()
    if args.num_proc > 1:
        procs = []
        for i in range(args.num_proc):
301
            rel_parts = train_data.rel_parts if args.rel_part else None
302
            valid_samplers = [valid_sampler_heads[i], valid_sampler_tails[i]] if args.valid else None
303
            proc = mp.Process(target=train, args=(args, model, train_samplers[i], i, rel_parts, valid_samplers))
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()
    else:
        valid_samplers = [valid_sampler_head, valid_sampler_tail] if args.valid else None
        train(args, model, train_sampler, valid_samplers)
    print('training takes {} seconds'.format(time.time() - start))

    if args.save_emb is not None:
        if not os.path.exists(args.save_emb):
            os.mkdir(args.save_emb)
        model.save_emb(args.save_emb, args.dataset)

    # test
    if args.test:
320
        start = time.time()
321
        if args.num_proc > 1:
322
            queue = mp.Queue(args.num_proc)
323
324
            procs = []
            for i in range(args.num_proc):
325
                proc = mp.Process(target=test, args=(args, model, [test_sampler_heads[i], test_sampler_tails[i]],
326
                                  i, 'Test', queue))
327
328
                procs.append(proc)
                proc.start()
329
330
331
332
333
334
335
336
337
338
339
340

            total_metrics = {}
            for i in range(args.num_proc):
                metrics = queue.get()
                for k, v in metrics.items():
                    if i == 0:
                        total_metrics[k] = v / args.num_proc
                    else:
                        total_metrics[k] += v / args.num_proc
            for k, v in metrics.items():
                print('Test average {} at [{}/{}]: {}'.format(k, args.step, args.max_step, v))

341
342
343
344
            for proc in procs:
                proc.join()
        else:
            test(args, model, [test_sampler_head, test_sampler_tail])
345
        print('test:', time.time() - start)
346
347
348
349
350

if __name__ == '__main__':
    args = ArgParser().parse_args()
    logger = get_logger(args)
    run(args, logger)