__init__.py 1.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
from __future__ import absolute_import

import numpy as np
import mxnet as mx
import mxnet.ndarray as nd
import mxnet.autograd as autograd

def cuda():
    return mx.gpu()

11
12
13
14
15
16
17
18
def is_cuda_available():
    # TODO: Does MXNet have a convenient function to test GPU availability/compilation?
    try:
        a = nd.array([1, 2, 3], ctx=mx.gpu())
        return True
    except mx.MXNetError:
        return False

19
20
21
def array_equal(a, b):
    return nd.equal(a, b).asnumpy().all()

22
23
def allclose(a, b, rtol=1e-4, atol=1e-4):
    return np.allclose(a.asnumpy(), b.asnumpy(), rtol=rtol, atol=atol)
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

def randn(shape):
    return nd.random.randn(*shape)

def attach_grad(x):
    x.attach_grad()
    return x

def backward(x, head_gradient=None):
    x.backward(head_gradient)

def grad(x):
    return x.grad

def is_no_grad(x):
    return (x != 0).sum() == 0

def full(shape, fill_value, dtype, ctx):
    return nd.full(shape, fill_value, dtype=dtype, ctx=ctx)

def narrow_row_set(x, start, stop, new):
    x[start:stop] = new

def sparse_to_numpy(x):
    return x.asscipy().todense().A

def clone(x):
    return x.copy()

def reduce_sum(x):
    return x.sum()

56
def softmax(x, dim):
57
    return nd.softmax(x, axis=dim)
58

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def spmm(x, y):
    return nd.dot(x, y)

def add(a, b):
    return a + b

def sub(a, b):
    return a - b

def mul(a, b):
    return a * b

def div(a, b):
    return a / b

74
75
def sum(x, dim, keepdims=False):
    return x.sum(dim, keepdims=keepdims)
76
77
78
79
80
81
82
83
84
85

def max(x, dim):
    return x.max(dim)

def min(x, dim):
    return x.min(dim)

def prod(x, dim):
    return x.prod(dim)

86
87
88
def matmul(a, b):
    return nd.dot(a, b)

89
90
91
def dot(a, b):
    return nd.sum(mul(a, b), axis=-1)

92
93
94
95
96
97
98
99
100
101
102
103
record_grad = autograd.record


class no_grad(object):
    def __init__(self):
        pass

    def __enter__(self):
        pass

    def __exit__(self, exc_type, exc_value, exc_traceback):
        pass