test_inlining.py 11.3 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import re
import numpy as np

from numba.tests.support import (TestCase, override_config, captured_stdout,
                      skip_parfors_unsupported)
from numba import jit, njit
from numba.core import types, ir, postproc, compiler
from numba.core.ir_utils import (guard, find_callname, find_const,
                                 get_definition, simplify_CFG)
from numba.core.registry import CPUDispatcher
from numba.core.inline_closurecall import inline_closure_call

from numba.core.untyped_passes import (ExtractByteCode, TranslateByteCode, FixupArgs,
                             IRProcessing, DeadBranchPrune,
                             RewriteSemanticConstants, GenericRewrites,
                             WithLifting, PreserveIR, InlineClosureLikes)

from numba.core.typed_passes import (NopythonTypeInference, AnnotateTypes,
                           NopythonRewrites, PreParforPass, ParforPass,
                           DumpParforDiagnostics, NativeLowering,
                           NativeParforLowering, IRLegalization,
                           NoPythonBackend, NativeLowering,
                           ParforFusionPass, ParforPreLoweringPass)

from numba.core.compiler_machinery import FunctionPass, PassManager, register_pass
import unittest

@jit((types.int32,), nopython=True)
def inner(a):
    return a + 1

@jit((types.int32,), nopython=True)
def more(a):
    return inner(inner(a))

def outer_simple(a):
    return inner(a) * 2

def outer_multiple(a):
    return inner(a) * more(a)

@njit
def __dummy__():
    return

@register_pass(analysis_only=False, mutates_CFG=True)
class InlineTestPass(FunctionPass):
    _name = "inline_test_pass"

    def __init__(self):
        FunctionPass.__init__(self)

    def run_pass(self, state):
        # assuming the function has one block with one call inside
        assert len(state.func_ir.blocks) == 1
        block = list(state.func_ir.blocks.values())[0]
        for i, stmt in enumerate(block.body):
            if guard(find_callname,state.func_ir, stmt.value) is not None:
                inline_closure_call(state.func_ir, {}, block, i, lambda: None,
                                    state.typingctx, state.targetctx, (),
                                    state.typemap, state.calltypes)
                break
        # also fix up the IR
        post_proc = postproc.PostProcessor(state.func_ir)
        post_proc.run()
        post_proc.remove_dels()
        return True


def gen_pipeline(state, test_pass):
        name = 'inline_test'
        pm = PassManager(name)
        pm.add_pass(TranslateByteCode, "analyzing bytecode")
        pm.add_pass(FixupArgs, "fix up args")
        pm.add_pass(IRProcessing, "processing IR")
        pm.add_pass(WithLifting, "Handle with contexts")
        # pre typing
        if not state.flags.no_rewrites:
            pm.add_pass(GenericRewrites, "nopython rewrites")
            pm.add_pass(RewriteSemanticConstants, "rewrite semantic constants")
            pm.add_pass(DeadBranchPrune, "dead branch pruning")
        pm.add_pass(InlineClosureLikes,
                    "inline calls to locally defined closures")
        # typing
        pm.add_pass(NopythonTypeInference, "nopython frontend")

        if state.flags.auto_parallel.enabled:
            pm.add_pass(PreParforPass, "Preprocessing for parfors")
        if not state.flags.no_rewrites:
            pm.add_pass(NopythonRewrites, "nopython rewrites")
        if state.flags.auto_parallel.enabled:
            pm.add_pass(ParforPass, "convert to parfors")
            pm.add_pass(ParforFusionPass, "fuse parfors")
            pm.add_pass(ParforPreLoweringPass, "parfor prelowering")

        pm.add_pass(test_pass, "inline test")

        # legalise
        pm.add_pass(IRLegalization, "ensure IR is legal prior to lowering")
        pm.add_pass(AnnotateTypes, "annotate types")
        pm.add_pass(PreserveIR, "preserve IR")

        # lower
        if state.flags.auto_parallel.enabled:
            pm.add_pass(NativeParforLowering, "native parfor lowering")
        else:
            pm.add_pass(NativeLowering, "native lowering")
        pm.add_pass(NoPythonBackend, "nopython mode backend")
        pm.add_pass(DumpParforDiagnostics, "dump parfor diagnostics")
        return pm

class InlineTestPipeline(compiler.CompilerBase):
    """compiler pipeline for testing inlining after optimization
    """
    def define_pipelines(self):
        pm = gen_pipeline(self.state, InlineTestPass)
        pm.finalize()
        return [pm]

