transforms.py 30.9 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
"""
Implement transformation on Numba IR
"""


from collections import namedtuple, defaultdict
import logging
import operator

from numba.core.analysis import compute_cfg_from_blocks, find_top_level_loops
from numba.core import errors, ir, ir_utils
from numba.core.analysis import compute_use_defs, compute_cfg_from_blocks
from numba.core.utils import PYVERSION


_logger = logging.getLogger(__name__)


def _extract_loop_lifting_candidates(cfg, blocks):
    """
    Returns a list of loops that are candidate for loop lifting
    """
    # check well-formed-ness of the loop
    def same_exit_point(loop):
        "all exits must point to the same location"
        outedges = set()
        for k in loop.exits:
            succs = set(x for x, _ in cfg.successors(k))
            if not succs:
                # If the exit point has no successor, it contains an return
                # statement, which is not handled by the looplifting code.
                # Thus, this loop is not a candidate.
                _logger.debug("return-statement in loop.")
                return False
            outedges |= succs
        ok = len(outedges) == 1
        _logger.debug("same_exit_point=%s (%s)", ok, outedges)
        return ok

    def one_entry(loop):
        "there is one entry"
        ok = len(loop.entries) == 1
        _logger.debug("one_entry=%s", ok)
        return ok

    def cannot_yield(loop):
        "cannot have yield inside the loop"
        insiders = set(loop.body) | set(loop.entries) | set(loop.exits)
        for blk in map(blocks.__getitem__, insiders):
            for inst in blk.body:
                if isinstance(inst, ir.Assign):
                    if isinstance(inst.value, ir.Yield):
                        _logger.debug("has yield")
                        return False
        _logger.debug("no yield")
        return True

    _logger.info('finding looplift candidates')
    # the check for cfg.entry_point in the loop.entries is to prevent a bad
    # rewrite where a prelude for a lifted loop would get written into block -1
    # if a loop entry were in block 0
    candidates = []
    for loop in find_top_level_loops(cfg):
        _logger.debug("top-level loop: %s", loop)
        if (same_exit_point(loop) and one_entry(loop) and cannot_yield(loop) and
            cfg.entry_point() not in loop.entries):
            candidates.append(loop)
            _logger.debug("add candidate: %s", loop)
    return candidates


def find_region_inout_vars(blocks, livemap, callfrom, returnto, body_block_ids):
    """Find input and output variables to a block region.
    """
    inputs = livemap[callfrom]
    outputs = livemap[returnto]

    # ensure live variables are actually used in the blocks, else remove,
    # saves having to create something valid to run through postproc
    # to achieve similar
    loopblocks = {}
    for k in body_block_ids:
        loopblocks[k] = blocks[k]

    used_vars = set()
    def_vars = set()
    defs = compute_use_defs(loopblocks)
    for vs in defs.usemap.values():
        used_vars |= vs
    for vs in defs.defmap.values():
        def_vars |= vs
    used_or_defined = used_vars | def_vars

    # note: sorted for stable ordering
    inputs = sorted(set(inputs) & used_or_defined)
    outputs = sorted(set(outputs) & used_or_defined & def_vars)
    return inputs, outputs


_loop_lift_info = namedtuple('loop_lift_info',
                             'loop,inputs,outputs,callfrom,returnto')


def _loop_lift_get_candidate_infos(cfg, blocks, livemap):
    """
    Returns information on looplifting candidates.
    """
    loops = _extract_loop_lifting_candidates(cfg, blocks)
    loopinfos = []
    for loop in loops:

        [callfrom] = loop.entries   # requirement checked earlier
        an_exit = next(iter(loop.exits))  # anyone of the exit block
        if len(loop.exits) > 1:
            # Pre-Py3.8 may have multiple exits
            [(returnto, _)] = cfg.successors(an_exit)  # requirement checked earlier
        else:
            # Post-Py3.8 DO NOT have multiple exits
            returnto = an_exit

        local_block_ids = set(loop.body) | set(loop.entries) | set(loop.exits)
        inputs, outputs = find_region_inout_vars(
            blocks=blocks,
            livemap=livemap,
            callfrom=callfrom,
            returnto=returnto,
            body_block_ids=local_block_ids,
        )

        lli = _loop_lift_info(loop=loop, inputs=inputs, outputs=outputs,
                              callfrom=callfrom, returnto=returnto)
        loopinfos.append(lli)

    return loopinfos


