test_numpy_support.py 17 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
"""
Test helper functions from numba.numpy_support.
"""


import sys
from itertools import product

import numpy as np

import unittest
from numba.core import types
from numba.core.errors import NumbaNotImplementedError
from numba.tests.support import TestCase
from numba.tests.enum_usecases import Shake, RequestError
from numba.np import numpy_support


class TestFromDtype(TestCase):

    def test_number_types(self):
        """
        Test from_dtype() and as_dtype() with the various scalar number types.
        """
        f = numpy_support.from_dtype

        def check(typechar, numba_type):
            # Only native ordering and alignment is supported
            dtype = np.dtype(typechar)
            self.assertIs(f(dtype), numba_type)
            self.assertIs(f(np.dtype('=' + typechar)), numba_type)
            self.assertEqual(dtype, numpy_support.as_dtype(numba_type))

        check('?', types.bool_)
        check('f', types.float32)
        check('f4', types.float32)
        check('d', types.float64)
        check('f8', types.float64)

        check('F', types.complex64)
        check('c8', types.complex64)
        check('D', types.complex128)
        check('c16', types.complex128)

        check('O', types.pyobject)

        check('b', types.int8)
        check('i1', types.int8)
        check('B', types.uint8)
        check('u1', types.uint8)

        check('h', types.int16)
        check('i2', types.int16)
        check('H', types.uint16)
        check('u2', types.uint16)

        check('i', types.int32)
        check('i4', types.int32)
        check('I', types.uint32)
        check('u4', types.uint32)

        check('q', types.int64)
        check('Q', types.uint64)
        for name in ('int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32',
                     'int64', 'uint64', 'intp', 'uintp'):
            self.assertIs(f(np.dtype(name)), getattr(types, name))

        # Non-native alignments are unsupported (except for 1-byte types)
        foreign_align = '>' if sys.byteorder == 'little' else '<'
        for letter in 'hHiIlLqQfdFD':
            self.assertRaises(NumbaNotImplementedError, f,
                              np.dtype(foreign_align + letter))

    def test_string_types(self):
        """
        Test from_dtype() and as_dtype() with the character string types.
        """
        def check(typestring, numba_type):
            # Only native ordering and alignment is supported
            dtype = np.dtype(typestring)
            self.assertEqual(numpy_support.from_dtype(dtype), numba_type)
            self.assertEqual(dtype, numpy_support.as_dtype(numba_type))

        check('S10', types.CharSeq(10))
        check('a11', types.CharSeq(11))
        check('U12', types.UnicodeCharSeq(12))

    def check_datetime_types(self, letter, nb_class):
        def check(dtype, numba_type, code):
            tp = numpy_support.from_dtype(dtype)
            self.assertEqual(tp, numba_type)
            self.assertEqual(tp.unit_code, code)
            self.assertEqual(numpy_support.as_dtype(numba_type), dtype)
            self.assertEqual(numpy_support.as_dtype(tp), dtype)

        # Unit-less ("generic") type
        check(np.dtype(letter), nb_class(''), 14)

    def test_datetime_types(self):
        """
        Test from_dtype() and as_dtype() with the datetime types.
        """
        self.check_datetime_types('M', types.NPDatetime)

    def test_timedelta_types(self):
        """
        Test from_dtype() and as_dtype() with the timedelta types.
        """
        self.check_datetime_types('m', types.NPTimedelta)

    def test_struct_types(self):
        def check(dtype, fields, size, aligned):
            tp = numpy_support.from_dtype(dtype)
            self.assertIsInstance(tp, types.Record)
            # Only check for dtype equality, as the Numba type may be interned
            self.assertEqual(tp.dtype, dtype)
            self.assertEqual(tp.fields, fields)
            self.assertEqual(tp.size, size)
            self.assertEqual(tp.aligned, aligned)

        dtype = np.dtype([('a', np.int16), ('b', np.int32)])
        check(dtype,
              fields={'a': (types.int16, 0, None, None),
                      'b': (types.int32, 2, None, None)},
              size=6, aligned=False)

        dtype = np.dtype([('a', np.int16), ('b', np.int32)], align=True)
        check(dtype,
              fields={'a': (types.int16, 0, None, None),
                      'b': (types.int32, 4, None, None)},
              size=8, aligned=True)

        dtype = np.dtype([('m', np.int32), ('n', 'S5')])
        check(dtype,
              fields={'m': (types.int32, 0, None, None),
                      'n': (types.CharSeq(5), 4, None, None)},
              size=9, aligned=False)

    def test_enum_type(self):

        def check(base_inst, enum_def, type_class):
            np_dt = np.dtype(base_inst)
            nb_ty = numpy_support.from_dtype(np_dt)
            inst = type_class(enum_def, nb_ty)
            recovered = numpy_support.as_dtype(inst)
            self.assertEqual(np_dt, recovered)

        dts = [np.float64, np.int32, np.complex128, np.bool_]
        enums = [Shake, RequestError]

        for dt, enum in product(dts, enums):
            check(dt, enum, types.EnumMember)

        for dt, enum in product(dts, enums):
            check(dt, enum, types.IntEnumMember)


