__init__.py 1.28 KB
Newer Older
1
2
3
4
from __future__ import absolute_import

import torch as th

5

6
def cuda():
7
    return th.device("cuda")
8

9

10
11
12
def is_cuda_available():
    return th.cuda.is_available()

13

14
def array_equal(a, b):
15
    return th.equal(a.cpu(), b.cpu())
16

17

18
def allclose(a, b, rtol=1e-4, atol=1e-4):
19
20
    return th.allclose(a.float().cpu(), b.float().cpu(), rtol=rtol, atol=atol)

21
22
23
24

def randn(shape):
    return th.randn(*shape)

25

26
27
28
def full(shape, fill_value, dtype, ctx):
    return th.full(shape, fill_value, dtype=dtype, device=ctx)

29

30
31
32
def narrow_row_set(x, start, stop, new):
    x[start:stop] = new

33

34
35
36
def sparse_to_numpy(x):
    return x.to_dense().numpy()

37

38
39
40
def clone(x):
    return x.clone()

41

42
43
44
def reduce_sum(x):
    return x.sum()

45

46
47
def softmax(x, dim):
    return th.softmax(x, dim)
48

49

50
51
52
def spmm(x, y):
    return th.spmm(x, y)

53

54
55
56
def add(a, b):
    return a + b

57

58
59
60
def sub(a, b):
    return a - b

61

62
63
64
def mul(a, b):
    return a * b

65

66
67
68
def div(a, b):
    return a / b

69

70
71
def sum(x, dim, keepdims=False):
    return x.sum(dim, keepdims=keepdims)
72

73

74
75
76
def max(x, dim):
    return x.max(dim)[0]

77

78
79
80
def min(x, dim):
    return x.min(dim)[0]

81

82
83
84
def prod(x, dim):
    return x.prod(dim)

85

86
87
88
def matmul(a, b):
    return a @ b

89

90
91
def dot(a, b):
    return sum(mul(a, b), dim=-1)
92

93

94
95
def abs(a):
    return a.abs()
96
97
98
99


def seed(a):
    return th.manual_seed(a)