test_obj_lifetime.py 15.2 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
import collections
import sys
import weakref
import gc
import operator
from itertools import takewhile

import unittest
from numba import njit, jit
from numba.core.controlflow import CFGraph, Loop
from numba.core.compiler import (compile_extra, compile_isolated, Flags,
                                 CompilerBase, DefaultPassBuilder)
from numba.core.untyped_passes import PreserveIR
from numba.core.typed_passes import IRLegalization
from numba.core import types, ir
from numba.tests.support import TestCase, override_config, SerialMixin

enable_pyobj_flags = Flags()
enable_pyobj_flags.enable_pyobject = True

forceobj_flags = Flags()
forceobj_flags.force_pyobject = True

no_pyobj_flags = Flags()


class _Dummy(object):

    def __init__(self, recorder, name):
        self.recorder = recorder
        self.name = name
        recorder._add_dummy(self)

    def __add__(self, other):
        assert isinstance(other, _Dummy)
        return _Dummy(self.recorder, "%s + %s" % (self.name, other.name))

    def __iter__(self):
        return _DummyIterator(self.recorder, "iter(%s)" % self.name)


class _DummyIterator(_Dummy):

    count = 0

    def __next__(self):
        if self.count >= 3:
            raise StopIteration
        self.count += 1
        return _Dummy(self.recorder, "%s#%s" % (self.name, self.count))

    next = __next__


class RefRecorder(object):
    """
    An object which records events when instances created through it
    are deleted.  Custom events can also be recorded to aid in
    diagnosis.
    """

    def __init__(self):
        self._counts = collections.defaultdict(int)
        self._events = []
        self._wrs = {}

    def make_dummy(self, name):
        """
        Make an object whose deletion will be recorded as *name*.
        """
        return _Dummy(self, name)

    def _add_dummy(self, dummy):
        wr = weakref.ref(dummy, self._on_disposal)
        self._wrs[wr] = dummy.name

    __call__ = make_dummy

    def mark(self, event):
        """
        Manually append *event* to the recorded events.
        *event* can be formatted using format().
        """
        count = self._counts[event] + 1
        self._counts[event] = count
        self._events.append(event.format(count=count))

    def _on_disposal(self, wr):
        name = self._wrs.pop(wr)
        self._events.append(name)

    @property
    def alive(self):
        """
        A list of objects which haven't been deleted yet.
        """
        return [wr() for wr in self._wrs]

    @property
    def recorded(self):
        """
        A list of recorded events.
        """
        return self._events


def simple_usecase1(rec):
    a = rec('a')
    b = rec('b')
    c = rec('c')
    a = b + c
    rec.mark('--1--')
    d = a + a   # b + c + b + c
    rec.mark('--2--')
    return d

def simple_usecase2(rec):
    a = rec('a')
    b = rec('b')
    rec.mark('--1--')
    x = a
    y = x
    a = None
    return y

def looping_usecase1(rec):
    a = rec('a')
    b = rec('b')
    c = rec('c')
    x = b
    for y in a:
        x = x + y
        rec.mark('--loop bottom--')
    rec.mark('--loop exit--')
    x = x + c
    return x

def looping_usecase2(rec):
    a = rec('a')
    b = rec('b')
    cum = rec('cum')
    for x in a:
        rec.mark('--outer loop top--')
        cum = cum + x
        z = x + x
        rec.mark('--inner loop entry #{count}--')
        for y in b:
            rec.mark('--inner loop top #{count}--')
            cum = cum + y
            rec.mark('--inner loop bottom #{count}--')
        rec.mark('--inner loop exit #{count}--')
        if cum:
            cum = y + z
        else:
            # Never gets here, but let the Numba compiler see a `break` opcode
            break
        rec.mark('--outer loop bottom #{count}--')
    else:
        rec.mark('--outer loop else--')
    rec.mark('--outer loop exit--')
    return cum

def generator_usecase1(rec):
    a = rec('a')
    b = rec('b')
    yield a
    yield b

def generator_usecase2(rec):
    a = rec('a')
    b = rec('b')
    for x in a:
        yield x
    yield b


class MyError(RuntimeError):
    pass

def do_raise(x):
    raise MyError(x)

def raising_usecase1(rec):
    a = rec('a')
    b = rec('b')
    d = rec('d')
    if a:
        do_raise("foo")
        c = rec('c')
        c + a
    c + b

def raising_usecase2(rec):
    a = rec('a')
    b = rec('b')
    if a:
        c = rec('c')
        do_raise(b)
    a + c

def raising_usecase3(rec):
    a = rec('a')
    b = rec('b')
    if a:
        raise MyError(b)


def del_before_definition(rec):
    """
    This test reveal a bug that there is a del on uninitialized variable
    """
    n = 5
    for i in range(n):
        rec.mark(str(i))
        n = 0
        for j in range(n):
            return 0
        else:
            if i < 2:
                continue
            elif i == 2:
                for j in range(i):
                    return i
                rec.mark('FAILED')
            rec.mark('FAILED')
        rec.mark('FAILED')
    rec.mark('OK')
    return -1