def _loop_lift_modify_call_block(liftedloop, block, inputs, outputs, returnto):
    """
    Transform calling block from top-level function to call the lifted loop.
    """
    scope = block.scope
    loc = block.loc
    blk = ir.Block(scope=scope, loc=loc)

    ir_utils.fill_block_with_call(
        newblock=blk,
        callee=liftedloop,
        label_next=returnto,
        inputs=inputs,
        outputs=outputs,
    )
    return blk


def _loop_lift_prepare_loop_func(loopinfo, blocks):
    """
    Inplace transform loop blocks for use as lifted loop.
    """
    entry_block = blocks[loopinfo.callfrom]
    scope = entry_block.scope
    loc = entry_block.loc

    # Lowering assumes the first block to be the one with the smallest offset
    firstblk = min(blocks) - 1
    blocks[firstblk] = ir_utils.fill_callee_prologue(
        block=ir.Block(scope=scope, loc=loc),
        inputs=loopinfo.inputs,
        label_next=loopinfo.callfrom,
    )
    blocks[loopinfo.returnto] = ir_utils.fill_callee_epilogue(
        block=ir.Block(scope=scope, loc=loc),
        outputs=loopinfo.outputs,
    )


def _loop_lift_modify_blocks(func_ir, loopinfo, blocks,
                             typingctx, targetctx, flags, locals):
    """
    Modify the block inplace to call to the lifted-loop.
    Returns a dictionary of blocks of the lifted-loop.
    """
    from numba.core.dispatcher import LiftedLoop

    # Copy loop blocks
    loop = loopinfo.loop

    loopblockkeys = set(loop.body) | set(loop.entries)
    if len(loop.exits) > 1:
        # Pre-Py3.8 may have multiple exits
        loopblockkeys |= loop.exits
    loopblocks = dict((k, blocks[k].copy()) for k in loopblockkeys)
    # Modify the loop blocks
    _loop_lift_prepare_loop_func(loopinfo, loopblocks)

    # Create a new IR for the lifted loop
    lifted_ir = func_ir.derive(blocks=loopblocks,
                               arg_names=tuple(loopinfo.inputs),
                               arg_count=len(loopinfo.inputs),
                               force_non_generator=True)
    liftedloop = LiftedLoop(lifted_ir,
                            typingctx, targetctx, flags, locals)

    # modify for calling into liftedloop
    callblock = _loop_lift_modify_call_block(liftedloop, blocks[loopinfo.callfrom],
                                             loopinfo.inputs, loopinfo.outputs,
                                             loopinfo.returnto)
    # remove blocks
    for k in loopblockkeys:
        del blocks[k]
    # update main interpreter callsite into the liftedloop
    blocks[loopinfo.callfrom] = callblock
    return liftedloop


def _has_multiple_loop_exits(cfg, lpinfo):
    """Returns True if there is more than one exit in the loop.

    NOTE: "common exits" refers to the situation where a loop exit has another
    loop exit as its successor. In that case, we do not need to alter it.
    """
    if len(lpinfo.exits) <= 1:
        return False
    exits = set(lpinfo.exits)
    pdom = cfg.post_dominators()

    # Eliminate blocks that have other blocks as post-dominators.
    processed = set()
    remain = set(exits) # create a copy to work on
    while remain:
        node = remain.pop()
        processed.add(node)
        exits -= pdom[node] - {node}
        remain = exits - processed

    return len(exits) > 1


def _pre_looplift_transform(func_ir):
    """Canonicalize loops for looplifting.
    """
    from numba.core.postproc import PostProcessor

    cfg = compute_cfg_from_blocks(func_ir.blocks)
    # For every loop that has multiple exits, combine the exits into one.
    for loop_info in cfg.loops().values():
        if _has_multiple_loop_exits(cfg, loop_info):
            func_ir, _common_key = _fix_multi_exit_blocks(
                func_ir, loop_info.exits
            )
    # Reset and reprocess the func_ir
    func_ir._reset_analysis_variables()
    PostProcessor(func_ir).run()
    return func_ir


