"vscode:/vscode.git/clone" did not exist on "3aa641289c995b3a0ce4ea895a76eb1128eff30c"
train.py 21.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# -*- coding: utf-8 -*-
#
# setup.py
#
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

20
21
22
23
24
25
26
from dataloader import EvalDataset, TrainDataset, NewBidirectionalOneShotIterator
from dataloader import get_dataset

import argparse
import os
import logging
import time
27
import json
28

29
30
from utils import get_compatible_batch_size

Da Zheng's avatar
Da Zheng committed
31
backend = os.environ.get('DGLBACKEND', 'pytorch')
32
if backend.lower() == 'mxnet':
33
    import multiprocessing as mp
34
35
36
37
    from train_mxnet import load_model
    from train_mxnet import train
    from train_mxnet import test
else:
38
    import torch.multiprocessing as mp
39
    from train_pytorch import load_model
40
41
    from train_pytorch import train, train_mp
    from train_pytorch import test, test_mp
42
43
44
45
46
47

class ArgParser(argparse.ArgumentParser):
    def __init__(self):
        super(ArgParser, self).__init__()

        self.add_argument('--model_name', default='TransE',
48
49
                          choices=['TransE', 'TransE_l1', 'TransE_l2', 'TransR',
                                   'RESCAL', 'DistMult', 'ComplEx', 'RotatE'],
50
51
52
53
54
                          help='model to use')
        self.add_argument('--data_path', type=str, default='data',
                          help='root path of all dataset')
        self.add_argument('--dataset', type=str, default='FB15k',
                          help='dataset name, under data_path')
55
56
57
58
59
        self.add_argument('--format', type=str, default='built_in',
                          help='the format of the dataset, it can be built_in,'\
                                'raw_udd_{htr} and udd_{htr}')
        self.add_argument('--data_files', type=str, default=None, nargs='+',
                          help='a list of data files, e.g. entity relation train valid test')
60
61
62
63
64
65
66
67
68
69
70
71
        self.add_argument('--save_path', type=str, default='ckpts',
                          help='place to save models and logs')
        self.add_argument('--save_emb', type=str, default=None,
                          help='save the embeddings in the specific location.')
        self.add_argument('--max_step', type=int, default=80000,
                          help='train xx steps')
        self.add_argument('--batch_size', type=int, default=1024,
                          help='batch size')
        self.add_argument('--batch_size_eval', type=int, default=8,
                          help='batch size used for eval and test')
        self.add_argument('--neg_sample_size', type=int, default=128,
                          help='negative sampling size')
72
73
74
75
        self.add_argument('--neg_deg_sample', action='store_true',
                          help='negative sample proportional to vertex degree in the training')
        self.add_argument('--neg_deg_sample_eval', action='store_true',
                          help='negative sampling proportional to vertex degree in the evaluation')
76
77
        self.add_argument('--neg_sample_size_eval', type=int, default=-1,
                          help='negative sampling size for evaluation')
78
79
        self.add_argument('--eval_percent', type=float, default=1,
                          help='sample some percentage for evaluation.')
80
81
82
83
84
85
        self.add_argument('--hidden_dim', type=int, default=256,
                          help='hidden dim used by relation and entity')
        self.add_argument('--lr', type=float, default=0.0001,
                          help='learning rate')
        self.add_argument('-g', '--gamma', type=float, default=12.0,
                          help='margin value')
86
87
        self.add_argument('--no_eval_filter', action='store_true',
                          help='do not filter positive edges among negative edges for evaluation')
88
89
        self.add_argument('--gpu', type=int, default=[-1], nargs='+', 
                          help='a list of active gpu ids, e.g. 0 1 2 4')
90
91
92
93
94
95
96
97
98
99
100
101
        self.add_argument('--mix_cpu_gpu', action='store_true',
                          help='mix CPU and GPU training')
        self.add_argument('-de', '--double_ent', action='store_true',
                          help='double entitiy dim for complex number')
        self.add_argument('-dr', '--double_rel', action='store_true',
                          help='double relation dim for complex number')
        self.add_argument('-log', '--log_interval', type=int, default=1000,
                          help='do evaluation after every x steps')
        self.add_argument('--eval_interval', type=int, default=10000,
                          help='do evaluation after every x steps')
        self.add_argument('-adv', '--neg_adversarial_sampling', action='store_true',
                          help='if use negative adversarial sampling')
