__init__.py 1.57 KB
Newer Older
1
2
3
4
5
from __future__ import absolute_import

import torch as th

def cuda():
6
    return th.device('cuda:0')
7

8
9
10
def is_cuda_available():
    return th.cuda.is_available()

11
def array_equal(a, b):
12
    return th.equal(a.cpu(), b.cpu())
13

14
def allclose(a, b, rtol=1e-4, atol=1e-4):
15
    return th.allclose(a.float().cpu(),
16
            b.float().cpu(), rtol=rtol, atol=atol)
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

def randn(shape):
    return th.randn(*shape)

def attach_grad(x):
    if x.grad is not None:
        x.grad.zero_()
        return x
    else:
        return x.requires_grad_()

def backward(x, head_gradient=None):
    x.backward(head_gradient)

def grad(x):
    return x.grad

def is_no_grad(x):
    return x.grad is None or (x.grad == 0).all()

def full(shape, fill_value, dtype, ctx):
    return th.full(shape, fill_value, dtype=dtype, device=ctx)

def narrow_row_set(x, start, stop, new):
    x[start:stop] = new

def sparse_to_numpy(x):
    return x.to_dense().numpy()

def clone(x):
    return x.clone()

def reduce_sum(x):
    return x.sum()

52
53
def softmax(x, dim):
    return th.softmax(x, dim)
54

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def spmm(x, y):
    return th.spmm(x, y)

def add(a, b):
    return a + b

def sub(a, b):
    return a - b

def mul(a, b):
    return a * b

def div(a, b):
    return a / b

def sum(x, dim):
    return x.sum(dim)

def max(x, dim):
    return x.max(dim)[0]

def min(x, dim):
    return x.min(dim)[0]

def prod(x, dim):
    return x.prod(dim)

82
83
84
85
86
87
88
89
90
91
92
class record_grad(object):
    def __init__(self):
        pass

    def __enter__(self):
        pass

    def __exit__(self, exc_type, exc_value, exc_traceback):
        pass

no_grad = th.no_grad