test_remove_dead.py 10.2 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
#
# Copyright (c) 2017 Intel Corporation
# SPDX-License-Identifier: BSD-2-Clause
#

import numba
import numba.parfors.parfor
from numba.core import ir_utils, cpu
from numba.core.compiler import compile_isolated, Flags
from numba.core import types, typing, ir, config, compiler
from numba.core.registry import cpu_target
from numba.core.annotations import type_annotations
from numba.core.ir_utils import (copy_propagate, apply_copy_propagate,
                            get_name_var_table, remove_dels, remove_dead,
                            remove_call_handlers, alias_func_extensions)
from numba.core.typed_passes import type_inference_stage
from numba.core.compiler_machinery import FunctionPass, register_pass, PassManager
from numba.core.untyped_passes import (ExtractByteCode, TranslateByteCode, FixupArgs,
                             IRProcessing, DeadBranchPrune,
                             RewriteSemanticConstants, GenericRewrites,
                             WithLifting, PreserveIR, InlineClosureLikes)

from numba.core.typed_passes import (NopythonTypeInference, AnnotateTypes,
                           NopythonRewrites, PreParforPass, ParforPass,
                           DumpParforDiagnostics, NativeLowering,
                           IRLegalization, NoPythonBackend, NativeLowering)
import numpy as np
from numba.tests.support import skip_parfors_unsupported, needs_blas
import unittest


def test_will_propagate(b, z, w):
    x1 = 3
    x = x1
    if b > 0:
        y = z + w
    else:
        y = 0
    a = 2 * x
    return a < b

def null_func(a,b,c,d):
    False

@numba.njit
def dummy_aliased_func(A):
    return A

def alias_ext_dummy_func(lhs_name, args, alias_map, arg_aliases):
    ir_utils._add_alias(lhs_name, args[0].name, alias_map, arg_aliases)

def findLhsAssign(func_ir, var):
    for label, block in func_ir.blocks.items():
        for i, inst in enumerate(block.body):
            if isinstance(inst, ir.Assign) and inst.target.name==var:
                return True

    return False

