__init__.py 1.46 KB
Newer Older
1
2
3
4
from __future__ import absolute_import

import mxnet as mx
import mxnet.ndarray as nd
5
6
import numpy as np

7
8
9
10

def cuda():
    return mx.gpu()

11

12
13
14
15
16
17
18
19
def is_cuda_available():
    # TODO: Does MXNet have a convenient function to test GPU availability/compilation?
    try:
        a = nd.array([1, 2, 3], ctx=mx.gpu())
        return True
    except mx.MXNetError:
        return False

20

21
22
23
def array_equal(a, b):
    return nd.equal(a, b).asnumpy().all()

24

25
26
def allclose(a, b, rtol=1e-4, atol=1e-4):
    return np.allclose(a.asnumpy(), b.asnumpy(), rtol=rtol, atol=atol)
27

28

29
30
31
def randn(shape):
    return nd.random.randn(*shape)

32

33
34
35
def full(shape, fill_value, dtype, ctx):
    return nd.full(shape, fill_value, dtype=dtype, ctx=ctx)

36

37
38
39
def narrow_row_set(x, start, stop, new):
    x[start:stop] = new

40

41
42
43
def sparse_to_numpy(x):
    return x.asscipy().todense().A

44

45
46
47
def clone(x):
    return x.copy()

48

49
50
51
def reduce_sum(x):
    return x.sum()

52

53
def softmax(x, dim):
54
    return nd.softmax(x, axis=dim)
55

56

57
58
59
def spmm(x, y):
    return nd.dot(x, y)

60

61
62
63
def add(a, b):
    return a + b

64

65
66
67
def sub(a, b):
    return a - b

68

69
70
71
def mul(a, b):
    return a * b

72

73
74
75
def div(a, b):
    return a / b

76

77
78
def sum(x, dim, keepdims=False):
    return x.sum(dim, keepdims=keepdims)
79

80

81
82
83
def max(x, dim):
    return x.max(dim)

84

85
86
87
def min(x, dim):
    return x.min(dim)

88

89
90
91
def prod(x, dim):
    return x.prod(dim)

92

93
94
95
def matmul(a, b):
    return nd.dot(a, b)

96

97
98
def dot(a, b):
    return nd.sum(mul(a, b), axis=-1)
99

100

101
102
def abs(a):
    return nd.abs(a)