def loop_lifting(func_ir, typingctx, targetctx, flags, locals):
    """
    Loop lifting transformation.

    Given a interpreter `func_ir` returns a 2 tuple of
    `(toplevel_interp, [loop0_interp, loop1_interp, ....])`
    """
    func_ir = _pre_looplift_transform(func_ir)
    blocks = func_ir.blocks.copy()
    cfg = compute_cfg_from_blocks(blocks)
    loopinfos = _loop_lift_get_candidate_infos(cfg, blocks,
                                               func_ir.variable_lifetime.livemap)
    loops = []
    if loopinfos:
        _logger.debug('loop lifting this IR with %d candidates:\n%s',
                      len(loopinfos), func_ir.dump_to_string())
    for loopinfo in loopinfos:
        lifted = _loop_lift_modify_blocks(func_ir, loopinfo, blocks,
                                          typingctx, targetctx, flags, locals)
        loops.append(lifted)

    # Make main IR
    main = func_ir.derive(blocks=blocks)

    return main, loops


def canonicalize_cfg_single_backedge(blocks):
    """
    Rewrite loops that have multiple backedges.
    """
    cfg = compute_cfg_from_blocks(blocks)
    newblocks = blocks.copy()

    def new_block_id():
        return max(newblocks.keys()) + 1

    def has_multiple_backedges(loop):
        count = 0
        for k in loop.body:
            blk = blocks[k]
            edges = blk.terminator.get_targets()
            # is a backedge?
            if loop.header in edges:
                count += 1
                if count > 1:
                    # early exit
                    return True
        return False

    def yield_loops_with_multiple_backedges():
        for lp in cfg.loops().values():
            if has_multiple_backedges(lp):
                yield lp

    def replace_target(term, src, dst):
        def replace(target):
            return (dst if target == src else target)

        if isinstance(term, ir.Branch):
            return ir.Branch(cond=term.cond,
                             truebr=replace(term.truebr),
                             falsebr=replace(term.falsebr),
                             loc=term.loc)
        elif isinstance(term, ir.Jump):
            return ir.Jump(target=replace(term.target), loc=term.loc)
        else:
            assert not term.get_targets()
            return term

    def rewrite_single_backedge(loop):
        """
        Add new tail block that gathers all the backedges
        """
        header = loop.header
        tailkey = new_block_id()
        for blkkey in loop.body:
            blk = newblocks[blkkey]
            if header in blk.terminator.get_targets():
                newblk = blk.copy()
                # rewrite backedge into jumps to new tail block
                newblk.body[-1] = replace_target(blk.terminator, header,
                                                 tailkey)
                newblocks[blkkey] = newblk
        # create new tail block
        entryblk = newblocks[header]
        tailblk = ir.Block(scope=entryblk.scope, loc=entryblk.loc)
        # add backedge
        tailblk.append(ir.Jump(target=header, loc=tailblk.loc))
        newblocks[tailkey] = tailblk

    for loop in yield_loops_with_multiple_backedges():
        rewrite_single_backedge(loop)

    return newblocks


def canonicalize_cfg(blocks):
    """
    Rewrite the given blocks to canonicalize the CFG.
    Returns a new dictionary of blocks.
    """
    return canonicalize_cfg_single_backedge(blocks)