class ValueTypingTestBase(object):
    """
    Common tests for the typing of values.  Also used by test_special.
    """

    def check_number_values(self, func):
        """
        Test *func*() with scalar numeric values.
        """
        f = func
        # Standard Python types get inferred by numpy
        self.assertIn(f(1), (types.int32, types.int64))
        self.assertIn(f(2**31 - 1), (types.int32, types.int64))
        self.assertIn(f(-2**31), (types.int32, types.int64))
        self.assertIs(f(1.0), types.float64)
        self.assertIs(f(1.0j), types.complex128)
        self.assertIs(f(True), types.bool_)
        self.assertIs(f(False), types.bool_)
        # Numpy scalar types get converted by from_dtype()
        for name in ('int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32',
                     'int64', 'uint64', 'intc', 'uintc', 'intp', 'uintp',
                     'float32', 'float64', 'complex64', 'complex128',
                     'bool_'):
            val = getattr(np, name)()
            self.assertIs(f(val), getattr(types, name))

    def _base_check_datetime_values(self, func, np_type, nb_type):
        f = func
        for unit in [
            '', 'Y', 'M', 'D', 'h', 'm', 's',
            'ms', 'us', 'ns', 'ps', 'fs', 'as',
        ]:
            if unit:
                t = np_type(3, unit)
            else:
                # "generic" datetime / timedelta
                t = np_type('Nat')
            tp = f(t)
            # This ensures the unit hasn't been lost
            self.assertEqual(tp, nb_type(unit))

    def check_datetime_values(self, func):
        """
        Test *func*() with np.datetime64 values.
        """
        self._base_check_datetime_values(func, np.datetime64, types.NPDatetime)

    def check_timedelta_values(self, func):
        """
        Test *func*() with np.timedelta64 values.
        """
        self._base_check_datetime_values(func, np.timedelta64,
                                         types.NPTimedelta)


class TestArrayScalars(ValueTypingTestBase, TestCase):

    def test_number_values(self):
        """
        Test map_arrayscalar_type() with scalar number values.
        """
        self.check_number_values(numpy_support.map_arrayscalar_type)

    def test_datetime_values(self):
        """
        Test map_arrayscalar_type() with np.datetime64 values.
        """
        f = numpy_support.map_arrayscalar_type
        self.check_datetime_values(f)
        # datetime64s with a non-one factor shouldn't be supported
        t = np.datetime64('2014', '10Y')
        with self.assertRaises(NotImplementedError):
            f(t)

    def test_timedelta_values(self):
        """
        Test map_arrayscalar_type() with np.timedelta64 values.
        """
        f = numpy_support.map_arrayscalar_type
        self.check_timedelta_values(f)
        # timedelta64s with a non-one factor shouldn't be supported
        t = np.timedelta64(10, '10Y')
        with self.assertRaises(NotImplementedError):
            f(t)


class FakeUFunc(object):
    __slots__ = ('nin', 'nout', 'types', 'ntypes')
    __name__ = "fake ufunc"

    def __init__(self, types):
        self.types = types
        in_, out = self.types[0].split('->')
        self.nin = len(in_)
        self.nout = len(out)
        self.ntypes = len(types)
        for tp in types:
            in_, out = self.types[0].split('->')
            assert len(in_) == self.nin
            assert len(out) == self.nout


# Typical types for np.add, np.multiply, np.isnan
_add_types = ['??->?', 'bb->b', 'BB->B', 'hh->h', 'HH->H', 'ii->i', 'II->I',
              'll->l', 'LL->L', 'qq->q', 'QQ->Q', 'ee->e', 'ff->f', 'dd->d',
              'gg->g', 'FF->F', 'DD->D', 'GG->G', 'Mm->M', 'mm->m', 'mM->M',
              'OO->O']

