train_pytorch.py 5.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
from models import KEModel

from torch.utils.data import DataLoader
import torch.optim as optim
import torch as th

from distutils.version import LooseVersion
TH_VERSION = LooseVersion(th.__version__)
if TH_VERSION.version[0] == 1 and TH_VERSION.version[1] < 2:
    raise Exception("DGL-ke has to work with Pytorch version >= 1.2")
11
from models.pytorch.tensor_models import thread_wrapped_func
12
13
14
15

import os
import logging
import time
16
17
from functools import wraps

18
19
20
21
22
def load_model(logger, args, n_entities, n_relations, ckpt=None):
    model = KEModel(args, args.model_name, n_entities, n_relations,
                    args.hidden_dim, args.gamma,
                    double_entity_emb=args.double_ent, double_relation_emb=args.double_rel)
    if ckpt is not None:
23
        assert False, "We do not support loading model emb for genernal Embedding"
24
25
26
27
28
29
30
31
    return model


def load_model_from_checkpoint(logger, args, n_entities, n_relations, ckpt_path):
    model = load_model(logger, args, n_entities, n_relations)
    model.load_emb(ckpt_path, args.dataset)
    return model

32
def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=None, barrier=None):
33
34
35
36
    logs = []
    for arg in vars(args):
        logging.info('{:20}:{}'.format(arg, getattr(args, arg)))

37
38
39
40
41
    if len(args.gpu) > 0:
        gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
    else:
        gpu_id = -1

42
43
44
45
46
    if args.async_update:
        model.create_async_update()
    if args.strict_rel_part:
        model.prepare_relation(th.device('cuda:' + str(gpu_id)))

47
    start = time.time()
Da Zheng's avatar
Da Zheng committed
48
    sample_time = 0
49
50
51
52
    update_time = 0
    forward_time = 0
    backward_time = 0
    for step in range(args.init_step, args.max_step):
Da Zheng's avatar
Da Zheng committed
53
        start1 = time.time()
54
        pos_g, neg_g = next(train_sampler)
Da Zheng's avatar
Da Zheng committed
55
        sample_time += time.time() - start1
56
57
58
        args.step = step

        start1 = time.time()
59
        loss, log = model.forward(pos_g, neg_g, gpu_id)
60
61
62
63
64
65
66
        forward_time += time.time() - start1

        start1 = time.time()
        loss.backward()
        backward_time += time.time() - start1

        start1 = time.time()
67
        model.update(gpu_id)
68
69
70
        update_time += time.time() - start1
        logs.append(log)

71
72
73
74
75
76
        # force synchronize embedding across processes every X steps
        if args.force_sync_interval > 0 and \
            (step + 1) % args.force_sync_interval == 0:
            barrier.wait()

        if (step + 1) % args.log_interval == 0:
77
78
            for k in logs[0].keys():
                v = sum(l[k] for l in logs) / len(logs)
79
                print('[{}][Train]({}/{}) average {}: {}'.format(rank, (step + 1), args.max_step, k, v))
80
            logs = []
81
            print('[{}][Train] {} steps take {:.3f} seconds'.format(rank, args.log_interval,
82
                                                            time.time() - start))
83
84
            print('[{}]sample: {:.3f}, forward: {:.3f}, backward: {:.3f}, update: {:.3f}'.format(
                rank, sample_time, forward_time, backward_time, update_time))
Da Zheng's avatar
Da Zheng committed
85
            sample_time = 0
86
87
88
89
90
            update_time = 0
            forward_time = 0
            backward_time = 0
            start = time.time()

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        if args.valid and (step + 1) % args.eval_interval == 0 and step > 1 and valid_samplers is not None:
            valid_start = time.time()
            if args.strict_rel_part:
                model.writeback_relation(rank, rel_parts)
            # forced sync for validation
            if barrier is not None:
                barrier.wait()
            test(args, model, valid_samplers, rank, mode='Valid')
            print('test:', time.time() - valid_start)
            if barrier is not None:
                barrier.wait()

    print('train {} takes {:.3f} seconds'.format(rank, time.time() - start))
    if args.async_update:
        model.finish_async_update()
    if args.strict_rel_part:
        model.writeback_relation(rank, rel_parts)
108

109
110
111
112
113
114
def test(args, model, test_samplers, rank=0, mode='Test', queue=None):
    if len(args.gpu) > 0:
        gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
    else:
        gpu_id = -1

115
116
117
    if args.strict_rel_part:
        model.load_relation(th.device('cuda:' + str(gpu_id)))

118
119
120
121
122
    with th.no_grad():
        logs = []
        for sampler in test_samplers:
            count = 0
            for pos_g, neg_g in sampler:
123
                model.forward_test(pos_g, neg_g, logs, gpu_id)
124
125
126
127
128

        metrics = {}
        if len(logs) > 0:
            for metric in logs[0].keys():
                metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
129
        if queue is not None:
130
            queue.put(logs)
131
132
        else:
            for k, v in metrics.items():
133
                print('[{}]{} average {} at [{}/{}]: {}'.format(rank, mode, k, args.step, args.max_step, v))
134
135
    test_samplers[0] = test_samplers[0].reset()
    test_samplers[1] = test_samplers[1].reset()
136
137
138
139
140
141
142
143
144
145

@thread_wrapped_func
def train_mp(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=None, barrier=None):
    if args.num_proc > 1:
        th.set_num_threads(args.num_thread)
    train(args, model, train_sampler, valid_samplers, rank, rel_parts, barrier)

@thread_wrapped_func
def test_mp(args, model, test_samplers, rank=0, mode='Test', queue=None):
    test(args, model, test_samplers, rank, mode, queue)