train_pytorch.py 5.58 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
from models import KEModel

from torch.utils.data import DataLoader
import torch.optim as optim
import torch as th

from distutils.version import LooseVersion
TH_VERSION = LooseVersion(th.__version__)
if TH_VERSION.version[0] == 1 and TH_VERSION.version[1] < 2:
    raise Exception("DGL-ke has to work with Pytorch version >= 1.2")
11
from models.pytorch.tensor_models import thread_wrapped_func
12
13
14
15

import os
import logging
import time
16
17
from functools import wraps

18
19
20
21
22
def load_model(logger, args, n_entities, n_relations, ckpt=None):
    model = KEModel(args, args.model_name, n_entities, n_relations,
                    args.hidden_dim, args.gamma,
                    double_entity_emb=args.double_ent, double_relation_emb=args.double_rel)
    if ckpt is not None:
23
        assert False, "We do not support loading model emb for genernal Embedding"
24
25
26
27
28
29
30
31
    return model


def load_model_from_checkpoint(logger, args, n_entities, n_relations, ckpt_path):
    model = load_model(logger, args, n_entities, n_relations)
    model.load_emb(ckpt_path, args.dataset)
    return model

32
def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=None, cross_rels=None, barrier=None):
33
34
35
36
    logs = []
    for arg in vars(args):
        logging.info('{:20}:{}'.format(arg, getattr(args, arg)))

37
38
39
40
41
    if len(args.gpu) > 0:
        gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
    else:
        gpu_id = -1

42
43
    if args.async_update:
        model.create_async_update()
44
    if args.strict_rel_part or args.soft_rel_part:
45
        model.prepare_relation(th.device('cuda:' + str(gpu_id)))
46
47
    if args.soft_rel_part:
        model.prepare_cross_rels(cross_rels)
48

49
    start = time.time()
Da Zheng's avatar
Da Zheng committed
50
    sample_time = 0
51
52
53
54
    update_time = 0
    forward_time = 0
    backward_time = 0
    for step in range(args.init_step, args.max_step):
Da Zheng's avatar
Da Zheng committed
55
        start1 = time.time()
56
        pos_g, neg_g = next(train_sampler)
Da Zheng's avatar
Da Zheng committed
57
        sample_time += time.time() - start1
58
59
60
        args.step = step

        start1 = time.time()
61
        loss, log = model.forward(pos_g, neg_g, gpu_id)
62
63
64
65
66
67
68
        forward_time += time.time() - start1

        start1 = time.time()
        loss.backward()
        backward_time += time.time() - start1

        start1 = time.time()
69
        model.update(gpu_id)
70
71
72
        update_time += time.time() - start1
        logs.append(log)

73
74
75
76
77
78
        # force synchronize embedding across processes every X steps
        if args.force_sync_interval > 0 and \
            (step + 1) % args.force_sync_interval == 0:
            barrier.wait()

        if (step + 1) % args.log_interval == 0:
79
80
            for k in logs[0].keys():
                v = sum(l[k] for l in logs) / len(logs)
81
                print('[{}][Train]({}/{}) average {}: {}'.format(rank, (step + 1), args.max_step, k, v))
82
            logs = []
83
            print('[{}][Train] {} steps take {:.3f} seconds'.format(rank, args.log_interval,
84
                                                            time.time() - start))
85
86
            print('[{}]sample: {:.3f}, forward: {:.3f}, backward: {:.3f}, update: {:.3f}'.format(
                rank, sample_time, forward_time, backward_time, update_time))
Da Zheng's avatar
Da Zheng committed
87
            sample_time = 0
88
89
90
91
92
            update_time = 0
            forward_time = 0
            backward_time = 0
            start = time.time()

93
94
        if args.valid and (step + 1) % args.eval_interval == 0 and step > 1 and valid_samplers is not None:
            valid_start = time.time()
95
            if args.strict_rel_part or args.soft_rel_part:
96
97
98
99
100
101
                model.writeback_relation(rank, rel_parts)
            # forced sync for validation
            if barrier is not None:
                barrier.wait()
            test(args, model, valid_samplers, rank, mode='Valid')
            print('test:', time.time() - valid_start)
102
103
            if args.soft_rel_part:
                model.prepare_cross_rels(cross_rels)
104
105
106
107
108
109
            if barrier is not None:
                barrier.wait()

    print('train {} takes {:.3f} seconds'.format(rank, time.time() - start))
    if args.async_update:
        model.finish_async_update()
110
    if args.strict_rel_part or args.soft_rel_part:
111
        model.writeback_relation(rank, rel_parts)
112

113
114
115
116
117
118
def test(args, model, test_samplers, rank=0, mode='Test', queue=None):
    if len(args.gpu) > 0:
        gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
    else:
        gpu_id = -1

119
    if args.strict_rel_part or args.soft_rel_part:
120
121
        model.load_relation(th.device('cuda:' + str(gpu_id)))

122
123
124
125
126
    with th.no_grad():
        logs = []
        for sampler in test_samplers:
            count = 0
            for pos_g, neg_g in sampler:
127
                model.forward_test(pos_g, neg_g, logs, gpu_id)
128
129
130
131
132

        metrics = {}
        if len(logs) > 0:
            for metric in logs[0].keys():
                metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
133
        if queue is not None:
134
            queue.put(logs)
135
136
        else:
            for k, v in metrics.items():
137
                print('[{}]{} average {} at [{}/{}]: {}'.format(rank, mode, k, args.step, args.max_step, v))
138
139
    test_samplers[0] = test_samplers[0].reset()
    test_samplers[1] = test_samplers[1].reset()
140
141

@thread_wrapped_func
142
def train_mp(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=None, cross_rels=None, barrier=None):
143
144
    if args.num_proc > 1:
        th.set_num_threads(args.num_thread)
145
    train(args, model, train_sampler, valid_samplers, rank, rel_parts, cross_rels, barrier)
146
147
148
149

@thread_wrapped_func
def test_mp(args, model, test_samplers, rank=0, mode='Test', queue=None):
    test(args, model, test_samplers, rank, mode, queue)