test_jit_module.py 5.04 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import sys
import inspect
import contextlib
import numpy as np
import logging
from io import StringIO

import unittest
from numba.tests.support import SerialMixin, create_temp_module
from numba.core import dispatcher


@contextlib.contextmanager
def captured_logs(l):
    try:
        buffer = StringIO()
        handler = logging.StreamHandler(buffer)
        l.addHandler(handler)
        yield buffer
    finally:
        l.removeHandler(handler)


class TestJitModule(SerialMixin, unittest.TestCase):

    source_lines = """
from numba import jit_module

def inc(x):
    return x + 1

def add(x, y):
    return x + y

def inc_add(x):
    y = inc(x)
    return add(x, y)

import numpy as np
mean = np.mean

class Foo(object):
    pass

jit_module({jit_options})
"""

    def test_create_temp_jitted_module(self):
        sys_path_original = list(sys.path)
        sys_modules_original = dict(sys.modules)
        with create_temp_module(self.source_lines) as test_module:
            temp_module_dir = os.path.dirname(test_module.__file__)
            self.assertEqual(temp_module_dir, sys.path[0])
            self.assertEqual(sys.path[1:], sys_path_original)
            self.assertTrue(test_module.__name__ in sys.modules)
        # Test that modifications to sys.path / sys.modules are reverted
        self.assertEqual(sys.path, sys_path_original)
        self.assertEqual(sys.modules, sys_modules_original)

    def test_create_temp_jitted_module_with_exception(self):
        try:
            sys_path_original = list(sys.path)
            sys_modules_original = dict(sys.modules)
            with create_temp_module(self.source_lines):
                raise ValueError("Something went wrong!")
        except ValueError:
            # Test that modifications to sys.path / sys.modules are reverted
            self.assertEqual(sys.path, sys_path_original)
            self.assertEqual(sys.modules, sys_modules_original)

    def test_jit_module(self):
        with create_temp_module(self.source_lines) as test_module:
            self.assertIsInstance(test_module.inc, dispatcher.Dispatcher)
            self.assertIsInstance(test_module.add, dispatcher.Dispatcher)
            self.assertIsInstance(test_module.inc_add, dispatcher.Dispatcher)
            self.assertTrue(test_module.mean is np.mean)
            self.assertTrue(inspect.isclass(test_module.Foo))

            # Test output of jitted functions is as expected
            x, y = 1.7, 2.3
            self.assertEqual(test_module.inc(x),
                             test_module.inc.py_func(x))
            self.assertEqual(test_module.add(x, y),
                             test_module.add.py_func(x, y))
            self.assertEqual(test_module.inc_add(x),
                             test_module.inc_add.py_func(x))

    def test_jit_module_jit_options(self):
        jit_options = {"nopython": True,
                       "nogil": False,
                       "error_model": "numpy",
                       "boundscheck": False,
                       }
        with create_temp_module(self.source_lines,
                                **jit_options) as test_module:
            self.assertEqual(test_module.inc.targetoptions, jit_options)

    def test_jit_module_jit_options_override(self):
        source_lines = """
from numba import jit, jit_module

@jit(nogil=True, forceobj=True)
def inc(x):
    return x + 1

def add(x, y):
    return x + y

jit_module({jit_options})
"""
        jit_options = {"nopython": True,
                       "error_model": "numpy",
                       "boundscheck": False,
                       }
        with create_temp_module(source_lines=source_lines,
                                **jit_options) as test_module:
            self.assertEqual(test_module.add.targetoptions, jit_options)
            # Test that manual jit-wrapping overrides jit_module options
            self.assertEqual(test_module.inc.targetoptions,
                             {'nogil': True, 'forceobj': True,
                              'boundscheck': None})

    def test_jit_module_logging_output(self):
        logger = logging.getLogger('numba.core.decorators')
        logger.setLevel(logging.DEBUG)
        jit_options = {"nopython": True,
                       "error_model": "numpy",
                       }
        with captured_logs(logger) as logs:
            with create_temp_module(self.source_lines,
                                    **jit_options) as test_module:
                logs = logs.getvalue()
                expected = ["Auto decorating function",
                            "from module {}".format(test_module.__name__),
                            "with jit and options: {}".format(jit_options)]
                self.assertTrue(all(i in logs for i in expected))

    def test_jit_module_logging_level(self):
        logger = logging.getLogger('numba.core.decorators')
        # Test there's no logging for INFO level
        logger.setLevel(logging.INFO)
        with captured_logs(logger) as logs:
            with create_temp_module(self.source_lines):
                self.assertEqual(logs.getvalue(), '')