def inf_loop_multiple_back_edge(rec):
    """
    test to reveal bug of invalid liveness when infinite loop has multiple
    backedge.
    """
    while True:
        rec.mark("yield")
        yield
        p = rec('p')
        if p:
            rec.mark('bra')
            pass


class TestObjLifetime(TestCase):
    """
    Test lifetime of Python objects inside jit-compiled functions.
    """

    def compile(self, pyfunc):
        # Note: looplift must be disabled. The test require the function
        #       control-flow to be unchanged.
        cfunc = jit((types.pyobject,), forceobj=True, looplift=False)(pyfunc)
        return cfunc

    def compile_and_record(self, pyfunc, raises=None):
        rec = RefRecorder()
        cfunc = self.compile(pyfunc)
        if raises is not None:
            with self.assertRaises(raises):
                cfunc(rec)
        else:
            cfunc(rec)
        return rec

    def assertRecordOrder(self, rec, expected):
        """
        Check that the *expected* markers occur in that order in *rec*'s
        recorded events.
        """
        actual = []
        recorded = rec.recorded
        remaining = list(expected)
        # Find out in which order, if any, the expected events were recorded
        for d in recorded:
            if d in remaining:
                actual.append(d)
                # User may or may not expect duplicates, handle them properly
                remaining.remove(d)
        self.assertEqual(actual, expected,
                         "the full list of recorded events is: %r" % (recorded,))

    def test_simple1(self):
        rec = self.compile_and_record(simple_usecase1)
        self.assertFalse(rec.alive)
        self.assertRecordOrder(rec, ['a', 'b', '--1--'])
        self.assertRecordOrder(rec, ['a', 'c', '--1--'])
        self.assertRecordOrder(rec, ['--1--', 'b + c', '--2--'])

    def test_simple2(self):
        rec = self.compile_and_record(simple_usecase2)
        self.assertFalse(rec.alive)
        self.assertRecordOrder(rec, ['b', '--1--', 'a'])

    def test_looping1(self):
        rec = self.compile_and_record(looping_usecase1)
        self.assertFalse(rec.alive)
        # a and b are unneeded after the loop, check they were disposed of
        self.assertRecordOrder(rec, ['a', 'b', '--loop exit--', 'c'])
        # check disposal order of iterator items and iterator
        self.assertRecordOrder(rec, ['iter(a)#1', '--loop bottom--',
                                     'iter(a)#2', '--loop bottom--',
                                     'iter(a)#3', '--loop bottom--',
                                     'iter(a)', '--loop exit--',
                                     ])

    def test_looping2(self):
        rec = self.compile_and_record(looping_usecase2)
        self.assertFalse(rec.alive)
        # `a` is disposed of after its iterator is taken
        self.assertRecordOrder(rec, ['a', '--outer loop top--'])
        # Check disposal of iterators
        self.assertRecordOrder(rec, ['iter(a)', '--outer loop else--',
                                     '--outer loop exit--'])
        self.assertRecordOrder(rec, ['iter(b)', '--inner loop exit #1--',
                                     'iter(b)', '--inner loop exit #2--',
                                     'iter(b)', '--inner loop exit #3--',
                                     ])
        # Disposal of in-loop variable `x`
        self.assertRecordOrder(rec, ['iter(a)#1', '--inner loop entry #1--',
                                     'iter(a)#2', '--inner loop entry #2--',
                                     'iter(a)#3', '--inner loop entry #3--',
                                     ])
        # Disposal of in-loop variable `z`
        self.assertRecordOrder(rec, ['iter(a)#1 + iter(a)#1',
                                     '--outer loop bottom #1--',
                                     ])

    def exercise_generator(self, genfunc):
        cfunc = self.compile(genfunc)
        # Exhaust the generator
        rec = RefRecorder()
        with self.assertRefCount(rec):
            gen = cfunc(rec)
            next(gen)
            self.assertTrue(rec.alive)
            list(gen)
            self.assertFalse(rec.alive)
        # Instantiate the generator but never iterate
        rec = RefRecorder()
        with self.assertRefCount(rec):
            gen = cfunc(rec)
            del gen
            gc.collect()
            self.assertFalse(rec.alive)
        # Stop iterating before exhaustion
        rec = RefRecorder()
        with self.assertRefCount(rec):
            gen = cfunc(rec)
            next(gen)
            self.assertTrue(rec.alive)
            del gen
            gc.collect()
            self.assertFalse(rec.alive)

    def test_generator1(self):
        self.exercise_generator(generator_usecase1)

    def test_generator2(self):
        self.exercise_generator(generator_usecase2)

    def test_del_before_definition(self):
        rec = self.compile_and_record(del_before_definition)
        self.assertEqual(rec.recorded, ['0', '1', '2'])

    def test_raising1(self):
        with self.assertRefCount(do_raise):
            rec = self.compile_and_record(raising_usecase1, raises=MyError)
            self.assertFalse(rec.alive)

    def test_raising2(self):
        with self.assertRefCount(do_raise):
            rec = self.compile_and_record(raising_usecase2, raises=MyError)
            self.assertFalse(rec.alive)

    def test_raising3(self):
        with self.assertRefCount(MyError):
            rec = self.compile_and_record(raising_usecase3, raises=MyError)
            self.assertFalse(rec.alive)

    def test_inf_loop_multiple_back_edge(self):
        cfunc = self.compile(inf_loop_multiple_back_edge)
        rec = RefRecorder()
        iterator = iter(cfunc(rec))
        next(iterator)
        self.assertEqual(rec.alive, [])
        next(iterator)
        self.assertEqual(rec.alive, [])
        next(iterator)
        self.assertEqual(rec.alive, [])
        self.assertEqual(rec.recorded,
                         ['yield', 'p', 'bra', 'yield', 'p', 'bra', 'yield'])


