tensor_models.py 3.22 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
"""
Knowledge Graph Embedding Models.
1. TransE
2. DistMult
3. ComplEx
4. RotatE
5. pRotatE
6. TransH
7. TransR
8. TransD
9. RESCAL
"""
import os
import numpy as np

import torch as th
import torch.nn as nn
import torch.nn.functional as functional
import torch.nn.init as INIT

from .. import *

logsigmoid = functional.logsigmoid

def get_device(args):
26
    return th.device('cpu') if args.gpu[0] < 0 else th.device('cuda:' + str(args.gpu[0]))
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

norm = lambda x, p: x.norm(p=p)**p

get_scalar = lambda x: x.detach().item()

reshape = lambda arr, x, y: arr.view(x, y)

cuda = lambda arr, gpu: arr.cuda(gpu)

class ExternalEmbedding:
    def __init__(self, args, num, dim, device):
        self.gpu = args.gpu
        self.args = args
        self.trace = []

        self.emb = th.empty(num, dim, dtype=th.float32, device=device)
        self.state_sum = self.emb.new().resize_(self.emb.size(0)).zero_()
        self.state_step = 0

    def init(self, emb_init):
        INIT.uniform_(self.emb, -emb_init, emb_init)
        INIT.zeros_(self.state_sum)

    def share_memory(self):
        self.emb.share_memory_()
        self.state_sum.share_memory_()

    def __call__(self, idx, gpu_id=-1, trace=True):
        s = self.emb[idx]
56
57
        if gpu_id >= 0:
            s = s.cuda(gpu_id)
Da Zheng's avatar
Da Zheng committed
58
59
        # During the training, we need to trace the computation.
        # In this case, we need to record the computation path and compute the gradients.
60
        if trace:
Da Zheng's avatar
Da Zheng committed
61
            data = s.clone().detach().requires_grad_(True)
62
            self.trace.append((idx, data))
Da Zheng's avatar
Da Zheng committed
63
64
        else:
            data = s
65
66
        return data

67
    def update(self, gpu_id=-1):
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        self.state_step += 1
        with th.no_grad():
            for idx, data in self.trace:
                grad = data.grad.data

                clr = self.args.lr
                #clr = self.args.lr / (1 + (self.state_step - 1) * group['lr_decay'])

                # the update is non-linear so indices must be unique
                grad_indices = idx
                grad_values = grad

                grad_sum = (grad_values * grad_values).mean(1)
                device = self.state_sum.device
                if device != grad_indices.device:
                    grad_indices = grad_indices.to(device)
                if device != grad_sum.device:
                    grad_sum = grad_sum.to(device)
                self.state_sum.index_add_(0, grad_indices, grad_sum)
                std = self.state_sum[grad_indices]  # _sparse_mask
88
89
                if gpu_id >= 0:
                    std = std.cuda(gpu_id)
90
91
92
93
94
95
96
97
98
99
100
101
102
                std_values = std.sqrt_().add_(1e-10).unsqueeze(1)
                tmp = (-clr * grad_values / std_values)
                if tmp.device != device:
                    tmp = tmp.to(device)
                # TODO(zhengda) the overhead is here.
                self.emb.index_add_(0, grad_indices, tmp)
        self.trace = []

    def curr_emb(self):
        data = [data for _, data in self.trace]
        return th.cat(data, 0)

    def save(self, path, name):
103
        file_name = os.path.join(path, name+'.npy')
104
105
106
107
108
        np.save(file_name, self.emb.cpu().detach().numpy())

    def load(self, path, name):
        file_name = os.path.join(path, name+'.npy')
        self.emb = th.Tensor(np.load(file_name))