test_pipeline.py 5.61 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from numba.core.compiler import Compiler, DefaultPassBuilder
from numba.core.compiler_machinery import (FunctionPass, AnalysisPass,
                                           register_pass)
from numba.core.untyped_passes import InlineInlinables
from numba.core.typed_passes import IRLegalization
from numba import jit, generated_jit, objmode, njit, cfunc
from numba.core import types, postproc, errors
from numba.core.ir import FunctionIR
from numba.tests.support import TestCase


class TestCustomPipeline(TestCase):
    def setUp(self):
        super(TestCustomPipeline, self).setUp()

        # Define custom pipeline class
        class CustomPipeline(Compiler):
            custom_pipeline_cache = []

            def compile_extra(self, func):
                # Store the compiled function
                self.custom_pipeline_cache.append(func)
                return super(CustomPipeline, self).compile_extra(func)

            def compile_ir(self, func_ir, *args, **kwargs):
                # Store the compiled function
                self.custom_pipeline_cache.append(func_ir)
                return super(CustomPipeline, self).compile_ir(
                    func_ir, *args, **kwargs)

        self.pipeline_class = CustomPipeline

    def test_jit_custom_pipeline(self):
        self.assertListEqual(self.pipeline_class.custom_pipeline_cache, [])

        @jit(pipeline_class=self.pipeline_class)
        def foo(x):
            return x

        self.assertEqual(foo(4), 4)
        self.assertListEqual(self.pipeline_class.custom_pipeline_cache,
                             [foo.py_func])

    def test_cfunc_custom_pipeline(self):
        self.assertListEqual(self.pipeline_class.custom_pipeline_cache, [])

        @cfunc(types.int64(types.int64), pipeline_class=self.pipeline_class)
        def foo(x):
            return x

        self.assertEqual(foo(4), 4)
        self.assertListEqual(self.pipeline_class.custom_pipeline_cache,
                             [foo.__wrapped__])

    def test_generated_jit_custom_pipeline(self):
        self.assertListEqual(self.pipeline_class.custom_pipeline_cache, [])

        def inner(x):
            return x

        @generated_jit(pipeline_class=self.pipeline_class)
        def foo(x):
            if isinstance(x, types.Integer):
                return inner

        self.assertEqual(foo(5), 5)
        self.assertListEqual(self.pipeline_class.custom_pipeline_cache,
                             [inner])

    def test_objmode_custom_pipeline(self):
        self.assertListEqual(self.pipeline_class.custom_pipeline_cache, [])

        @jit(pipeline_class=self.pipeline_class)
        def foo(x):
            with objmode(x="intp"):
                x += int(0x1)
            return x

        arg = 123
        self.assertEqual(foo(arg), arg + 1)
        # Two items in the list.
        self.assertEqual(len(self.pipeline_class.custom_pipeline_cache), 2)
        # First item is the `foo` function
        first = self.pipeline_class.custom_pipeline_cache[0]
        self.assertIs(first, foo.py_func)
        # Second item is a FunctionIR of the obj-lifted function
        second = self.pipeline_class.custom_pipeline_cache[1]
        self.assertIsInstance(second, FunctionIR)


class TestPassManagerFunctionality(TestCase):

    def _create_pipeline_w_del(self, base=None, inject_after=None):
        """
        Creates a new compiler pipeline with the _InjectDelsPass injected after
        the pass supplied in kwarg 'inject_after'.
        """
        self.assertTrue(inject_after is not None)
        self.assertTrue(base is not None)

        @register_pass(mutates_CFG=False, analysis_only=False)
        class _InjectDelsPass(base):
            """
            This pass injects ir.Del nodes into the IR
            """
            _name = "inject_dels_%s" % str(base)

            def __init__(self):
                base.__init__(self)

            def run_pass(self, state):
                pp = postproc.PostProcessor(state.func_ir)
                pp.run(emit_dels=True)
                return True

        class TestCompiler(Compiler):

            def define_pipelines(self):
                pm = DefaultPassBuilder.define_nopython_pipeline(self.state)
                pm.add_pass_after(_InjectDelsPass, inject_after)
                pm.finalize()
                return [pm]

        return TestCompiler

    def test_compiler_error_on_ir_del_from_functionpass(self):
        new_compiler = self._create_pipeline_w_del(FunctionPass,
                                                   InlineInlinables)

        @njit(pipeline_class=new_compiler)
        def foo(x):
            return x + 1

        with self.assertRaises(errors.CompilerError) as raises:
            foo(10)

        errstr = str(raises.exception)

        self.assertIn("Illegal IR, del found at:", errstr)
        self.assertIn("del x", errstr)

    def test_no_compiler_error_on_ir_del_after_legalization(self):
        # Legalization should be the last FunctionPass to execute so it's fine
        # for it to emit ir.Del nodes as no further FunctionPasses will run and
        # therefore the checking routine in the PassManager won't execute.
        # This test adds a new pass that is an AnalysisPass into the pipeline
        # after legalisation, this pass will return with already existing dels
        # in the IR but by virtue of it being an AnalysisPass the checking
        # routine won't execute.

        new_compiler = self._create_pipeline_w_del(AnalysisPass,
                                                   IRLegalization)

        @njit(pipeline_class=new_compiler)
        def foo(x):
            return x + 1

        self.assertTrue(foo(10), foo.py_func(10))