102
103
        self.add_argument('-a', '--adversarial_temperature', default=1.0, type=float,
                          help='adversarial_temperature')
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        self.add_argument('--valid', action='store_true',
                          help='if valid a model')
        self.add_argument('--test', action='store_true',
                          help='if test a model')
        self.add_argument('-rc', '--regularization_coef', type=float, default=0.000002,
                          help='set value > 0.0 if regularization is used')
        self.add_argument('-rn', '--regularization_norm', type=int, default=3,
                          help='norm used in regularization')
        self.add_argument('--non_uni_weight', action='store_true',
                          help='if use uniform weight when computing loss')
        self.add_argument('--pickle_graph', action='store_true',
                          help='pickle built graph, building a huge graph is slow.')
        self.add_argument('--num_proc', type=int, default=1,
                          help='number of process used')
118
119
        self.add_argument('--num_thread', type=int, default=1,
                          help='number of thread used')
120
121
        self.add_argument('--rel_part', action='store_true',
                          help='enable relation partitioning')
122
123
        self.add_argument('--soft_rel_part', action='store_true',
                          help='enable soft relation partition')
124
125
126
127
        self.add_argument('--async_update', action='store_true',
                          help='allow async_update on node embedding')
        self.add_argument('--force_sync_interval', type=int, default=-1,
                          help='We force a synchronization between processes every x steps')
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156


def get_logger(args):
    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    folder = '{}_{}_'.format(args.model_name, args.dataset)
    n = len([x for x in os.listdir(args.save_path) if x.startswith(folder)])
    folder += str(n)
    args.save_path = os.path.join(args.save_path, folder)

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    log_file = os.path.join(args.save_path, 'train.log')

    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S',
        filename=log_file,
        filemode='w'
    )

    logger = logging.getLogger(__name__)
    print("Logs are being recorded at: {}".format(log_file))
    return logger


def run(args, logger):
157
    init_time_start = time.time()
158
    # load dataset and samplers
159
    dataset = get_dataset(args.data_path, args.dataset, args.format, args.data_files)
160

161
162
163
164
    if args.neg_sample_size_eval < 0:
        args.neg_sample_size_eval = dataset.n_entities
    args.batch_size = get_compatible_batch_size(args.batch_size, args.neg_sample_size)
    args.batch_size_eval = get_compatible_batch_size(args.batch_size_eval, args.neg_sample_size_eval)
165

166
    args.eval_filter = not args.no_eval_filter
167
168
169
    if args.neg_deg_sample_eval:
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

170
    train_data = TrainDataset(dataset, args, ranks=args.num_proc)
171
    # if there is no cross partition relaiton, we fall back to strict_rel_part
172
    args.strict_rel_part = args.mix_cpu_gpu and (train_data.cross_part == False)
173
    args.soft_rel_part = args.mix_cpu_gpu and args.soft_rel_part and train_data.cross_part
174
    args.num_workers = 8 # fix num_worker to 8
175

176
177
178
    if args.num_proc > 1:
        train_samplers = []
        for i in range(args.num_proc):
179
180
            train_sampler_head = train_data.create_sampler(args.batch_size,
                                                           args.neg_sample_size,
181
                                                           args.neg_sample_size,
182
                                                           mode='head',
183
                                                           num_workers=args.num_workers,
184
                                                           shuffle=True,
185
                                                           exclude_positive=False,
186
                                                           rank=i)
187
188
            train_sampler_tail = train_data.create_sampler(args.batch_size,
                                                           args.neg_sample_size,
189
                                                           args.neg_sample_size,
190
                                                           mode='tail',
191
                                                           num_workers=args.num_workers,
192
                                                           shuffle=True,
193
                                                           exclude_positive=False,
194
195
                                                           rank=i)
            train_samplers.append(NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
196
197
198
199
200
201
202
                                                                  args.neg_sample_size, args.neg_sample_size,
                                                                  True, dataset.n_entities))

        train_sampler = NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
                                                        args.neg_sample_size, args.neg_sample_size,
                                                       True, dataset.n_entities)
    else: # This is used for debug