def with_lifting(func_ir, typingctx, targetctx, flags, locals):
    """With-lifting transformation

    Rewrite the IR to extract all withs.
    Only the top-level withs are extracted.
    Returns the (the_new_ir, the_lifted_with_ir)
    """
    from numba.core import postproc

    def dispatcher_factory(func_ir, objectmode=False, **kwargs):
        from numba.core.dispatcher import LiftedWith, ObjModeLiftedWith

        myflags = flags.copy()
        if objectmode:
            # Lifted with-block cannot looplift
            myflags.enable_looplift = False
            # Lifted with-block uses object mode
            myflags.enable_pyobject = True
            myflags.force_pyobject = True
            myflags.no_cpython_wrapper = False
            cls = ObjModeLiftedWith
        else:
            cls = LiftedWith
        return cls(func_ir, typingctx, targetctx, myflags, locals, **kwargs)

    # find where with-contexts regions are
    withs, func_ir = find_setupwiths(func_ir)

    if not withs:
        return func_ir, []

    postproc.PostProcessor(func_ir).run()  # ensure we have variable lifetime
    assert func_ir.variable_lifetime
    vlt = func_ir.variable_lifetime
    blocks = func_ir.blocks.copy()
    cfg = vlt.cfg
    # For each with-regions, mutate them according to
    # the kind of contextmanager
    sub_irs = []
    for (blk_start, blk_end) in withs:
        body_blocks = []
        for node in _cfg_nodes_in_region(cfg, blk_start, blk_end):
            body_blocks.append(node)
        _legalize_with_head(blocks[blk_start])
        # Find the contextmanager
        cmkind, extra = _get_with_contextmanager(func_ir, blocks, blk_start)
        # Mutate the body and get new IR
        sub = cmkind.mutate_with_body(func_ir, blocks, blk_start, blk_end,
                                      body_blocks, dispatcher_factory,
                                      extra)
        sub_irs.append(sub)
    if not sub_irs:
        # Unchanged
        new_ir = func_ir
    else:
        new_ir = func_ir.derive(blocks)
    return new_ir, sub_irs


def _get_with_contextmanager(func_ir, blocks, blk_start):
    """Get the global object used for the context manager
    """
    _illegal_cm_msg = "Illegal use of context-manager."

    def get_var_dfn(var):
        """Get the definition given a variable"""
        return func_ir.get_definition(var)

    def get_ctxmgr_obj(var_ref):
        """Return the context-manager object and extra info.

        The extra contains the arguments if the context-manager is used
        as a call.
        """
        # If the contextmanager used as a Call
        dfn = func_ir.get_definition(var_ref)
        if isinstance(dfn, ir.Expr) and dfn.op == 'call':
            args = [get_var_dfn(x) for x in dfn.args]
            kws = {k: get_var_dfn(v) for k, v in dfn.kws}
            extra = {'args': args, 'kwargs': kws}
            var_ref = dfn.func
        else:
            extra = None

        ctxobj = ir_utils.guard(ir_utils.find_outer_value, func_ir, var_ref)

        # check the contextmanager object
        if ctxobj is ir.UNDEFINED:
            raise errors.CompilerError(
                "Undefined variable used as context manager",
                loc=blocks[blk_start].loc,
                )

        if ctxobj is None:
            raise errors.CompilerError(_illegal_cm_msg, loc=dfn.loc)

        return ctxobj, extra

    # Scan the start of the with-region for the contextmanager
    for stmt in blocks[blk_start].body:
        if isinstance(stmt, ir.EnterWith):
            var_ref = stmt.contextmanager
            ctxobj, extra = get_ctxmgr_obj(var_ref)
            if not hasattr(ctxobj, 'mutate_with_body'):
                raise errors.CompilerError(
                    "Unsupported context manager in use",
                    loc=blocks[blk_start].loc,
                    )
            return ctxobj, extra
    # No contextmanager found?
    raise errors.CompilerError(
        "malformed with-context usage",
        loc=blocks[blk_start].loc,
        )


def _legalize_with_head(blk):
    """Given *blk*, the head block of the with-context, check that it doesn't
    do anything else.
    """
    counters = defaultdict(int)
    for stmt in blk.body:
        counters[type(stmt)] += 1
    if counters.pop(ir.EnterWith) != 1:
        raise errors.CompilerError(
            "with's head-block must have exactly 1 ENTER_WITH",
            loc=blk.loc,
            )
    if counters.pop(ir.Jump, 0) != 1:
        raise errors.CompilerError(
            "with's head-block must have exactly 1 JUMP",
            loc=blk.loc,
            )
    # Can have any number of del
    counters.pop(ir.Del, None)
    # There MUST NOT be any other statements
    if counters:
        raise errors.CompilerError(
            "illegal statements in with's head-block",
            loc=blk.loc,
            )


def _cfg_nodes_in_region(cfg, region_begin, region_end):
    """Find the set of CFG nodes that are in the given region
    """
    region_nodes = set()
    stack = [region_begin]
    while stack:
        tos = stack.pop()
        succlist = list(cfg.successors(tos))
        # a single block function will have a empty successor list
        if succlist:
            succs, _ = zip(*succlist)
            nodes = set([node for node in succs
                        if node not in region_nodes and
                        node != region_end])
            stack.extend(nodes)
            region_nodes |= nodes

    return region_nodes


