__init__.py 1.78 KB
Newer Older
VoVAllen's avatar
VoVAllen committed
1
2
3
4
5
6
7
8
from __future__ import absolute_import

import numpy as np
import tensorflow as tf
from scipy.sparse import coo_matrix


def cuda():
9
    return "/gpu:0"
VoVAllen's avatar
VoVAllen committed
10
11
12
13
14
15
16
17
18
19
20


def is_cuda_available():
    return tf.test.is_gpu_available(cuda_only=True)


def array_equal(a, b):
    return np.array_equal(a.numpy(), b.numpy())


def allclose(a, b, rtol=1e-4, atol=1e-4):
21
22
23
24
25
26
    return np.allclose(
        tf.convert_to_tensor(a).numpy(),
        tf.convert_to_tensor(b).numpy(),
        rtol=rtol,
        atol=atol,
    )
VoVAllen's avatar
VoVAllen committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80


def randn(shape):
    return tf.random.normal(shape)


def full(shape, fill_value, dtype, ctx):
    with tf.device(ctx):
        t = tf.constant(fill_value, shape=shape, dtype=dtype)
    return t


def narrow_row_set(x, start, stop, new):
    # x[start:stop] = new
    raise NotImplementedError("TF doesn't support inplace update")


def sparse_to_numpy(x):
    # tf.sparse.to_dense assume sorted indices, need to turn off validate_indices in our cases
    return tf.sparse.to_dense(x, validate_indices=False).numpy()


def clone(x):
    return tf.identity(x)


def reduce_sum(x):
    return tf.reduce_sum(x)


def softmax(x, dim):
    return tf.math.softmax(x, axis=dim)


def spmm(x, y):
    return tf.sparse.sparse_dense_matmul(x, y)


def add(a, b):
    return a + b


def sub(a, b):
    return a - b


def mul(a, b):
    return a * b


def div(a, b):
    return a / b


81
82
def sum(x, dim, keepdims=False):
    return tf.reduce_sum(x, axis=dim, keepdims=keepdims)
VoVAllen's avatar
VoVAllen committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102


def max(x, dim):
    return tf.reduce_max(x, axis=dim)


def min(x, dim):
    return tf.reduce_min(x, axis=dim)


def prod(x, dim):
    return tf.reduce_prod(x, axis=dim)


def matmul(a, b):
    return tf.linalg.matmul(a, b)


def dot(a, b):
    return sum(mul(a, b), dim=-1)
103

104

105
106
def abs(a):
    return tf.abs(a)
107
108
109
110


def seed(a):
    return tf.random.set_seed(a)