tensor.py 6.81 KB
Newer Older
1
2
3
4
5
from __future__ import absolute_import

import numpy as np
import mxnet as mx
import mxnet.ndarray as nd
6
import numbers
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

def data_type_dict():
    return {'float16' : np.float16,
            'float32' : np.float32,
            'float64' : np.float64,
            'uint8'   : np.uint8,
            'int8'    : np.int8,
            'int16'   : np.int16,
            'int32'   : np.int32,
            'int64'   : np.int64}

def cpu():
    return mx.cpu()

def tensor(data, dtype=None):
22
23
24
25
26
27
28
    # MXNet always returns a float tensor regardless of type inside data.
    # This is a workaround.
    if dtype is None:
        if isinstance(data[0], numbers.Integral):
            dtype = np.int64
        else:
            dtype = np.float32
29
30
    return nd.array(data, dtype=dtype)

31
32
33
34
35
36
37
38
def get_preferred_sparse_format():
    """Get the preferred sparse matrix format supported by the backend.

    Different backends have their preferred backend. This info is useful when
    constructing a sparse matrix.
    """
    return "csr"

39
40
41
42
43
44
45
def sparse_matrix(data, index, shape, force_format=False):
    fmt = index[0]
    if fmt == 'coo':
        if force_format:
            raise TypeError('MXNet backend only supports CSR format,'
                            ' but COO format is forced.')
        coord = index[1]
46
47
48
49
        # generate convert idx
        # FIXME: cannot use int64
        tmp_data = nd.arange(len(coord[0]), dtype=data.dtype, ctx=coord[0].context)
        tmp_spmat = nd.sparse.csr_matrix((tmp_data, (coord[0], coord[1])),
50
                tuple(shape), ctx=data.context)
51
52
53
54
55
56
        convert_idx = nd.cast(tmp_spmat.data, dtype='int64')
        # shuffle the data
        data = data[convert_idx]
        spmat = nd.sparse.csr_matrix((data, tmp_spmat.indices, tmp_spmat.indptr),
                tuple(shape), ctx=data.context)
        return spmat, convert_idx
57
58
59
    elif fmt == 'csr':
        indices = index[1]
        indptr = index[2]
60
        spmat = nd.sparse.csr_matrix((data, indices, indptr),
61
                tuple(shape), ctx=data.context)
62
63
        # No conversion is required.
        return spmat, None
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    else:
        raise TypeError('Invalid format: %s.' % fmt)

def sparse_matrix_indices(spmat):
    return ('csr', spmat.indices, spmat.indptr)

def is_tensor(obj):
    return isinstance(obj, nd.NDArray)

def shape(input):
    # NOTE: the input cannot be a symbol
    return input.shape

def dtype(input):
    # NOTE: the input cannot be a symbol
    return input.dtype

Gan Quan's avatar
Gan Quan committed
81
82
83
def ndim(input):
    return input.ndim

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def context(input):
    return input.context

def astype(input, ty):
    return nd.cast(input, ty)

def asnumpy(input):
    return input.asnumpy()

def copy_to(input, ctx):
    return input.as_in_context(ctx)

def sum(input, dim):
    return nd.sum(input, axis=dim)

Gan Quan's avatar
Gan Quan committed
99
100
101
def mean(input, dim):
    return nd.mean(input, axis=dim)

102
def max(input, dim):
103
    return nd.max(input, axis=dim)
104
105
106
107

def cat(seq, dim):
    return nd.concat(*seq, dim=dim)

Gan Quan's avatar
Gan Quan committed
108
def stack(seq, dim):
109
    return nd.stack(*seq, axis=dim)
Gan Quan's avatar
Gan Quan committed
110

111
def split(x, sizes_or_sections, dim):
112
    if isinstance(sizes_or_sections, list) or isinstance(sizes_or_sections, np.ndarray):
113
114
115
116
117
118
119
120
121
        # TODO: fallback to numpy is unfortunate
        np_arr = x.asnumpy()
        indices = np.cumsum(sizes_or_sections)[:-1]
        res = np.split(np_arr, indices, axis=dim)
        return [tensor(arr, dtype=x.dtype) for arr in res]
    else:
        return nd.split(x, sizes_or_sections, axis=dim)

def gather_row(data, row_index):
122
123
124
125
    # MXNet workaround for empty row index
    if len(row_index) == 0:
        return data[0:0]

126
127
128
129
130
131
    if isinstance(row_index, nd.NDArray):
        return nd.take(data, row_index)
    else:
        return data[row_index,]

