test_api.py 2.55 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import warnings

import numba
from numba import jit, njit

from numba.tests.support import TestCase, always_test
import unittest


class TestNumbaModule(TestCase):
    """
    Test the APIs exposed by the top-level `numba` module.
    """

    def check_member(self, name):
        self.assertTrue(hasattr(numba, name), name)
        self.assertIn(name, numba.__all__)

    @always_test
    def test_numba_module(self):
        # jit
        self.check_member("jit")
        self.check_member("vectorize")
        self.check_member("guvectorize")
        self.check_member("njit")
        # errors
        self.check_member("NumbaError")
        self.check_member("TypingError")
        # types
        self.check_member("int32")
        # misc
        numba.__version__  # not in __all__


class TestJitDecorator(TestCase):
    """
    Test the jit and njit decorators
    """
    def test_jit_nopython_forceobj(self):
        with self.assertRaises(ValueError) as cm:
            jit(nopython=True, forceobj=True)
        self.assertIn(
            "Only one of 'nopython' or 'forceobj' can be True.",
            str(cm.exception)
        )

        def py_func(x):
            return x

        jit_func = jit(nopython=True)(py_func)
        jit_func(1)
        # Check length of nopython_signatures to check
        # which mode the function was compiled in
        self.assertEqual(len(jit_func.nopython_signatures), 1)

        jit_func = jit(forceobj=True)(py_func)
        jit_func(1)
        self.assertEqual(len(jit_func.nopython_signatures), 0)

    def test_njit_nopython_forceobj(self):
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter('always', RuntimeWarning)
            njit(forceobj=True)
        self.assertEqual(len(w), 1)
        self.assertIn(
            'forceobj is set for njit and is ignored', str(w[0].message)
        )

        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter('always', RuntimeWarning)
            njit(nopython=True)
        self.assertEqual(len(w), 1)
        self.assertIn(
            'nopython is set for njit and is ignored', str(w[0].message)
        )

        def py_func(x):
            return x

        jit_func = njit(nopython=True)(py_func)
        jit_func(1)
        self.assertEqual(len(jit_func.nopython_signatures), 1)

        jit_func = njit(forceobj=True)(py_func)
        jit_func(1)
        # Since forceobj is ignored this has to compile in nopython mode
        self.assertEqual(len(jit_func.nopython_signatures), 1)


if __name__ == '__main__':
    unittest.main()