train_pytorch.py 13.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# -*- coding: utf-8 -*-
#
# setup.py
#
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

20
21
from models import KEModel

22
import torch.multiprocessing as mp
23
24
25
26
27
28
29
30
from torch.utils.data import DataLoader
import torch.optim as optim
import torch as th

from distutils.version import LooseVersion
TH_VERSION = LooseVersion(th.__version__)
if TH_VERSION.version[0] == 1 and TH_VERSION.version[1] < 2:
    raise Exception("DGL-ke has to work with Pytorch version >= 1.2")
31
from models.pytorch.tensor_models import thread_wrapped_func
32
33
34
35

import os
import logging
import time
36
37
from functools import wraps

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import dgl
from dgl.contrib import KVClient
import dgl.backend as F

from dataloader import EvalDataset
from dataloader import get_dataset

class KGEClient(KVClient):
    """User-defined kvclient for DGL-KGE
    """
    def _push_handler(self, name, ID, data, target):
        """Row-Sparse Adagrad updater
        """
        original_name = name[0:-6]
        state_sum = target[original_name+'_state-data-']
        grad_sum = (data * data).mean(1)
        state_sum.index_add_(0, ID, grad_sum)
        std = state_sum[ID]  # _sparse_mask
        std_values = std.sqrt_().add_(1e-10).unsqueeze(1)
        tmp = (-self.clr * data / std_values)
        target[name].index_add_(0, ID, tmp)


    def set_clr(self, learning_rate):
        """Set learning rate
        """
        self.clr = learning_rate


    def set_local2global(self, l2g):
        self._l2g = l2g


    def get_local2global(self):
        return self._l2g


def connect_to_kvstore(args, entity_pb, relation_pb, l2g):
    """Create kvclient and connect to kvstore service
    """
    server_namebook = dgl.contrib.read_ip_config(filename=args.ip_config)

    my_client = KGEClient(server_namebook=server_namebook)

    my_client.set_clr(args.lr)

    my_client.connect()

    if my_client.get_id() % args.num_client == 0:
        my_client.set_partition_book(name='entity_emb', partition_book=entity_pb)
        my_client.set_partition_book(name='relation_emb', partition_book=relation_pb)
    else:
        my_client.set_partition_book(name='entity_emb')
        my_client.set_partition_book(name='relation_emb')

    my_client.set_local2global(l2g)

    return my_client


98
99
100
101
102
def load_model(logger, args, n_entities, n_relations, ckpt=None):
    model = KEModel(args, args.model_name, n_entities, n_relations,
                    args.hidden_dim, args.gamma,
                    double_entity_emb=args.double_ent, double_relation_emb=args.double_rel)
    if ckpt is not None:
103
        assert False, "We do not support loading model emb for genernal Embedding"
104
105
106
107
108
109
110
111
    return model


def load_model_from_checkpoint(logger, args, n_entities, n_relations, ckpt_path):
    model = load_model(logger, args, n_entities, n_relations)
    model.load_emb(ckpt_path, args.dataset)
    return model

112
def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=None, cross_rels=None, barrier=None, client=None):
113
114
115
116
    logs = []
    for arg in vars(args):
        logging.info('{:20}:{}'.format(arg, getattr(args, arg)))

117
118
119
120
121
    if len(args.gpu) > 0:
        gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
    else:
        gpu_id = -1

122
123
    if args.async_update:
        model.create_async_update()
124
    if args.strict_rel_part or args.soft_rel_part:
125
        model.prepare_relation(th.device('cuda:' + str(gpu_id)))
126
127
    if args.soft_rel_part:
        model.prepare_cross_rels(cross_rels)
128

129
    train_start = start = time.time()
Da Zheng's avatar
Da Zheng committed
130
    sample_time = 0
131
132
133
    update_time = 0
    forward_time = 0
    backward_time = 0
134
    for step in range(0, args.max_step):
Da Zheng's avatar
Da Zheng committed
135
        start1 = time.time()
136
        pos_g, neg_g = next(train_sampler)
Da Zheng's avatar
Da Zheng committed
137
        sample_time += time.time() - start1
138

