train_pytorch.py 4.81 KB
Newer Older
1
2
3
4
5
6
from models import KEModel

from torch.utils.data import DataLoader
import torch.optim as optim
import torch as th
import torch.multiprocessing as mp
7
8
from torch.multiprocessing import Queue
from _thread import start_new_thread
9
10
11
12
13
14
15
16
17

from distutils.version import LooseVersion
TH_VERSION = LooseVersion(th.__version__)
if TH_VERSION.version[0] == 1 and TH_VERSION.version[1] < 2:
    raise Exception("DGL-ke has to work with Pytorch version >= 1.2")

import os
import logging
import time
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from functools import wraps

def thread_wrapped_func(func):
    @wraps(func)
    def decorated_function(*args, **kwargs):
        queue = Queue()
        def _queue_result():
            exception, trace, res = None, None, None
            try:
                res = func(*args, **kwargs)
            except Exception as e:
                exception = e
                trace = traceback.format_exc()
            queue.put((res, exception, trace))

        start_new_thread(_queue_result, ())
        result, exception, trace = queue.get()
        if exception is None:
            return result
        else:
            assert isinstance(exception, Exception)
            raise exception.__class__(trace)
    return decorated_function
41
42
43
44
45
46

def load_model(logger, args, n_entities, n_relations, ckpt=None):
    model = KEModel(args, args.model_name, n_entities, n_relations,
                    args.hidden_dim, args.gamma,
                    double_entity_emb=args.double_ent, double_relation_emb=args.double_rel)
    if ckpt is not None:
47
        assert False, "We do not support loading model emb for genernal Embedding"
48
49
50
51
52
53
54
55
    return model


def load_model_from_checkpoint(logger, args, n_entities, n_relations, ckpt_path):
    model = load_model(logger, args, n_entities, n_relations)
    model.load_emb(ckpt_path, args.dataset)
    return model

56
57
@thread_wrapped_func
def train(args, model, train_sampler, rank=0, rel_parts=None, valid_samplers=None):
58
    if args.num_proc > 1:
59
        th.set_num_threads(4)
60
61
62
63
    logs = []
    for arg in vars(args):
        logging.info('{:20}:{}'.format(arg, getattr(args, arg)))

64
65
66
67
68
    if len(args.gpu) > 0:
        gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
    else:
        gpu_id = -1

69
    start = time.time()
Da Zheng's avatar
Da Zheng committed
70
    sample_time = 0
71
72
73
74
    update_time = 0
    forward_time = 0
    backward_time = 0
    for step in range(args.init_step, args.max_step):
Da Zheng's avatar
Da Zheng committed
75
        start1 = time.time()
76
        pos_g, neg_g = next(train_sampler)
Da Zheng's avatar
Da Zheng committed
77
        sample_time += time.time() - start1
78
79
80
        args.step = step

        start1 = time.time()
81
        loss, log = model.forward(pos_g, neg_g, gpu_id)
82
83
84
85
86
87
88
        forward_time += time.time() - start1

        start1 = time.time()
        loss.backward()
        backward_time += time.time() - start1

        start1 = time.time()
89
        model.update(gpu_id)
90
91
92
93
94
95
96
97
98
99
        update_time += time.time() - start1
        logs.append(log)

        if step % args.log_interval == 0:
            for k in logs[0].keys():
                v = sum(l[k] for l in logs) / len(logs)
                print('[Train]({}/{}) average {}: {}'.format(step, args.max_step, k, v))
            logs = []
            print('[Train] {} steps take {:.3f} seconds'.format(args.log_interval,
                                                            time.time() - start))
Da Zheng's avatar
Da Zheng committed
100
101
102
            print('sample: {:.3f}, forward: {:.3f}, backward: {:.3f}, update: {:.3f}'.format(
                sample_time, forward_time, backward_time, update_time))
            sample_time = 0
103
104
105
106
107
108
109
110
111
112
            update_time = 0
            forward_time = 0
            backward_time = 0
            start = time.time()

        if args.valid and step % args.eval_interval == 0 and step > 1 and valid_samplers is not None:
            start = time.time()
            test(args, model, valid_samplers, mode='Valid')
            print('test:', time.time() - start)

113
114
@thread_wrapped_func
def test(args, model, test_samplers, rank=0, mode='Test', queue=None):
115
    if args.num_proc > 1:
116
117
118
119
120
121
122
        th.set_num_threads(4)

    if len(args.gpu) > 0:
        gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
    else:
        gpu_id = -1

123
124
125
126
127
128
    with th.no_grad():
        logs = []
        for sampler in test_samplers:
            count = 0
            for pos_g, neg_g in sampler:
                with th.no_grad():
129
                    model.forward_test(pos_g, neg_g, logs, gpu_id)
130
131
132
133
134

        metrics = {}
        if len(logs) > 0:
            for metric in logs[0].keys():
                metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
135
136
137
138
139
        if queue is not None:
            queue.put(metrics)
        else:
            for k, v in metrics.items():
                print('{} average {} at [{}/{}]: {}'.format(mode, k, args.step, args.max_step, v))
140
141
    test_samplers[0] = test_samplers[0].reset()
    test_samplers[1] = test_samplers[1].reset()