203
204
        train_sampler_head = train_data.create_sampler(args.batch_size,
                                                       args.neg_sample_size,
205
                                                       args.neg_sample_size,
206
                                                       mode='head',
207
                                                       num_workers=args.num_workers,
208
                                                       shuffle=True,
209
210
211
                                                       exclude_positive=False)
        train_sampler_tail = train_data.create_sampler(args.batch_size,
                                                       args.neg_sample_size,
212
                                                       args.neg_sample_size,
213
                                                       mode='tail',
214
                                                       num_workers=args.num_workers,
215
                                                       shuffle=True,
216
                                                       exclude_positive=False)
217
        train_sampler = NewBidirectionalOneShotIterator(train_sampler_head, train_sampler_tail,
218
219
220
                                                        args.neg_sample_size, args.neg_sample_size,
                                                        True, dataset.n_entities)

221
222

    if args.valid or args.test:
223
224
225
226
        if len(args.gpu) > 1:
            args.num_test_proc = args.num_proc if args.num_proc < len(args.gpu) else len(args.gpu)
        else:
            args.num_test_proc = args.num_proc
227
        eval_dataset = EvalDataset(dataset, args)
228

229
230
231
232
233
234
    if args.valid:
        if args.num_proc > 1:
            valid_sampler_heads = []
            valid_sampler_tails = []
            for i in range(args.num_proc):
                valid_sampler_head = eval_dataset.create_sampler('valid', args.batch_size_eval,
235
236
                                                                  args.neg_sample_size_eval,
                                                                  args.neg_sample_size_eval,
237
238
239
240
                                                                  args.eval_filter,
                                                                  mode='chunk-head',
                                                                  num_workers=args.num_workers,
                                                                  rank=i, ranks=args.num_proc)
241
                valid_sampler_tail = eval_dataset.create_sampler('valid', args.batch_size_eval,
242
243
                                                                  args.neg_sample_size_eval,
                                                                  args.neg_sample_size_eval,
244
245
246
247
                                                                  args.eval_filter,
                                                                  mode='chunk-tail',
                                                                  num_workers=args.num_workers,
                                                                  rank=i, ranks=args.num_proc)
248
249
                valid_sampler_heads.append(valid_sampler_head)
                valid_sampler_tails.append(valid_sampler_tail)
250
        else: # This is used for debug
251
            valid_sampler_head = eval_dataset.create_sampler('valid', args.batch_size_eval,
252
253
                                                             args.neg_sample_size_eval,
                                                             args.neg_sample_size_eval,
254
                                                             args.eval_filter,
255
                                                             mode='chunk-head',
256
                                                             num_workers=args.num_workers,
257
258
                                                             rank=0, ranks=1)
            valid_sampler_tail = eval_dataset.create_sampler('valid', args.batch_size_eval,
259
260
                                                             args.neg_sample_size_eval,
                                                             args.neg_sample_size_eval,
261
                                                             args.eval_filter,
262
                                                             mode='chunk-tail',
263
                                                             num_workers=args.num_workers,
264
265
                                                             rank=0, ranks=1)
    if args.test:
266
        if args.num_test_proc > 1:
267
268
            test_sampler_tails = []
            test_sampler_heads = []
269
            for i in range(args.num_test_proc):
270
                test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
271
272
                                                                 args.neg_sample_size_eval,
                                                                 args.neg_sample_size_eval,
273
274
275
276
                                                                 args.eval_filter,
                                                                 mode='chunk-head',
                                                                 num_workers=args.num_workers,
                                                                 rank=i, ranks=args.num_test_proc)
277
                test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
278
279
                                                                 args.neg_sample_size_eval,
                                                                 args.neg_sample_size_eval,
280
281
282
283
                                                                 args.eval_filter,
                                                                 mode='chunk-tail',
                                                                 num_workers=args.num_workers,
                                                                 rank=i, ranks=args.num_test_proc)
284
285
286
287
                test_sampler_heads.append(test_sampler_head)
                test_sampler_tails.append(test_sampler_tail)
        else:
            test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
288
289
                                                            args.neg_sample_size_eval,
                                                            args.neg_sample_size_eval,
