__init__.py 1.89 KB
Newer Older
1
2
3
4
5
from __future__ import absolute_import

import torch as th

def cuda():
6
    return th.device('cuda:0')
7

8
9
10
def is_cuda_available():
    return th.cuda.is_available()

11
def array_equal(a, b):
12
    return th.equal(a.cpu(), b.cpu())
13

14
def allclose(a, b, rtol=1e-4, atol=1e-4):
15
    return th.allclose(a.float().cpu(),
16
            b.float().cpu(), rtol=rtol, atol=atol)
17
18
19
20
21
22
23
24
25
26
27
28

def randn(shape):
    return th.randn(*shape)

def attach_grad(x):
    if x.grad is not None:
        x.grad.zero_()
        return x
    else:
        return x.requires_grad_()

def backward(x, head_gradient=None):
VoVAllen's avatar
VoVAllen committed
29
30
31
    if head_gradient is not None and head_gradient.shape[0] == 1 and len(head_gradient.shape) == 1:
        # Fix for torch 1.3.1
        head_gradient = th.tensor(head_gradient.item()).to(head_gradient.device)
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    x.backward(head_gradient)

def grad(x):
    return x.grad

def is_no_grad(x):
    return x.grad is None or (x.grad == 0).all()

def full(shape, fill_value, dtype, ctx):
    return th.full(shape, fill_value, dtype=dtype, device=ctx)

def narrow_row_set(x, start, stop, new):
    x[start:stop] = new

def sparse_to_numpy(x):
    return x.to_dense().numpy()

def clone(x):
    return x.clone()

def reduce_sum(x):
    return x.sum()

55
56
def softmax(x, dim):
    return th.softmax(x, dim)
57

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def spmm(x, y):
    return th.spmm(x, y)

def add(a, b):
    return a + b

def sub(a, b):
    return a - b

def mul(a, b):
    return a * b

def div(a, b):
    return a / b

73
74
def sum(x, dim, keepdims=False):
    return x.sum(dim, keepdims=keepdims)
75
76
77
78
79
80
81
82
83
84

def max(x, dim):
    return x.max(dim)[0]

def min(x, dim):
    return x.min(dim)[0]

def prod(x, dim):
    return x.prod(dim)

85
86
87
def matmul(a, b):
    return a @ b

88
89
90
def dot(a, b):
    return sum(mul(a, b), dim=-1)

91
92
93
94
95
96
97
98
99
100
101
class record_grad(object):
    def __init__(self):
        pass

    def __enter__(self):
        pass

    def __exit__(self, exc_type, exc_value, exc_traceback):
        pass

no_grad = th.no_grad