def find_setupwiths(func_ir):
    """Find all top-level with.

    Returns a list of ranges for the with-regions.
    """
    def find_ranges(blocks):

        cfg = compute_cfg_from_blocks(blocks)
        sus_setups, sus_pops = set(), set()
        # traverse the cfg and collect all suspected SETUP_WITH and POP_BLOCK
        # statements so that we can iterate over them
        for label, block in blocks.items():
            for stmt in block.body:
                if ir_utils.is_setup_with(stmt):
                    sus_setups.add(label)
                if ir_utils.is_pop_block(stmt):
                    sus_pops.add(label)

        # now that we do have the statements, iterate through them in reverse
        # topo order and from each start looking for pop_blocks
        setup_with_to_pop_blocks_map = defaultdict(set)
        for setup_block in cfg.topo_sort(sus_setups, reverse=True):
            # begin pop_block, search
            to_visit, seen = [], []
            to_visit.append(setup_block)
            while to_visit:
                # get whatever is next and record that we have seen it
                block = to_visit.pop()
                seen.append(block)
                # go through the body of the block, looking for statements
                for stmt in blocks[block].body:
                    # raise detected before pop_block
                    if ir_utils.is_raise(stmt):
                            raise errors.CompilerError(
                                'unsupported control flow due to raise '
                                'statements inside with block'
                                )
                    # if a pop_block, process it
                    if ir_utils.is_pop_block(stmt) and block in sus_pops:
                        # record the jump target of this block belonging to this setup
                        setup_with_to_pop_blocks_map[setup_block].add(block)
                        # remove the block from blocks to be matched
                        sus_pops.remove(block)
                        # stop looking, we have reached the frontier
                        break
                    # if we are still here, by the block terminator,
                    # add all its targets to the to_visit stack, unless we
                    # have seen them already
                    if ir_utils.is_terminator(stmt):
                        for t in stmt.get_targets():
                            if t not in seen:
                                to_visit.append(t)

        return setup_with_to_pop_blocks_map

    blocks = func_ir.blocks
    # initial find, will return a dictionary, mapping indices of blocks
    # containing SETUP_WITH statements to a set of indices of blocks containing
    # POP_BLOCK statements
    with_ranges_dict = find_ranges(blocks)
    # rewrite the CFG in case there are multiple POP_BLOCK statements for one
    # with
    func_ir = consolidate_multi_exit_withs(with_ranges_dict, blocks, func_ir)
    # here we need to turn the withs back into a list of tuples so that the
    # rest of the code can cope
    with_ranges_tuple = [(s, list(p)[0])
             for (s, p) in with_ranges_dict.items()]

    # check for POP_BLOCKS with multiple outgoing edges and reject
    for (_, p) in with_ranges_tuple:
        targets = blocks[p].terminator.get_targets()
        if len(targets) != 1:
            raise errors.CompilerError(
                "unsupported control flow: with-context contains branches "
                "(i.e. break/return/raise) that can leave the block "
            )
    # now we check for returns inside with and reject them
    for (_, p) in with_ranges_tuple:
        target_block = blocks[p]
        if ir_utils.is_return(func_ir.blocks[
                target_block.terminator.get_targets()[0]].terminator):
            if PYVERSION == (3, 8):
                # 3.8 needs to bail here, if this is the case, because the
                # later code can't handle it.
                raise errors.CompilerError(
                    "unsupported control flow: due to return statements "
                    "inside with block"
                )
            _rewrite_return(func_ir, p)

    # now we need to rewrite the tuple such that we have SETUP_WITH matching the
    # successor of the block that contains the POP_BLOCK.
    with_ranges_tuple = [(s, func_ir.blocks[p].terminator.get_targets()[0])
                         for (s, p) in with_ranges_tuple]

    # finally we check for nested with statements and reject them
    with_ranges_tuple = _eliminate_nested_withs(with_ranges_tuple)

    return with_ranges_tuple, func_ir