def narrow_row(data, start, stop):
132
    return data[start:stop]
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

def scatter_row(data, row_index, value):
    return mx.nd.contrib.index_copy(data, row_index, value)

def scatter_row_inplace(data, row_index, value):
    data[row_index] = value

def squeeze(input, dim):
    return nd.squeeze(input, axis=dim)

def unsqueeze(input, dim):
    return nd.expand_dims(input, axis=dim)

def reshape(input, shape):
    # NOTE: the input cannot be a symbol
    return nd.reshape(input ,shape)

150
151
def zeros(shape, dtype, ctx):
    return nd.zeros(shape, dtype=dtype, ctx=ctx)
152

153
154
155
def zeros_like(input):
    return nd.zeros_like(input)

156
157
def ones(shape, dtype, ctx):
    return nd.ones(shape, dtype=dtype, ctx=ctx)
158
159
160
161

def spmm(x, y):
    return nd.dot(x, y)

Gan Quan's avatar
Gan Quan committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
    # TODO: support other dimensions
    assert dim == 0, 'MXNet only supports segment sum on first dimension'

    # Use SPMV to simulate segment sum
    ctx = input.context
    n_inputs = input.shape[0]
    input_shape_suffix = input.shape[1:]
    input = input.reshape(n_inputs, -1)
    n_range = nd.arange(n_inputs, dtype='int64').as_in_context(input.context)
    w_nnz = nd.ones(n_inputs).as_in_context(input.context)
    w_nid = nd.stack(seg_id, n_range, axis=0)
    w = nd.sparse.csr_matrix((w_nnz, (seg_id, n_range)), (n_segs, n_inputs))
    w = w.as_in_context(input.context)
    y = nd.dot(w, input)
    y = nd.reshape(y, (n_segs,) + input_shape_suffix)
    return y

def unsorted_1d_segment_mean(input, seg_id, n_segs, dim):
    # TODO: support other dimensions
    assert dim == 0, 'MXNet only supports segment mean on first dimension'

    n_ones = nd.ones_like(seg_id).astype(input.dtype)
    w = unsorted_1d_segment_sum(n_ones, seg_id, n_segs, 0)
    w = nd.clip(w, a_min=1, a_max=np.inf)
    y = unsorted_1d_segment_sum(input, seg_id, n_segs, dim)
188
    y = y / w.reshape((-1,) + (1,) * (y.ndim - 1))
Gan Quan's avatar
Gan Quan committed
189
190
    return y

191
192
193
194
195
196
197
198
199
def boolean_mask(input, mask):
    return mx.contrib.nd.boolean_mask(input, mask)

def equal(x, y):
    return x == y

def logical_not(input):
    return nd.logical_not(input)

200
201
202
203
204
205
def unique(input):
    # TODO: fallback to numpy is unfortunate
    tmp = input.asnumpy()
    tmp = np.unique(tmp)
    return nd.array(tmp, ctx=input.context, dtype=input.dtype)

206
207
def full_1d(length, fill_value, dtype, ctx):
    return nd.full((length,), fill_value, dtype=dtype, ctx=ctx)
208
209
210
211
212
213
214
215
216

def nonzero_1d(input):
    # TODO: fallback to numpy is unfortunate
    tmp = input.asnumpy()
    tmp = np.nonzero(tmp)[0]
    return nd.array(tmp, ctx=input.context, dtype=input.dtype)

def sort_1d(input):
    # TODO: this isn't an ideal implementation.
217
    val = nd.sort(input, axis=None, is_ascend=True)
218
219
220
221
222
223
224
    idx = nd.argsort(input, is_ascend=True)
    idx = nd.cast(idx, dtype='int64')
    return val, idx

def arange(start, stop):
    return nd.arange(start, stop, dtype=np.int64)

225
226
227
def rand_shuffle(arr):
    return mx.nd.random.shuffle(arr)

228
229
230
231
232
233
234
235
236
237
238
239
240
def zerocopy_to_dlpack(arr):
    return arr.to_dlpack_for_read()

def zerocopy_from_dlpack(dlpack_arr):
    return nd.from_dlpack(dlpack_arr)

def zerocopy_to_numpy(arr):
    # NOTE: not zerocopy
    return arr.asnumpy()

def zerocopy_from_numpy(np_data):
    # NOTE: not zerocopy
    return nd.array(np_data, dtype=np_data.dtype)