class TestRemoveDead(unittest.TestCase):

    _numba_parallel_test_ = False

    def compile_parallel(self, func, arg_types):
        fast_pflags = Flags()
        fast_pflags.auto_parallel = cpu.ParallelOptions(True)
        fast_pflags.nrt = True
        fast_pflags.fastmath = cpu.FastMathOptions(True)
        return compile_isolated(func, arg_types, flags=fast_pflags).entry_point

    def test1(self):
        typingctx = typing.Context()
        targetctx = cpu.CPUContext(typingctx)
        test_ir = compiler.run_frontend(test_will_propagate)
        with cpu_target.nested_context(typingctx, targetctx):
            typingctx.refresh()
            targetctx.refresh()
            args = (types.int64, types.int64, types.int64)
            typemap, _, calltypes, _ = type_inference_stage(typingctx, targetctx, test_ir, args, None)
            remove_dels(test_ir.blocks)
            in_cps, out_cps = copy_propagate(test_ir.blocks, typemap)
            apply_copy_propagate(test_ir.blocks, in_cps, get_name_var_table(test_ir.blocks), typemap, calltypes)

            remove_dead(test_ir.blocks, test_ir.arg_names, test_ir)
            self.assertFalse(findLhsAssign(test_ir, "x"))

    def test2(self):
        def call_np_random_seed():
            np.random.seed(2)

        def seed_call_exists(func_ir):
            for inst in func_ir.blocks[0].body:
                if (isinstance(inst, ir.Assign) and
                    isinstance(inst.value, ir.Expr) and
                    inst.value.op == 'call' and
                    func_ir.get_definition(inst.value.func).attr == 'seed'):
                    return True
            return False

        test_ir = compiler.run_frontend(call_np_random_seed)
        remove_dead(test_ir.blocks, test_ir.arg_names, test_ir)
        self.assertTrue(seed_call_exists(test_ir))

    def run_array_index_test(self, func):
        A1 = np.arange(6).reshape(2,3)
        A2 = A1.copy()
        i = 0
        pfunc = self.compile_parallel(func, (numba.typeof(A1), numba.typeof(i)))

        func(A1, i)
        pfunc(A2, i)
        np.testing.assert_array_equal(A1, A2)

    def test_alias_ravel(self):
        def func(A, i):
            B = A.ravel()
            B[i] = 3

        self.run_array_index_test(func)

    def test_alias_flat(self):
        def func(A, i):
            B = A.flat
            B[i] = 3

        self.run_array_index_test(func)

    def test_alias_transpose1(self):
        def func(A, i):
            B = A.T
            B[i,0] = 3

        self.run_array_index_test(func)

    def test_alias_transpose2(self):
        def func(A, i):
            B = A.transpose()
            B[i,0] = 3

        self.run_array_index_test(func)

    def test_alias_transpose3(self):
        def func(A, i):
            B = np.transpose(A)
            B[i,0] = 3

        self.run_array_index_test(func)

    @skip_parfors_unsupported
    @needs_blas
    def test_alias_ctypes(self):
        # use xxnrm2 to test call a C function with ctypes
        from numba.np.linalg import _BLAS
        xxnrm2 = _BLAS().numba_xxnrm2(types.float64)

        def remove_dead_xxnrm2(rhs, lives, call_list):
            if call_list == [xxnrm2]:
                return rhs.args[4].name not in lives
            return False

        # adding this handler has no-op effect since this function won't match
        # anything else but it's a bit cleaner to save the state and recover
        old_remove_handlers = remove_call_handlers[:]
        remove_call_handlers.append(remove_dead_xxnrm2)

        def func(ret):
            a = np.ones(4)
            xxnrm2(100, 4, a.ctypes, 1, ret.ctypes)

        A1 = np.zeros(1)
        A2 = A1.copy()

        try:
            pfunc = self.compile_parallel(func, (numba.typeof(A1),))
            numba.njit(func)(A1)
            pfunc(A2)
        finally:
            # recover global state
            remove_call_handlers[:] = old_remove_handlers

        self.assertEqual(A1[0], A2[0])

    def test_alias_reshape1(self):
        def func(A, i):
            B = np.reshape(A, (3,2))
            B[i,0] = 3

        self.run_array_index_test(func)

    def test_alias_reshape2(self):
        def func(A, i):
            B = A.reshape(3,2)
            B[i,0] = 3

        self.run_array_index_test(func)

    def test_alias_func_ext(self):
        def func(A, i):
            B = dummy_aliased_func(A)
            B[i, 0] = 3

        # save global state
        old_ext_handlers = alias_func_extensions.copy()
        try:
            alias_func_extensions[('dummy_aliased_func',
                'numba.tests.test_remove_dead')] = alias_ext_dummy_func
            self.run_array_index_test(func)
        finally:
            # recover global state
            ir_utils.alias_func_extensions = old_ext_handlers

    def test_rm_dead_rhs_vars(self):
        """make sure lhs variable of assignment is considered live if used in
        rhs (test for #6715).
        """
        def func():
            for i in range(3):
                a = (lambda j: j)(i)
                a = np.array(a)
            return a

        self.assertEqual(func(), numba.njit(func)())

    @skip_parfors_unsupported
    def test_alias_parfor_extension(self):
        """Make sure aliases are considered in remove dead extension for
        parfors.
        """
        def func():
            n = 11
            numba.parfors.parfor.init_prange()
            A = np.empty(n)
            B = A  # create alias to A
            for i in numba.prange(n):
                A[i] = i

            return B

        @register_pass(analysis_only=False, mutates_CFG=True)
        class LimitedParfor(FunctionPass):
            _name = "limited_parfor"

            def __init__(self):
                FunctionPass.__init__(self)

            def run_pass(self, state):
                parfor_pass = numba.parfors.parfor.ParforPass(
                    state.func_ir,
                    state.typemap,
                    state.calltypes,
                    state.return_type,
                    state.typingctx,
                    state.flags.auto_parallel,
                    state.flags,
                    state.metadata,
                    state.parfor_diagnostics
                )
                remove_dels(state.func_ir.blocks)
                parfor_pass.array_analysis.run(state.func_ir.blocks)
                parfor_pass._convert_loop(state.func_ir.blocks)
                remove_dead(state.func_ir.blocks,
                            state.func_ir.arg_names,
                            state.func_ir,
                            state.typemap)
                numba.parfors.parfor.get_parfor_params(state.func_ir.blocks,
                                                parfor_pass.options.fusion,
                                                parfor_pass.nested_fusion_info)
                return True

        class TestPipeline(compiler.Compiler):
            """Test pipeline that just converts prange() to parfor and calls
            remove_dead(). Copy propagation can replace B in the example code
            which this pipeline avoids.
            """
            def define_pipelines(self):
                name = 'test parfor aliasing'
                pm = PassManager(name)
                pm.add_pass(TranslateByteCode, "analyzing bytecode")
                pm.add_pass(FixupArgs, "fix up args")
                pm.add_pass(IRProcessing, "processing IR")
                pm.add_pass(WithLifting, "Handle with contexts")
                # pre typing
                if not self.state.flags.no_rewrites:
                    pm.add_pass(GenericRewrites, "nopython rewrites")
                    pm.add_pass(RewriteSemanticConstants, "rewrite semantic constants")
                    pm.add_pass(DeadBranchPrune, "dead branch pruning")
                pm.add_pass(InlineClosureLikes,
                            "inline calls to locally defined closures")
                # typing
                pm.add_pass(NopythonTypeInference, "nopython frontend")

                # lower
                pm.add_pass(NativeLowering, "native lowering")
                pm.add_pass(NoPythonBackend, "nopython mode backend")
                pm.finalize()
                return [pm]

        test_res = numba.jit(pipeline_class=TestPipeline)(func)()
        py_res = func()
        np.testing.assert_array_equal(test_res, py_res)


if __name__ == "__main__":
    unittest.main()