290
                                                            args.eval_filter,
291
                                                            mode='chunk-head',
292
                                                            num_workers=args.num_workers,
293
294
                                                            rank=0, ranks=1)
            test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
295
296
                                                            args.neg_sample_size_eval,
                                                            args.neg_sample_size_eval,
297
                                                            args.eval_filter,
298
                                                            mode='chunk-tail',
299
                                                            num_workers=args.num_workers,
300
301
302
                                                            rank=0, ranks=1)

    # load model
303
    model = load_model(logger, args, dataset.n_entities, dataset.n_relations)
304
    if args.num_proc > 1 or args.async_update:
305
306
        model.share_memory()

307
308
309
310
311
    # We need to free all memory referenced by dataset.
    eval_dataset = None
    dataset = None

    print('Total initialize time {:.3f} seconds'.format(time.time() - init_time_start))
312

313
314
    # train
    start = time.time()
315
316
    rel_parts = train_data.rel_parts if args.strict_rel_part or args.soft_rel_part else None
    cross_rels = train_data.cross_rels if args.soft_rel_part else None
317
318
    if args.num_proc > 1:
        procs = []
319
        barrier = mp.Barrier(args.num_proc)
320
        for i in range(args.num_proc):
321
322
323
324
325
326
327
            valid_sampler = [valid_sampler_heads[i], valid_sampler_tails[i]] if args.valid else None
            proc = mp.Process(target=train_mp, args=(args,
                                                     model,
                                                     train_samplers[i],
                                                     valid_sampler,
                                                     i,
                                                     rel_parts,
328
                                                     cross_rels,
329
                                                     barrier))
330
331
332
333
334
335
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()
    else:
        valid_samplers = [valid_sampler_head, valid_sampler_tail] if args.valid else None
336
        train(args, model, train_sampler, valid_samplers, rel_parts=rel_parts)
337

338
339
340
341
342
343
344
    print('training takes {} seconds'.format(time.time() - start))

    if args.save_emb is not None:
        if not os.path.exists(args.save_emb):
            os.mkdir(args.save_emb)
        model.save_emb(args.save_emb, args.dataset)

345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
        # We need to save the model configurations as well.
        conf_file = os.path.join(args.save_emb, 'config.json')
        with open(conf_file, 'w') as outfile:
            json.dump({'dataset': args.dataset,
                       'model': args.model_name,
                       'emb_size': args.hidden_dim,
                       'max_train_step': args.max_step,
                       'batch_size': args.batch_size,
                       'neg_sample_size': args.neg_sample_size,
                       'lr': args.lr,
                       'gamma': args.gamma,
                       'double_ent': args.double_ent,
                       'double_rel': args.double_rel,
                       'neg_adversarial_sampling': args.neg_adversarial_sampling,
                       'adversarial_temperature': args.adversarial_temperature,
                       'regularization_coef': args.regularization_coef,
                       'regularization_norm': args.regularization_norm},
                       outfile, indent=4)

364
365
    # test
    if args.test:
366
        start = time.time()
367
368
        if args.num_test_proc > 1:
            queue = mp.Queue(args.num_test_proc)
369
            procs = []
370
371
372
373
374
375
376
            for i in range(args.num_test_proc):
                proc = mp.Process(target=test_mp, args=(args,
                                                        model,
                                                        [test_sampler_heads[i], test_sampler_tails[i]],
                                                        i,
                                                        'Test',
                                                        queue))
377
378
                procs.append(proc)
                proc.start()
379
380

            total_metrics = {}
381
382
383
384
385
386
387
388
            metrics = {}
            logs = []
            for i in range(args.num_test_proc):
                log = queue.get()
                logs = logs + log
            
            for metric in logs[0].keys():
                metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
389
            for k, v in metrics.items():
390
                print('Test average {} : {}'.format(k, v))
391

392
393
394
395
            for proc in procs:
                proc.join()
        else:
            test(args, model, [test_sampler_head, test_sampler_tail])
396
        print('testing takes {:.3f} seconds'.format(time.time() - start))
397
398
399
400
401

if __name__ == '__main__':
    args = ArgParser().parse_args()
    logger = get_logger(args)
    run(args, logger)