class TestExtendingVariableLifetimes(SerialMixin, TestCase):
    # Test for `numba.config.EXTEND_VARIABLE_LIFETIMES` which moves the ir.Del
    # nodes to just before a block's terminator, i.e. their lifetime is extended
    # beyond the point of last use.

    def test_lifetime_basic(self):

        def get_ir(extend_lifetimes):
            class IRPreservingCompiler(CompilerBase):

                def define_pipelines(self):
                    pm = DefaultPassBuilder.define_nopython_pipeline(self.state)
                    pm.add_pass_after(PreserveIR, IRLegalization)
                    pm.finalize()
                    return [pm]

            @njit(pipeline_class=IRPreservingCompiler)
            def foo():
                a = 10
                b = 20
                c = a + b
                # a and b are now unused, standard behaviour is ir.Del for them here
                d = c / c
                return d

            with override_config('EXTEND_VARIABLE_LIFETIMES', extend_lifetimes):
                foo()
                cres = foo.overloads[foo.signatures[0]]
                func_ir = cres.metadata['preserved_ir']

            return func_ir


        def check(func_ir, expect):
            # assert single block
            self.assertEqual(len(func_ir.blocks), 1)
            blk = next(iter(func_ir.blocks.values()))

            # check sequencing
            for expect_class, got_stmt in zip(expect, blk.body):
                self.assertIsInstance(got_stmt, expect_class)

        del_after_use_ir = get_ir(False)
        # should be 3 assigns (a, b, c), 2 del (a, b), assign (d), del (c)
        # assign for cast d to return, del (d), return
        expect = [*((ir.Assign,) * 3), ir.Del, ir.Del, ir.Assign, ir.Del,
                  ir.Assign, ir.Del, ir.Return]
        check(del_after_use_ir, expect)

        del_at_block_end_ir = get_ir(True)
        # should be 4 assigns (a, b, c, d), assign for cast d to return,
        # 4 dels (a, b, c, d) then the return.
        expect = [*((ir.Assign,) * 4), ir.Assign, *((ir.Del,) * 4), ir.Return]
        check(del_at_block_end_ir, expect)

    def test_dbg_extend_lifetimes(self):

        def get_ir(**options):
            class IRPreservingCompiler(CompilerBase):

                def define_pipelines(self):
                    pm = DefaultPassBuilder.define_nopython_pipeline(self.state)
                    pm.add_pass_after(PreserveIR, IRLegalization)
                    pm.finalize()
                    return [pm]

            @njit(pipeline_class=IRPreservingCompiler, **options)
            def foo():
                a = 10
                b = 20
                c = a + b
                # a and b are now unused, standard behaviour is ir.Del for them here
                d = c / c
                return d

            foo()
            cres = foo.overloads[foo.signatures[0]]
            func_ir = cres.metadata['preserved_ir']

            return func_ir

        # _dbg_extend_lifetimes is on when debug=True
        ir_debug = get_ir(debug=True)
        # explicitly turn on _dbg_extend_lifetimes
        ir_debug_ext = get_ir(debug=True, _dbg_extend_lifetimes=True)
        # explicitly turn off _dbg_extend_lifetimes
        ir_debug_no_ext = get_ir(debug=True, _dbg_extend_lifetimes=False)

        def is_del_grouped_at_the_end(fir):
            [blk] = fir.blocks.values()
            # Mark all statements that are ir.Del
            inst_is_del = [isinstance(stmt, ir.Del) for stmt in blk.body]
            # Get the leading segment that are not dels
            not_dels = list(takewhile(operator.not_, inst_is_del))
            # Compute the starting position of the dels
            begin = len(not_dels)
            # Get the remaining segment that are all dels
            all_dels = list(takewhile(operator.truth, inst_is_del[begin:]))
            # Compute the ending position of the dels
            end = begin + len(all_dels)
            # If the dels are all grouped at the end (before the terminator),
            # the end position will be the last position of the list
            return end == len(inst_is_del) - 1

        self.assertTrue(is_del_grouped_at_the_end(ir_debug))
        self.assertTrue(is_del_grouped_at_the_end(ir_debug_ext))
        self.assertFalse(is_del_grouped_at_the_end(ir_debug_no_ext))


if __name__ == "__main__":
    unittest.main()