__init__.py 1.2 KB
Newer Older
1
2
3
4
5
from __future__ import absolute_import

import torch as th

def cuda():
6
    return th.device('cuda:0')
7

8
9
10
def is_cuda_available():
    return th.cuda.is_available()

11
def array_equal(a, b):
12
    return th.equal(a.cpu(), b.cpu())
13

14
def allclose(a, b, rtol=1e-4, atol=1e-4):
15
    return th.allclose(a.float().cpu(),
16
            b.float().cpu(), rtol=rtol, atol=atol)
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

def randn(shape):
    return th.randn(*shape)

def full(shape, fill_value, dtype, ctx):
    return th.full(shape, fill_value, dtype=dtype, device=ctx)

def narrow_row_set(x, start, stop, new):
    x[start:stop] = new

def sparse_to_numpy(x):
    return x.to_dense().numpy()

def clone(x):
    return x.clone()

def reduce_sum(x):
    return x.sum()

36
37
def softmax(x, dim):
    return th.softmax(x, dim)
38

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def spmm(x, y):
    return th.spmm(x, y)

def add(a, b):
    return a + b

def sub(a, b):
    return a - b

def mul(a, b):
    return a * b

def div(a, b):
    return a / b

54
55
def sum(x, dim, keepdims=False):
    return x.sum(dim, keepdims=keepdims)
56
57
58
59
60
61
62
63
64
65

def max(x, dim):
    return x.max(dim)[0]

def min(x, dim):
    return x.min(dim)[0]

def prod(x, dim):
    return x.prod(dim)

66
67
68
def matmul(a, b):
    return a @ b

69
70
def dot(a, b):
    return sum(mul(a, b), dim=-1)