class TestInlining(TestCase):
    """
    Check that jitted inner functions are inlined into outer functions,
    in nopython mode.
    Note that not all inner functions are guaranteed to be inlined.
    We just trust LLVM's inlining heuristics.
    """

    def make_pattern(self, fullname):
        """
        Make regexpr to match mangled name
        """
        parts = fullname.split('.')
        return r'_ZN?' + r''.join([r'\d+{}'.format(p) for p in parts])

    def assert_has_pattern(self, fullname, text):
        pat = self.make_pattern(fullname)
        self.assertIsNotNone(re.search(pat, text),
                             msg='expected {}'.format(pat))

    def assert_not_has_pattern(self, fullname, text):
        pat = self.make_pattern(fullname)
        self.assertIsNone(re.search(pat, text),
                          msg='unexpected {}'.format(pat))

    def test_inner_function(self):
        with override_config('DUMP_ASSEMBLY', True):
            with captured_stdout() as out:
                cfunc = jit((types.int32,), nopython=True)(outer_simple)
        self.assertPreciseEqual(cfunc(1), 4)
        # Check the inner function was elided from the output (which also
        # guarantees it was inlined into the outer function).
        asm = out.getvalue()
        prefix = __name__
        self.assert_has_pattern('%s.outer_simple' % prefix, asm)
        self.assert_not_has_pattern('%s.inner' % prefix, asm)

    def test_multiple_inner_functions(self):
        # Same with multiple inner functions, and multiple calls to
        # the same inner function (inner()).  This checks that linking in
        # the same library/module twice doesn't produce linker errors.
        with override_config('DUMP_ASSEMBLY', True):
            with captured_stdout() as out:
                cfunc = jit((types.int32,), nopython=True)(outer_multiple)
        self.assertPreciseEqual(cfunc(1), 6)
        asm = out.getvalue()
        prefix = __name__
        self.assert_has_pattern('%s.outer_multiple' % prefix, asm)
        self.assert_not_has_pattern('%s.more' % prefix, asm)
        self.assert_not_has_pattern('%s.inner' % prefix, asm)

    @skip_parfors_unsupported
    def test_inline_call_after_parfor(self):
        # replace the call to make sure inlining doesn't cause label conflict
        # with parfor body
        def test_impl(A):
            __dummy__()
            return A.sum()
        j_func = njit(parallel=True, pipeline_class=InlineTestPipeline)(
                                                                    test_impl)
        A = np.arange(10)
        self.assertEqual(test_impl(A), j_func(A))

    @skip_parfors_unsupported
    def test_inline_update_target_def(self):

        def test_impl(a):
            if a == 1:
                b = 2
            else:
                b = 3
            return b

        func_ir = compiler.run_frontend(test_impl)
        blocks = list(func_ir.blocks.values())
        for block in blocks:
            for i, stmt in enumerate(block.body):
                # match b = 2 and replace with lambda: 2
                if (isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Var)
                        and guard(find_const, func_ir, stmt.value) == 2):
                    # replace expr with a dummy call
                    func_ir._definitions[stmt.target.name].remove(stmt.value)
                    stmt.value = ir.Expr.call(ir.Var(block.scope, "myvar", loc=stmt.loc), (), (), stmt.loc)
                    func_ir._definitions[stmt.target.name].append(stmt.value)
                    #func = g.py_func#
                    inline_closure_call(func_ir, {}, block, i, lambda: 2)
                    break

        self.assertEqual(len(func_ir._definitions['b']), 2)

    @skip_parfors_unsupported
    def test_inline_var_dict_ret(self):
        # make sure inline_closure_call returns the variable replacement dict
        # and it contains the original variable name used in locals
        @njit(locals={'b': types.float64})
        def g(a):
            b = a + 1
            return b

        def test_impl():
            return g(1)

        func_ir = compiler.run_frontend(test_impl)
        blocks = list(func_ir.blocks.values())
        for block in blocks:
            for i, stmt in enumerate(block.body):
                if (isinstance(stmt, ir.Assign)
                        and isinstance(stmt.value, ir.Expr)
                        and stmt.value.op == 'call'):
                    func_def = guard(get_definition, func_ir, stmt.value.func)
                    if (isinstance(func_def, (ir.Global, ir.FreeVar))
                            and isinstance(func_def.value, CPUDispatcher)):
                        py_func = func_def.value.py_func
                        _, var_map = inline_closure_call(
                            func_ir, py_func.__globals__, block, i, py_func)
                        break

        self.assertTrue('b' in var_map)

    @skip_parfors_unsupported
    def test_inline_call_branch_pruning(self):
        # branch pruning pass should run properly in inlining to enable
        # functions with type checks
        @njit
        def foo(A=None):
            if A is None:
                return 2
            else:
                return A

        def test_impl(A=None):
            return foo(A)

        @register_pass(analysis_only=False, mutates_CFG=True)
        class PruningInlineTestPass(FunctionPass):
            _name = "pruning_inline_test_pass"

            def __init__(self):
                FunctionPass.__init__(self)

            def run_pass(self, state):
                # assuming the function has one block with one call inside
                assert len(state.func_ir.blocks) == 1
                block = list(state.func_ir.blocks.values())[0]
                for i, stmt in enumerate(block.body):
                    if (guard(find_callname, state.func_ir, stmt.value)
                            is not None):
                        inline_closure_call(state.func_ir, {}, block, i,
                            foo.py_func, state.typingctx, state.targetctx,
                            (state.typemap[stmt.value.args[0].name],),
                             state.typemap, state.calltypes)
                        break
                return True

        class InlineTestPipelinePrune(compiler.CompilerBase):

            def define_pipelines(self):
                pm = gen_pipeline(self.state, PruningInlineTestPass)
                pm.finalize()
                return [pm]

        # make sure inline_closure_call runs in full pipeline
        j_func = njit(pipeline_class=InlineTestPipelinePrune)(test_impl)
        A = 3
        self.assertEqual(test_impl(A), j_func(A))
        self.assertEqual(test_impl(), j_func())

        # make sure IR doesn't have branches
        fir = j_func.overloads[(types.Omitted(None),)].metadata['preserved_ir']
        fir.blocks = simplify_CFG(fir.blocks)
        self.assertEqual(len(fir.blocks), 1)

if __name__ == '__main__':
    unittest.main()