139
140
141
        if client is not None:
            model.pull_model(client, pos_g, neg_g)

142
        start1 = time.time()
143
        loss, log = model.forward(pos_g, neg_g, gpu_id)
144
145
146
147
148
149
150
        forward_time += time.time() - start1

        start1 = time.time()
        loss.backward()
        backward_time += time.time() - start1

        start1 = time.time()
151
152
153
154
        if client is not None:
            model.push_gradient(client)
        else:
            model.update(gpu_id)
155
156
157
        update_time += time.time() - start1
        logs.append(log)

158
159
160
161
162
163
        # force synchronize embedding across processes every X steps
        if args.force_sync_interval > 0 and \
            (step + 1) % args.force_sync_interval == 0:
            barrier.wait()

        if (step + 1) % args.log_interval == 0:
164
165
            for k in logs[0].keys():
                v = sum(l[k] for l in logs) / len(logs)
166
                print('[{}][Train]({}/{}) average {}: {}'.format(rank, (step + 1), args.max_step, k, v))
167
            logs = []
168
            print('[{}][Train] {} steps take {:.3f} seconds'.format(rank, args.log_interval,
169
                                                            time.time() - start))
170
171
            print('[{}]sample: {:.3f}, forward: {:.3f}, backward: {:.3f}, update: {:.3f}'.format(
                rank, sample_time, forward_time, backward_time, update_time))
Da Zheng's avatar
Da Zheng committed
172
            sample_time = 0
173
174
175
176
177
            update_time = 0
            forward_time = 0
            backward_time = 0
            start = time.time()

178
179
        if args.valid and (step + 1) % args.eval_interval == 0 and step > 1 and valid_samplers is not None:
            valid_start = time.time()
180
            if args.strict_rel_part or args.soft_rel_part:
181
182
183
184
185
                model.writeback_relation(rank, rel_parts)
            # forced sync for validation
            if barrier is not None:
                barrier.wait()
            test(args, model, valid_samplers, rank, mode='Valid')
186
            print('validation take {:.3f} seconds:'.format(time.time() - valid_start))
187
188
            if args.soft_rel_part:
                model.prepare_cross_rels(cross_rels)
189
190
191
            if barrier is not None:
                barrier.wait()

192
    print('train {} takes {:.3f} seconds'.format(rank, time.time() - train_start))
193
194
    if args.async_update:
        model.finish_async_update()
195
    if args.strict_rel_part or args.soft_rel_part:
196
        model.writeback_relation(rank, rel_parts)
197

198
199
200
201
202
203
def test(args, model, test_samplers, rank=0, mode='Test', queue=None):
    if len(args.gpu) > 0:
        gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
    else:
        gpu_id = -1

204
    if args.strict_rel_part or args.soft_rel_part:
205
206
        model.load_relation(th.device('cuda:' + str(gpu_id)))

207
208
209
210
    with th.no_grad():
        logs = []
        for sampler in test_samplers:
            for pos_g, neg_g in sampler:
211
                model.forward_test(pos_g, neg_g, logs, gpu_id)
212
213
214
215
216

        metrics = {}
        if len(logs) > 0:
            for metric in logs[0].keys():
                metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
217
        if queue is not None:
218
            queue.put(logs)
219
220
        else:
            for k, v in metrics.items():
221
                print('[{}]{} average {}: {}'.format(rank, mode, k, v))
222
223
    test_samplers[0] = test_samplers[0].reset()
    test_samplers[1] = test_samplers[1].reset()
224
225

@thread_wrapped_func
226
def train_mp(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=None, cross_rels=None, barrier=None):
227
228
    if args.num_proc > 1:
        th.set_num_threads(args.num_thread)
229
    train(args, model, train_sampler, valid_samplers, rank, rel_parts, cross_rels, barrier)
230
231
232

@thread_wrapped_func
def test_mp(args, model, test_samplers, rank=0, mode='Test', queue=None):
233
234
    if args.num_proc > 1:
        th.set_num_threads(args.num_thread)
235
    test(args, model, test_samplers, rank, mode, queue)
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264