def _rewrite_return(func_ir, target_block_label):
    """Rewrite a return block inside a with statement.

    Arguments
    ---------

    func_ir: Function IR
      the CFG to transform
    target_block_label: int
      the block index/label of the block containing the POP_BLOCK statement


    This implements a CFG transformation to insert a block between two other
    blocks.

    The input situation is:

    ┌───────────────┐
    │   top         │
    │   POP_BLOCK   │
    │   bottom      │
    └───────┬───────┘

    ┌───────▼───────┐
    │               │
    │    RETURN     │
    │               │
    └───────────────┘

    If such a pattern is detected in IR, it means there is a `return` statement
    within a `with` context. The basic idea is to rewrite the CFG as follows:

    ┌───────────────┐
    │   top         │
    │   POP_BLOCK   │
    │               │
    └───────┬───────┘

    ┌───────▼───────┐
    │               │
    │     bottom    │
    │               │
    └───────┬───────┘

    ┌───────▼───────┐
    │               │
    │    RETURN     │
    │               │
    └───────────────┘

    We split the block that contains the `POP_BLOCK` statement into two blocks.
    Everything from the beginning of the block up to and including the
    `POP_BLOCK` statement is considered the 'top' and everything below is
    considered 'bottom'. Finally the jump statements are re-wired to make sure
    the CFG remains valid.

    """
    # the block itself from the index
    target_block = func_ir.blocks[target_block_label]
    # get the index of the block containing the return
    target_block_successor_label = target_block.terminator.get_targets()[0]
    # the return block
    target_block_successor = func_ir.blocks[target_block_successor_label]

    # create the new return block with an appropriate label
    max_label = ir_utils.find_max_label(func_ir.blocks)
    new_label = max_label + 1
    # create the new return block
    new_block_loc = target_block_successor.loc
    new_block_scope = ir.Scope(None, loc=new_block_loc)
    new_block = ir.Block(new_block_scope, loc=new_block_loc)

    # Split the block containing the POP_BLOCK into top and bottom
    # Block must be of the form:
    # -----------------
    # <some stmts>
    # POP_BLOCK
    # <some more stmts>
    # JUMP
    # -----------------
    top_body, bottom_body = [], []
    pop_blocks = [*target_block.find_insts(ir.PopBlock)]
    assert len(pop_blocks) == 1
    assert len([*target_block.find_insts(ir.Jump)]) == 1
    assert isinstance(target_block.body[-1], ir.Jump)
    pb_marker = pop_blocks[0]
    pb_is = target_block.body.index(pb_marker)
    top_body.extend(target_block.body[:pb_is])
    top_body.append(ir.Jump(target_block_successor_label, target_block.loc))
    bottom_body.extend(target_block.body[pb_is:-1])
    bottom_body.append(ir.Jump(new_label, target_block.loc))

    # get the contents of the return block
    return_body = func_ir.blocks[target_block_successor_label].body
    # finally, re-assign all blocks
    new_block.body.extend(return_body)
    target_block_successor.body.clear()
    target_block_successor.body.extend(bottom_body)
    target_block.body.clear()
    target_block.body.extend(top_body)

    # finally, append the new return block and rebuild the IR properties
    func_ir.blocks[new_label] = new_block
    func_ir._definitions = ir_utils.build_definitions(func_ir.blocks)
    return func_ir


def _eliminate_nested_withs(with_ranges):
    known_ranges = []
    def within_known_range(start, end, known_ranges):
        for a, b in known_ranges:
            # FIXME: this should be a comparison in topological order, right
            # now we are comparing the integers of the blocks, stuff probably
            # works by accident.
            if start > a and end < b:
                return True
        return False

    for s, e in sorted(with_ranges):
        if not within_known_range(s, e, known_ranges):
            known_ranges.append((s, e))

    return known_ranges

def consolidate_multi_exit_withs(withs: dict, blocks, func_ir):
    """Modify the FunctionIR to merge the exit blocks of with constructs.
    """
    for k in withs:
        vs : set = withs[k]
        if len(vs) > 1:
            func_ir, common = _fix_multi_exit_blocks(
                func_ir, vs, split_condition=ir_utils.is_pop_block,
            )
            withs[k] = {common}
    return func_ir


