__init__.py 3.63 KB
Newer Older
VoVAllen's avatar
VoVAllen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from __future__ import absolute_import

import numpy as np
import tensorflow as tf
from scipy.sparse import coo_matrix


def cuda():
    return '/gpu:0'


def is_cuda_available():
    return tf.test.is_gpu_available(cuda_only=True)


def array_equal(a, b):
    return np.array_equal(a.numpy(), b.numpy())


def allclose(a, b, rtol=1e-4, atol=1e-4):
21
22
    return np.allclose(tf.convert_to_tensor(a).numpy(),
                       tf.convert_to_tensor(b).numpy(), rtol=rtol, atol=atol)
VoVAllen's avatar
VoVAllen committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185


def randn(shape):
    return tf.random.normal(shape)


class GradContext:
    def __init__(self):
        self.tensor_for_grad = []
        self.grad_list = []
        self.tape = None

    def set_tape(self, tape):
        self.tape = tape

    def add_tensor(self, x):
        idx_pop = []
        for idx, ele in enumerate(self.tensor_for_grad):
            if ele._id == x._id:
                idx_pop.append(idx)
        if len(idx_pop) > 0:
            self.tensor_for_grad.pop(idx_pop[0])
        if self.tape is not None:
            self.tape.watch(x)
        self.tensor_for_grad.append(x)

    def backward(self, x, head_gradient=None):
        if head_gradient is not None:
            x = x * head_gradient
        self.grad_list = self.tape.gradient(x, self.tensor_for_grad)

    def is_no_grad(self, x):
        idx_pop = []
        for idx, ele in enumerate(self.tensor_for_grad):
            if ele._id == x._id:
                idx_pop.append(idx)
        if len(idx_pop) == 0:
            return True
        else:
            return self.grad_list[idx_pop[0]] is None

    def grad(self, x):
        idx_pop = []
        for idx, ele in enumerate(self.tensor_for_grad):
            if ele._id == x._id:
                idx_pop.append(idx)
        assert len(idx_pop) == 1
        t = self.grad_list[idx_pop[0]]
        return tf.convert_to_tensor(t)


cgrad = GradContext()


def get_cgrad():
    return cgrad


class record_grad:
    def __init__(self):
        self.tape = tf.GradientTape()

    def __enter__(self):
        cgrad.set_tape(self.tape)
        self.tape.__enter__()
        for x in cgrad.tensor_for_grad:
            self.tape.watch(x)

    def __exit__(self, exc_type, exc_value, exc_traceback):
        # pass
        self.tape.__exit__(exc_type, exc_value, exc_traceback)
        cgrad.tape = None


def attach_grad(x):
    cgrad.add_tensor(x)
    return x


def backward(x, head_gradient=None):
    cgrad.backward(x, head_gradient)


def grad(x):
    return cgrad.grad(x)

def is_no_grad(x):
    return cgrad.is_no_grad(x)


def full(shape, fill_value, dtype, ctx):
    with tf.device(ctx):
        t = tf.constant(fill_value, shape=shape, dtype=dtype)
    return t


def narrow_row_set(x, start, stop, new):
    # x[start:stop] = new
    raise NotImplementedError("TF doesn't support inplace update")


def sparse_to_numpy(x):
    # tf.sparse.to_dense assume sorted indices, need to turn off validate_indices in our cases
    return tf.sparse.to_dense(x, validate_indices=False).numpy()


def clone(x):
    return tf.identity(x)


def reduce_sum(x):
    return tf.reduce_sum(x)


def softmax(x, dim):
    return tf.math.softmax(x, axis=dim)


def spmm(x, y):
    return tf.sparse.sparse_dense_matmul(x, y)


def add(a, b):
    return a + b


def sub(a, b):
    return a - b


def mul(a, b):
    return a * b


def div(a, b):
    return a / b


def sum(x, dim):
    return tf.reduce_sum(x, axis=dim)


def max(x, dim):
    return tf.reduce_max(x, axis=dim)


def min(x, dim):
    return tf.reduce_min(x, axis=dim)


def prod(x, dim):
    return tf.reduce_prod(x, axis=dim)


def matmul(a, b):
    return tf.linalg.matmul(a, b)


def dot(a, b):
    return sum(mul(a, b), dim=-1)


no_grad = None