@thread_wrapped_func
def dist_train_test(args, model, train_sampler, entity_pb, relation_pb, l2g, rank=0, rel_parts=None, cross_rels=None, barrier=None):
    if args.num_proc > 1:
        th.set_num_threads(args.num_thread)

    client = connect_to_kvstore(args, entity_pb, relation_pb, l2g)
    client.barrier()
    train_time_start = time.time()
    train(args, model, train_sampler, None, rank, rel_parts, cross_rels, barrier, client)
    client.barrier()
    print('Total train time {:.3f} seconds'.format(time.time() - train_time_start))

    model = None

    if client.get_id() % args.num_client == 0: # pull full model from kvstore

        args.num_test_proc = args.num_client
        dataset_full = get_dataset(args.data_path, args.dataset, args.format)

        print('Full data n_entities: ' + str(dataset_full.n_entities))
        print("Full data n_relations: " + str(dataset_full.n_relations))

        model_test = load_model(None, args, dataset_full.n_entities, dataset_full.n_relations)
        eval_dataset = EvalDataset(dataset_full, args)

        if args.test:
            model_test.share_memory()

265
266
        if args.neg_sample_size_eval < 0:
            args.neg_sample_size_eval = dataset_full.n_entities
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        args.eval_filter = not args.no_eval_filter
        if args.neg_deg_sample_eval:
            assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

        print("Pull relation_emb ...")
        relation_id = F.arange(0, model_test.n_relations)
        relation_data = client.pull(name='relation_emb', id_tensor=relation_id)
        model_test.relation_emb.emb[relation_id] = relation_data
 
        print("Pull entity_emb ... ")
        # split model into 100 small parts
        start = 0
        percent = 0
        entity_id = F.arange(0, model_test.n_entities)
        count = int(model_test.n_entities / 100)
        end = start + count
        while True:
            print("Pull %d / 100 ..." % percent)
            if end >= model_test.n_entities:
                end = -1
            tmp_id = entity_id[start:end]
            entity_data = client.pull(name='entity_emb', id_tensor=tmp_id)
            model_test.entity_emb.emb[tmp_id] = entity_data
            if end == -1:
                break
            start = end
            end += count
            percent += 1

296
297
298
299
            if args.save_emb is not None:
                if not os.path.exists(args.save_emb):
                    os.mkdir(args.save_emb)
                model.save_emb(args.save_emb, args.dataset)
300
301
302
303
304
305
306

        if args.test:
            args.num_thread = 1
            test_sampler_tails = []
            test_sampler_heads = []
            for i in range(args.num_test_proc):
                test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
307
308
                                                                args.neg_sample_size_eval,
                                                                args.neg_sample_size_eval,
309
310
                                                                args.eval_filter,
                                                                mode='chunk-head',
311
                                                                num_workers=args.num_workers,
312
313
                                                                rank=i, ranks=args.num_test_proc)
                test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
314
315
                                                                args.neg_sample_size_eval,
                                                                args.neg_sample_size_eval,
316
317
                                                                args.eval_filter,
                                                                mode='chunk-tail',
318
                                                                num_workers=args.num_workers,
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
                                                                rank=i, ranks=args.num_test_proc)
                test_sampler_heads.append(test_sampler_head)
                test_sampler_tails.append(test_sampler_tail)

            eval_dataset = None
            dataset_full = None

            print("Run test, test processes: %d" % args.num_test_proc)

            queue = mp.Queue(args.num_test_proc)
            procs = []
            for i in range(args.num_test_proc):
                proc = mp.Process(target=test_mp, args=(args,
                                                        model_test,
                                                        [test_sampler_heads[i], test_sampler_tails[i]],
                                                        i,
                                                        'Test',
                                                        queue))
                procs.append(proc)
                proc.start()

            total_metrics = {}
            metrics = {}
            logs = []
            for i in range(args.num_test_proc):
                log = queue.get()
                logs = logs + log
            
            for metric in logs[0].keys():
                metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
            for k, v in metrics.items():
350
                print('Test average {} : {}'.format(k, v))
351
352
353
354
355

            for proc in procs:
                proc.join()

        if client.get_id() == 0:
356
            client.shut_down()