_mul_types = ['??->?', 'bb->b', 'BB->B', 'hh->h', 'HH->H', 'ii->i', 'II->I',
              'll->l', 'LL->L', 'qq->q', 'QQ->Q', 'ee->e', 'ff->f', 'dd->d',
              'gg->g', 'FF->F', 'DD->D', 'GG->G', 'mq->m', 'qm->m', 'md->m',
              'dm->m', 'OO->O']

# Those ones only have floating-point loops
_isnan_types = ['e->?', 'f->?', 'd->?', 'g->?', 'F->?', 'D->?', 'G->?']
_sqrt_types = ['e->e', 'f->f', 'd->d', 'g->g', 'F->F', 'D->D', 'G->G', 'O->O']


class TestUFuncs(TestCase):
    """
    Test ufunc helpers.
    """

    def test_ufunc_find_matching_loop(self):
        f = numpy_support.ufunc_find_matching_loop
        np_add = FakeUFunc(_add_types)
        np_mul = FakeUFunc(_mul_types)
        np_isnan = FakeUFunc(_isnan_types)
        np_sqrt = FakeUFunc(_sqrt_types)

        def check(ufunc, input_types, sigs, output_types=()):
            """
            Check that ufunc_find_matching_loop() finds one of the given
            *sigs* for *ufunc*, *input_types* and optional *output_types*.
            """
            loop = f(ufunc, input_types + output_types)
            self.assertTrue(loop)
            if isinstance(sigs, str):
                sigs = (sigs,)
            self.assertIn(loop.ufunc_sig, sigs,
                          "inputs=%s and outputs=%s should have selected "
                          "one of %s, got %s"
                          % (input_types, output_types, sigs, loop.ufunc_sig))
            self.assertEqual(len(loop.numpy_inputs), len(loop.inputs))
            self.assertEqual(len(loop.numpy_outputs), len(loop.outputs))
            if not output_types:
                # Add explicit outputs and check the result is the same
                loop_explicit = f(ufunc, list(input_types) + loop.outputs)
                self.assertEqual(loop_explicit, loop)
            else:
                self.assertEqual(loop.outputs, list(output_types))
            # Round-tripping inputs and outputs
            loop_rt = f(ufunc, loop.inputs + loop.outputs)
            self.assertEqual(loop_rt, loop)
            return loop

        def check_exact(ufunc, input_types, sigs, output_types=()):
            """
            Like check(), but also ensure no casting of inputs occurred.
            """
            loop = check(ufunc, input_types, sigs, output_types)
            self.assertEqual(loop.inputs, list(input_types))

        def check_no_match(ufunc, input_types):
            loop = f(ufunc, input_types)
            self.assertIs(loop, None)

        # Exact matching for number types
        check_exact(np_add, (types.bool_, types.bool_), '??->?')
        check_exact(np_add, (types.int8, types.int8), 'bb->b')
        check_exact(np_add, (types.uint8, types.uint8), 'BB->B')
        check_exact(np_add, (types.int64, types.int64), ('ll->l', 'qq->q'))
        check_exact(np_add, (types.uint64, types.uint64), ('LL->L', 'QQ->Q'))
        check_exact(np_add, (types.float32, types.float32), 'ff->f')
        check_exact(np_add, (types.float64, types.float64), 'dd->d')
        check_exact(np_add, (types.complex64, types.complex64), 'FF->F')
        check_exact(np_add, (types.complex128, types.complex128), 'DD->D')

        # Exact matching for datetime64 and timedelta64 types
        check_exact(np_add, (types.NPTimedelta('s'), types.NPTimedelta('s')),
                    'mm->m', output_types=(types.NPTimedelta('s'),))
        check_exact(np_add, (types.NPTimedelta('ms'), types.NPDatetime('s')),
                    'mM->M', output_types=(types.NPDatetime('ms'),))
        check_exact(np_add, (types.NPDatetime('s'), types.NPTimedelta('s')),
                    'Mm->M', output_types=(types.NPDatetime('s'),))
        check_exact(np_add, (types.NPDatetime('s'), types.NPTimedelta('')),
                    'Mm->M', output_types=(types.NPDatetime('s'),))
        check_exact(np_add, (types.NPDatetime('ns'), types.NPTimedelta('')),
                    'Mm->M', output_types=(types.NPDatetime('ns'),))
        check_exact(np_add, (types.NPTimedelta(''), types.NPDatetime('s')),
                    'mM->M', output_types=(types.NPDatetime('s'),))
        check_exact(np_add, (types.NPTimedelta(''), types.NPDatetime('ns')),
                    'mM->M', output_types=(types.NPDatetime('ns'),))
        check_exact(np_mul, (types.NPTimedelta('s'), types.int64),
                    'mq->m', output_types=(types.NPTimedelta('s'),))
        check_exact(np_mul, (types.float64, types.NPTimedelta('s')),
                    'dm->m', output_types=(types.NPTimedelta('s'),))

        # Mix and match number types, with casting
        check(np_add, (types.bool_, types.int8), 'bb->b')
        check(np_add, (types.uint8, types.bool_), 'BB->B')
        check(np_add, (types.int16, types.uint16), 'ii->i')
        check(np_add, (types.complex64, types.float64), 'DD->D')
        check(np_add, (types.float64, types.complex64), 'DD->D')
        # Integers, when used together with floating-point numbers,
        # should cast to any real or complex (see #2006)
        int_types = [types.int32, types.uint32, types.int64, types.uint64]
        for intty in int_types:
            check(np_add, (types.float32, intty), 'ff->f')
            check(np_add, (types.float64, intty), 'dd->d')
            check(np_add, (types.complex64, intty), 'FF->F')
            check(np_add, (types.complex128, intty), 'DD->D')
        # However, when used alone, they should cast only to
        # floating-point types of sufficient precision
        # (typical use case: np.sqrt(2) should give an accurate enough value)
        for intty in int_types:
            check(np_sqrt, (intty,), 'd->d')
            check(np_isnan, (intty,), 'd->?')

        # With some timedelta64 arguments as well
        check(np_mul, (types.NPTimedelta('s'), types.int32),
              'mq->m', output_types=(types.NPTimedelta('s'),))
        check(np_mul, (types.NPTimedelta('s'), types.uint32),
              'mq->m', output_types=(types.NPTimedelta('s'),))
        check(np_mul, (types.NPTimedelta('s'), types.float32),
              'md->m', output_types=(types.NPTimedelta('s'),))
        check(np_mul, (types.float32, types.NPTimedelta('s')),
              'dm->m', output_types=(types.NPTimedelta('s'),))

        # No match
        check_no_match(np_add, (types.NPDatetime('s'), types.NPDatetime('s')))
        # No implicit casting from int64 to timedelta64 (Numpy would allow
        # this).
        check_no_match(np_add, (types.NPTimedelta('s'), types.int64))

    def test_layout_checker(self):
        def check_arr(arr):
            dims = arr.shape
            strides = arr.strides
            itemsize = arr.dtype.itemsize
            is_c = numpy_support.is_contiguous(dims, strides, itemsize)
            is_f = numpy_support.is_fortran(dims, strides, itemsize)
            expect_c = arr.flags['C_CONTIGUOUS']
            expect_f = arr.flags['F_CONTIGUOUS']
            self.assertEqual(is_c, expect_c)
            self.assertEqual(is_f, expect_f)

        arr = np.arange(24)
        # 1D
        check_arr(arr)
        # 2D
        check_arr(arr.reshape((3, 8)))
        check_arr(arr.reshape((3, 8)).T)
        check_arr(arr.reshape((3, 8))[::2])
        # 3D
        check_arr(arr.reshape((2, 3, 4)))
        check_arr(arr.reshape((2, 3, 4)).T)
        # middle axis is shape 1
        check_arr(arr.reshape((2, 3, 4))[:, ::3])
        check_arr(arr.reshape((2, 3, 4)).T[:, ::3])

        # leading axis is shape 1
        check_arr(arr.reshape((2, 3, 4))[::2])
        check_arr(arr.reshape((2, 3, 4)).T[:, :, ::2])
        # 2 leading axis are shape 1
        check_arr(arr.reshape((2, 3, 4))[::2, ::3])
        check_arr(arr.reshape((2, 3, 4)).T[:, ::3, ::2])
        # single item slices for all axis
        check_arr(arr.reshape((2, 3, 4))[::2, ::3, ::4])
        check_arr(arr.reshape((2, 3, 4)).T[::4, ::3, ::2])
        # 4D
        check_arr(arr.reshape((2, 2, 3, 2))[::2, ::2, ::3])
        check_arr(arr.reshape((2, 2, 3, 2)).T[:, ::3, ::2, ::2])
        # outer zero dims
        check_arr(arr.reshape((2, 2, 3, 2))[::5, ::2, ::3])
        check_arr(arr.reshape((2, 2, 3, 2)).T[:, ::3, ::2, ::5])


if __name__ == '__main__':
    unittest.main()