def _fix_multi_exit_blocks(func_ir, exit_nodes, *, split_condition=None):
    """Modify the FunctionIR to create a single common exit node given the
    original exit nodes.

    Parameters
    ----------
    func_ir :
        The FunctionIR. Mutated inplace.
    exit_nodes :
        The original exit nodes. A sequence of block keys.
    split_condition : callable or None
        If not None, it is a callable with the signature
        `split_condition(statement)` that determines if the `statement` is the
        splitting point (e.g. `POP_BLOCK`) in an exit node.
        If it's None, the exit node is not split.
    """

    # Convert the following:
    #
    #     |           |
    # +-------+   +-------+
    # | exit0 |   | exit1 |
    # +-------+   +-------+
    #     |           |
    # +-------+   +-------+
    # | after0|   | after1|
    # +-------+   +-------+
    #     |           |
    #
    # To roughly:
    #
    #     |           |
    # +-------+   +-------+
    # | exit0 |   | exit1 |
    # +-------+   +-------+
    #     |           |
    #     +-----+-----+
    #           |
    #      +---------+
    #      | common  |
    #      +---------+
    #           |
    #       +-------+
    #       | post  |
    #       +-------+
    #           |
    #     +-----+-----+
    #     |           |
    # +-------+   +-------+
    # | after0|   | after1|
    # +-------+   +-------+

    blocks = func_ir.blocks
    # Getting the scope
    any_blk = min(func_ir.blocks.values())
    scope = any_blk.scope
    # Getting the maximum block label
    max_label = max(func_ir.blocks) + 1
    # Define the new common block for the new exit.
    common_block = ir.Block(any_blk.scope, loc=ir.unknown_loc)
    common_label = max_label
    max_label += 1
    blocks[common_label] = common_block
    # Define the new block after the exit.
    post_block = ir.Block(any_blk.scope, loc=ir.unknown_loc)
    post_label = max_label
    max_label += 1
    blocks[post_label] = post_block

    # Adjust each exit node
    remainings = []
    for i, k in enumerate(exit_nodes):
        blk = blocks[k]

        # split the block if needed
        if split_condition is not None:
            for pt, stmt in enumerate(blk.body):
                if split_condition(stmt):
                    break
        else:
            # no splitting
            pt = -1

        before = blk.body[:pt]
        after = blk.body[pt:]
        remainings.append(after)

        # Add control-point variable to mark which exit block this is.
        blk.body = before
        loc = blk.loc
        blk.body.append(
            ir.Assign(value=ir.Const(i, loc=loc),
                      target=scope.get_or_define("$cp", loc=loc),
                      loc=loc)
        )
        # Replace terminator with a jump to the common block
        assert not blk.is_terminated
        blk.body.append(ir.Jump(common_label, loc=ir.unknown_loc))

    if split_condition is not None:
        # Move the splitting statement to the common block
        common_block.body.append(remainings[0][0])
    assert not common_block.is_terminated
    # Append jump from common block to post block
    common_block.body.append(ir.Jump(post_label, loc=loc))

    # Make if-else tree to jump to target
    remain_blocks = []
    for remain in remainings:
        remain_blocks.append(max_label)
        max_label += 1

    switch_block = post_block
    loc = ir.unknown_loc
    for i, remain in enumerate(remainings):
        match_expr = scope.redefine("$cp_check", loc=loc)
        match_rhs = scope.redefine("$cp_rhs", loc=loc)

        # Do comparison to match control-point variable to the exit block
        switch_block.body.append(
            ir.Assign(
                value=ir.Const(i, loc=loc),
                target=match_rhs,
                loc=loc
            ),
        )

        # Add assignment for the comparison
        switch_block.body.append(
            ir.Assign(
                value=ir.Expr.binop(
                    fn=operator.eq, lhs=scope.get("$cp"), rhs=match_rhs,
                    loc=loc,
                ),
                target=match_expr,
                loc=loc
            ),
        )

        # Insert jump to the next case
        [jump_target] = remain[-1].get_targets()
        switch_block.body.append(
            ir.Branch(match_expr, jump_target, remain_blocks[i], loc=loc),
        )
        switch_block = ir.Block(scope=scope, loc=loc)
        blocks[remain_blocks[i]] = switch_block

    # Add the final jump
    switch_block.body.append(ir.Jump(jump_target, loc=loc))

    return func_ir, common_label