"docs/source/models/inception_quant.rst" did not exist on "8edd920d4d8bc07d215c16261cbf8ea3edd37b02"
__init__.py 1.86 KB
Newer Older
1
2
3
4
5
from __future__ import absolute_import

import torch as th

def cuda():
6
    return th.device('cuda:0')
7

8
9
10
def is_cuda_available():
    return th.cuda.is_available()

11
def array_equal(a, b):
12
    return th.equal(a.cpu(), b.cpu())
13

14
def allclose(a, b, rtol=1e-4, atol=1e-4):
15
    return th.allclose(a.float().cpu(),
16
            b.float().cpu(), rtol=rtol, atol=atol)
17
18
19
20
21
22
23
24
25
26
27
28

def randn(shape):
    return th.randn(*shape)

def attach_grad(x):
    if x.grad is not None:
        x.grad.zero_()
        return x
    else:
        return x.requires_grad_()

def backward(x, head_gradient=None):
VoVAllen's avatar
VoVAllen committed
29
30
31
    if head_gradient is not None and head_gradient.shape[0] == 1 and len(head_gradient.shape) == 1:
        # Fix for torch 1.3.1
        head_gradient = th.tensor(head_gradient.item()).to(head_gradient.device)
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    x.backward(head_gradient)

def grad(x):
    return x.grad

def is_no_grad(x):
    return x.grad is None or (x.grad == 0).all()

def full(shape, fill_value, dtype, ctx):
    return th.full(shape, fill_value, dtype=dtype, device=ctx)

def narrow_row_set(x, start, stop, new):
    x[start:stop] = new

def sparse_to_numpy(x):
    return x.to_dense().numpy()

def clone(x):
    return x.clone()

def reduce_sum(x):
    return x.sum()

55
56
def softmax(x, dim):
    return th.softmax(x, dim)
57

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def spmm(x, y):
    return th.spmm(x, y)

def add(a, b):
    return a + b

def sub(a, b):
    return a - b

def mul(a, b):
    return a * b

def div(a, b):
    return a / b

def sum(x, dim):
    return x.sum(dim)

def max(x, dim):
    return x.max(dim)[0]

def min(x, dim):
    return x.min(dim)[0]

def prod(x, dim):
    return x.prod(dim)

85
86
87
def matmul(a, b):
    return a @ b

88
89
90
def dot(a, b):
    return sum(mul(a, b), dim=-1)

91
92
93
94
95
96
97
98
99
100
101
class record_grad(object):
    def __init__(self):
        pass

    def __enter__(self):
        pass

    def __exit__(self, exc_type, exc_value, exc_traceback):
        pass

no_grad = th.no_grad