test_withlifting.py 34.8 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
import copy
import os
import signal
import subprocess
import sys
import tempfile
import threading
import warnings
import numpy as np

import numba
from numba.core.transforms import find_setupwiths, with_lifting
from numba.core.withcontexts import bypass_context, call_context, objmode_context
from numba.core.bytecode import FunctionIdentity, ByteCode
from numba.core.interpreter import Interpreter
from numba.core import typing, errors, cpu
from numba.core.registry import cpu_target
from numba.core.compiler import compile_ir, DEFAULT_FLAGS
from numba import njit, typeof, objmode, types
from numba.core.extending import overload
from numba.tests.support import (MemoryLeak, TestCase, captured_stdout,
                                 skip_unless_scipy, linux_only,
                                 strace_supported, strace,
                                 expected_failure_py311)
from numba.core.utils import PYVERSION
from numba.experimental import jitclass
import unittest


def get_func_ir(func):
    func_id = FunctionIdentity.from_function(func)
    bc = ByteCode(func_id=func_id)
    interp = Interpreter(func_id)
    func_ir = interp.interpret(bc)
    return func_ir


def lift1():
    print("A")
    with bypass_context:
        print("B")
        b()
    print("C")


def lift2():
    x = 1
    print("A", x)
    x = 1
    with bypass_context:
        print("B", x)
        x += 100
        b()
    x += 1
    with bypass_context:
        print("C", x)
        b()
        x += 10
    x += 1
    print("D", x)


def lift3():
    x = 1
    y = 100
    print("A", x, y)
    with bypass_context:
        print("B")
        b()
        x += 100
        with bypass_context:
            print("C")
            y += 100000
            b()
    x += 1
    y += 1
    print("D", x, y)


def lift4():
    x = 0
    print("A", x)
    x += 10
    with bypass_context:
        print("B")
        b()
        x += 1
        for i in range(10):
            with bypass_context:
                print("C")
                b()
                x += i
    with bypass_context:
        print("D")
        b()
        if x:
            x *= 10
    x += 1
    print("E", x)


def lift5():
    print("A")


def liftcall1():
    x = 1
    print("A", x)
    with call_context:
        x += 1
    print("B", x)
    return x


def liftcall2():
    x = 1
    print("A", x)
    with call_context:
        x += 1
    print("B", x)
    with call_context:
        x += 10
    print("C", x)
    return x


def liftcall3():
    x = 1
    print("A", x)
    with call_context:
        if x > 0:
            x += 1
    print("B", x)
    with call_context:
        for i in range(10):
            x += i
    print("C", x)
    return x


def liftcall4():
    with call_context:
        with call_context:
            pass


def liftcall5():
    for i in range(10):
        with call_context:
            print(i)
            if i == 5:
                print("A")
                break
    return i


def lift_undefiend():
    with undefined_global_var:
        pass


bogus_contextmanager = object()


def lift_invalid():
    with bogus_contextmanager:
        pass


gv_type = types.intp


class TestWithFinding(TestCase):
    def check_num_of_with(self, func, expect_count):
        the_ir = get_func_ir(func)
        ct = len(find_setupwiths(the_ir)[0])
        self.assertEqual(ct, expect_count)

    def test_lift1(self):
        self.check_num_of_with(lift1, expect_count=1)

    def test_lift2(self):
        self.check_num_of_with(lift2, expect_count=2)

    def test_lift3(self):
        self.check_num_of_with(lift3, expect_count=1)

    def test_lift4(self):
        self.check_num_of_with(lift4, expect_count=2)

    def test_lift5(self):
        self.check_num_of_with(lift5, expect_count=0)


class BaseTestWithLifting(TestCase):
    def setUp(self):
        super(BaseTestWithLifting, self).setUp()
        self.typingctx = typing.Context()
        self.targetctx = cpu.CPUContext(self.typingctx)
        self.flags = DEFAULT_FLAGS

    def check_extracted_with(self, func, expect_count, expected_stdout):
        the_ir = get_func_ir(func)
        new_ir, extracted = with_lifting(
            the_ir, self.typingctx, self.targetctx, self.flags,
            locals={},
        )
        self.assertEqual(len(extracted), expect_count)
        cres = self.compile_ir(new_ir)

        with captured_stdout() as out:
            cres.entry_point()

        self.assertEqual(out.getvalue(), expected_stdout)

    def compile_ir(self, the_ir, args=(), return_type=None):
        typingctx = self.typingctx
        targetctx = self.targetctx
        flags = self.flags
        # Register the contexts in case for nested @jit or @overload calls
        with cpu_target.nested_context(typingctx, targetctx):
            return compile_ir(typingctx, targetctx, the_ir, args,
                              return_type, flags, locals={})


class TestLiftByPass(BaseTestWithLifting):

    def test_lift1(self):
        self.check_extracted_with(lift1, expect_count=1,
                                  expected_stdout="A\nC\n")

    def test_lift2(self):
        self.check_extracted_with(lift2, expect_count=2,
                                  expected_stdout="A 1\nD 3\n")

    def test_lift3(self):
        self.check_extracted_with(lift3, expect_count=1,
                                  expected_stdout="A 1 100\nD 2 101\n")

    def test_lift4(self):
        self.check_extracted_with(lift4, expect_count=2,
                                  expected_stdout="A 0\nE 11\n")

    def test_lift5(self):
        self.check_extracted_with(lift5, expect_count=0,
                                  expected_stdout="A\n")


class TestLiftCall(BaseTestWithLifting):

    def check_same_semantic(self, func):
        """Ensure same semantic with non-jitted code
        """
        jitted = njit(func)
        with captured_stdout() as got:
            jitted()

        with captured_stdout() as expect:
            func()

        self.assertEqual(got.getvalue(), expect.getvalue())

    def test_liftcall1(self):
        self.check_extracted_with(liftcall1, expect_count=1,
                                  expected_stdout="A 1\nB 2\n")
        self.check_same_semantic(liftcall1)

    def test_liftcall2(self):
        self.check_extracted_with(liftcall2, expect_count=2,
                                  expected_stdout="A 1\nB 2\nC 12\n")
        self.check_same_semantic(liftcall2)

    def test_liftcall3(self):
        self.check_extracted_with(liftcall3, expect_count=2,
                                  expected_stdout="A 1\nB 2\nC 47\n")
        self.check_same_semantic(liftcall3)

    def test_liftcall4(self):
        accept = (errors.TypingError, errors.NumbaRuntimeError,
                  errors.NumbaValueError, errors.CompilerError)
        with self.assertRaises(accept) as raises:
            njit(liftcall4)()
        # Known error.  We only support one context manager per function
        # for body that are lifted.
        msg = ("compiler re-entrant to the same function signature")
        self.assertIn(msg, str(raises.exception))

    # 3.8 and earlier fails to interpret the bytecode for this example
    @unittest.skipIf(PYVERSION <= (3, 8),
                     "unsupported on py3.8 and before")
    @expected_failure_py311
    def test_liftcall5(self):
        self.check_extracted_with(liftcall5, expect_count=1,
                                  expected_stdout="0\n1\n2\n3\n4\n5\nA\n")
        self.check_same_semantic(liftcall5)


def expected_failure_for_list_arg(fn):
    def core(self, *args, **kwargs):
        with self.assertRaises(errors.TypingError) as raises:
            fn(self, *args, **kwargs)
        self.assertIn('Does not support list type',
                      str(raises.exception))
    return core


def expected_failure_for_function_arg(fn):
    def core(self, *args, **kwargs):
        with self.assertRaises(errors.TypingError) as raises:
            fn(self, *args, **kwargs)
        self.assertIn('Does not support function type',
                      str(raises.exception))
    return core


class TestLiftObj(MemoryLeak, TestCase):

    def setUp(self):
        warnings.simplefilter("error", errors.NumbaWarning)

    def tearDown(self):
        warnings.resetwarnings()

    def assert_equal_return_and_stdout(self, pyfunc, *args):
        py_args = copy.deepcopy(args)
        c_args = copy.deepcopy(args)
        cfunc = njit(pyfunc)

        with captured_stdout() as stream:
            expect_res = pyfunc(*py_args)
            expect_out = stream.getvalue()

        # avoid compiling during stdout-capturing for easier print-debugging
        cfunc.compile(tuple(map(typeof, c_args)))
        with captured_stdout() as stream:
            got_res = cfunc(*c_args)
            got_out = stream.getvalue()

        self.assertEqual(expect_out, got_out)
        self.assertPreciseEqual(expect_res, got_res)

    def test_lift_objmode_basic(self):
        def bar(ival):
            print("ival =", {'ival': ival // 2})

        def foo(ival):
            ival += 1
            with objmode_context:
                bar(ival)
            return ival + 1

        def foo_nonglobal(ival):
            ival += 1
            with numba.objmode:
                bar(ival)
            return ival + 1

        self.assert_equal_return_and_stdout(foo, 123)
        self.assert_equal_return_and_stdout(foo_nonglobal, 123)

    def test_lift_objmode_array_in(self):
        def bar(arr):
            print({'arr': arr // 2})
            # arr is modified. the effect is visible outside.
            arr *= 2

        def foo(nelem):
            arr = np.arange(nelem).astype(np.int64)
            with objmode_context:
                # arr is modified inplace inside bar()
                bar(arr)
            return arr + 1

        nelem = 10
        self.assert_equal_return_and_stdout(foo, nelem)

    def test_lift_objmode_define_new_unused(self):
        def bar(y):
            print(y)

        def foo(x):
            with objmode_context():
                y = 2 + x           # defined but unused outside
                a = np.arange(y)    # defined but unused outside
                bar(a)
            return x

        arg = 123
        self.assert_equal_return_and_stdout(foo, arg)

    def test_lift_objmode_return_simple(self):
        def inverse(x):
            print(x)
            return 1 / x

        def foo(x):
            with objmode_context(y="float64"):
                y = inverse(x)
            return x, y

        def foo_nonglobal(x):
            with numba.objmode(y="float64"):
                y = inverse(x)
            return x, y

        arg = 123
        self.assert_equal_return_and_stdout(foo, arg)
        self.assert_equal_return_and_stdout(foo_nonglobal, arg)

    def test_lift_objmode_return_array(self):
        def inverse(x):
            print(x)
            return 1 / x

        def foo(x):
            with objmode_context(y="float64[:]", z="int64"):
                y = inverse(x)
                z = int(y[0])
            return x, y, z

        arg = np.arange(1, 10, dtype=np.float64)
        self.assert_equal_return_and_stdout(foo, arg)

    @expected_failure_for_list_arg
    def test_lift_objmode_using_list(self):
        def foo(x):
            with objmode_context(y="float64[:]"):
                print(x)
                x[0] = 4
                print(x)
                y = [1, 2, 3] + x
                y = np.asarray([1 / i for i in y])
            return x, y

        arg = [1, 2, 3]
        self.assert_equal_return_and_stdout(foo, arg)

    def test_lift_objmode_var_redef(self):
        def foo(x):
            for x in range(x):
                pass
            if x:
                x += 1
            with objmode_context(x="intp"):
                print(x)
                x -= 1
                print(x)
                for i in range(x):
                    x += i
                    print(x)
            return x

        arg = 123
        self.assert_equal_return_and_stdout(foo, arg)

    @expected_failure_for_list_arg
    def test_case01_mutate_list_ahead_of_ctx(self):
        def foo(x, z):
            x[2] = z

            with objmode_context():
                # should print [1, 2, 15] but prints [1, 2, 3]
                print(x)

            with objmode_context():
                x[2] = 2 * z
                # should print [1, 2, 30] but prints [1, 2, 15]
                print(x)

            return x

        self.assert_equal_return_and_stdout(foo, [1, 2, 3], 15)

    def test_case02_mutate_array_ahead_of_ctx(self):
        def foo(x, z):
            x[2] = z

            with objmode_context():
                # should print [1, 2, 15]
                print(x)

            with objmode_context():
                x[2] = 2 * z
                # should print [1, 2, 30]
                print(x)

            return x

        x = np.array([1, 2, 3])
        self.assert_equal_return_and_stdout(foo, x, 15)

    @expected_failure_for_list_arg
    def test_case03_create_and_mutate(self):
        def foo(x):
            with objmode_context(y='List(int64)'):
                y = [1, 2, 3]
            with objmode_context():
                y[2] = 10
            return y
        self.assert_equal_return_and_stdout(foo, 1)

    def test_case04_bogus_variable_type_info(self):

        def foo(x):
            # should specifying nonsense type info be considered valid?
            with objmode_context(k="float64[:]"):
                print(x)
            return x

        x = np.array([1, 2, 3])
        cfoo = njit(foo)
        with self.assertRaises(errors.TypingError) as raises:
            cfoo(x)
        self.assertIn(
            "Invalid type annotation on non-outgoing variables",
            str(raises.exception),
            )

    def test_case05_bogus_type_info(self):
        def foo(x):
            # should specifying the wrong type info be considered valid?
            # z is complex.
            # Note: for now, we will coerce for scalar and raise for array
            with objmode_context(z="float64[:]"):
                z = x + 1.j
            return z

        x = np.array([1, 2, 3])
        cfoo = njit(foo)
        with self.assertRaises(TypeError) as raises:
            got = cfoo(x)
        self.assertIn(
            ("can't unbox array from PyObject into native value."
             "  The object maybe of a different type"),
            str(raises.exception),
        )

    def test_case06_double_objmode(self):
        def foo(x):
            # would nested ctx in the same scope ever make sense? Is this
            # pattern useful?
            with objmode_context():
                #with npmmode_context(): not implemented yet
                    with objmode_context():
                        print(x)
            return x

        with self.assertRaises(errors.TypingError) as raises:
            njit(foo)(123)
        # Check that an error occurred in with-lifting in objmode
        pat = ("During: resolving callee type: "
               "type\(ObjModeLiftedWith\(<.*>\)\)")
        self.assertRegex(str(raises.exception), pat)

    def test_case07_mystery_key_error(self):
        # this raises a key error
        def foo(x):
            with objmode_context():
                t = {'a': x}
                u = 3
            return x, t, u
        x = np.array([1, 2, 3])
        cfoo = njit(foo)

        with self.assertRaises(errors.TypingError) as raises:
            cfoo(x)

        exstr = str(raises.exception)
        self.assertIn("Missing type annotation on outgoing variable(s): "
                      "['t', 'u']",
                      exstr)
        self.assertIn("Example code: with objmode"
                      "(t='<add_type_as_string_here>')",
                      exstr)

    def test_case08_raise_from_external(self):
        # this segfaults, expect its because the dict needs to raise as '2' is
        # not in the keys until a later loop (looking for `d['0']` works fine).
        d = dict()

        def foo(x):
            for i in range(len(x)):
                with objmode_context():
                    k = str(i)
                    v = x[i]
                    d[k] = v
                    print(d['2'])
            return x

        x = np.array([1, 2, 3])
        cfoo = njit(foo)
        with self.assertRaises(KeyError) as raises:
            cfoo(x)
        self.assertEqual(str(raises.exception), "'2'")

    def test_case09_explicit_raise(self):
        def foo(x):
            with objmode_context():
                raise ValueError()
            return x

        x = np.array([1, 2, 3])
        cfoo = njit(foo)
        with self.assertRaises(errors.CompilerError) as raises:
            cfoo(x)
        self.assertIn(
            ('unsupported control flow due to raise statements inside '
             'with block'),
            str(raises.exception),
        )

    @expected_failure_for_list_arg
    def test_case10_mutate_across_contexts(self):
        # This shouldn't work due to using List as input.
        def foo(x):
            with objmode_context(y='List(int64)'):
                y = [1, 2, 3]
            with objmode_context():
                y[2] = 10
            return y

        x = np.array([1, 2, 3])
        self.assert_equal_return_and_stdout(foo, x)

    def test_case10_mutate_array_across_contexts(self):
        # Sub-case of case-10.
        def foo(x):
            with objmode_context(y='int64[:]'):
                y = np.asarray([1, 2, 3], dtype='int64')
            with objmode_context():
                # Note: `y` is not an output.
                y[2] = 10
            return y

        x = np.array([1, 2, 3])
        self.assert_equal_return_and_stdout(foo, x)

    def test_case11_define_function_in_context(self):
        # should this work? no, global name 'bar' is not defined
        def foo(x):
            with objmode_context():
                def bar(y):
                    return y + 1
            return x

        x = np.array([1, 2, 3])
        cfoo = njit(foo)
        with self.assertRaises(NameError) as raises:
            cfoo(x)
        self.assertIn(
            "global name 'bar' is not defined",
            str(raises.exception),
        )

    def test_case12_njit_inside_a_objmode_ctx(self):
        # TODO: is this still the cases?
        # this works locally but not inside this test, probably due to the way
        # compilation is being done
        def bar(y):
            return y + 1

        def foo(x):
            with objmode_context(y='int64[:]'):
                y = njit(bar)(x).astype('int64')
            return x + y

        x = np.array([1, 2, 3])
        self.assert_equal_return_and_stdout(foo, x)

    def test_case14_return_direct_from_objmode_ctx(self):
        def foo(x):
            with objmode_context(x='int64[:]'):
                x += 1
                return x

        if PYVERSION <= (3,8):
            # 3.8 and below don't support return inside with
            with self.assertRaises(errors.CompilerError) as raises:
                cfoo = njit(foo)
                cfoo(np.array([1, 2, 3]))
            msg = "unsupported control flow: due to return statements inside with block"
            self.assertIn(msg, str(raises.exception))
        else:
            result = foo(np.array([1, 2, 3]))
            np.testing.assert_array_equal(np.array([2, 3, 4]), result)

    # No easy way to handle this yet.
    @unittest.expectedFailure
    def test_case15_close_over_objmode_ctx(self):
        # Fails with Unsupported constraint encountered: enter_with $phi8.1
        def foo(x):
            j = 10

            def bar(x):
                with objmode_context(x='int64[:]'):
                    print(x)
                    return x + j
            return bar(x) + 2
        x = np.array([1, 2, 3])
        self.assert_equal_return_and_stdout(foo, x)

    @skip_unless_scipy
    def test_case16_scipy_call_in_objmode_ctx(self):
        from scipy import sparse as sp

        def foo(x):
            with objmode_context(k='int64'):
                print(x)
                spx = sp.csr_matrix(x)
                # the np.int64 call is pointless, works around:
                # https://github.com/scipy/scipy/issues/10206
                # which hit the SciPy 1.3 release.
                k = np.int64(spx[0, 0])
            return k
        x = np.array([1, 2, 3])
        self.assert_equal_return_and_stdout(foo, x)

    def test_case17_print_own_bytecode(self):
        import dis

        def foo(x):
            with objmode_context():
                dis.dis(foo)
        x = np.array([1, 2, 3])
        self.assert_equal_return_and_stdout(foo, x)

    @expected_failure_for_function_arg
    def test_case18_njitfunc_passed_to_objmode_ctx(self):
        def foo(func, x):
            with objmode_context():
                func(x[0])

        x = np.array([1, 2, 3])
        fn = njit(lambda z: z + 5)
        self.assert_equal_return_and_stdout(foo, fn, x)

    @expected_failure_py311
    def test_case19_recursion(self):
        def foo(x):
            with objmode_context():
                if x == 0:
                    return 7
            ret = foo(x - 1)
            return ret
        with self.assertRaises((errors.TypingError, errors.CompilerError)) as raises:
            cfoo = njit(foo)
            cfoo(np.array([1, 2, 3]))
        msg = "Untyped global name 'foo'"
        self.assertIn(msg, str(raises.exception))

    @unittest.expectedFailure
    def test_case20_rng_works_ok(self):
        def foo(x):
            np.random.seed(0)
            y = np.random.rand()
            with objmode_context(z="float64"):
                # It's known that the random state does not sync
                z = np.random.rand()
            return x + z + y

        x = np.array([1, 2, 3])
        self.assert_equal_return_and_stdout(foo, x)

    def test_case21_rng_seed_works_ok(self):
        def foo(x):
            np.random.seed(0)
            y = np.random.rand()
            with objmode_context(z="float64"):
                # Similar to test_case20_rng_works_ok but call seed
                np.random.seed(0)
                z = np.random.rand()
            return x + z + y

        x = np.array([1, 2, 3])
        self.assert_equal_return_and_stdout(foo, x)

    def test_example01(self):
        # Example from _ObjModeContextType.__doc__
        def bar(x):
            return np.asarray(list(reversed(x.tolist())))

        @njit
        def foo():
            x = np.arange(5)
            with objmode(y='intp[:]'):  # annotate return type
                # this region is executed by object-mode.
                y = x + bar(x)
            return y

        self.assertPreciseEqual(foo(), foo.py_func())
        self.assertIs(objmode, objmode_context)

    def test_objmode_in_overload(self):
        def foo(s):
            pass

        @overload(foo)
        def foo_overload(s):
            def impl(s):
                with objmode(out='intp'):
                    out = s + 3
                return out
            return impl

        @numba.njit
        def f():
            return foo(1)

        self.assertEqual(f(), 1 + 3)

    def test_objmode_gv_variable(self):
        @njit
        def global_var():
            with objmode(val=gv_type):
                val = 12.3
            return val

        ret = global_var()
        # the result is truncated because of the intp return-type
        self.assertIsInstance(ret, int)
        self.assertEqual(ret, 12)

    def test_objmode_gv_variable_error(self):
        @njit
        def global_var():
            with objmode(val=gv_type2):
                val = 123
            return val

        with self.assertRaisesRegex(
            errors.CompilerError,
            ("Error handling objmode argument 'val'. "
             "Global 'gv_type2' is not defined\.")
        ):
            global_var()

    def test_objmode_gv_mod_attr(self):
        @njit
        def modattr1():
            with objmode(val=types.intp):
                val = 12.3
            return val

        @njit
        def modattr2():
            with objmode(val=numba.types.intp):
                val = 12.3
            return val

        for fn in (modattr1, modattr2):
            with self.subTest(fn=str(fn)):
                ret = fn()
                # the result is truncated because of the intp return-type
                self.assertIsInstance(ret, int)
                self.assertEqual(ret, 12)

    def test_objmode_gv_mod_attr_error(self):
        @njit
        def moderror():
            with objmode(val=types.THIS_DOES_NOT_EXIST):
                val = 12.3
            return val
        with self.assertRaisesRegex(
            errors.CompilerError,
            ("Error handling objmode argument 'val'. "
             "Getattr cannot be resolved at compile-time"),
        ):
            moderror()

    def test_objmode_gv_mod_attr_error_multiple(self):
        @njit
        def moderror():
            with objmode(v1=types.intp, v2=types.THIS_DOES_NOT_EXIST,
                         v3=types.float32):
                v1 = 12.3
                v2 = 12.3
                v3 = 12.3
            return val
        with self.assertRaisesRegex(
            errors.CompilerError,
            ("Error handling objmode argument 'v2'. "
             "Getattr cannot be resolved at compile-time"),
        ):
            moderror()

    def test_objmode_closure_type_in_overload(self):
        def foo():
            pass

        @overload(foo)
        def foo_overload():
            shrubbery = types.float64[:]
            def impl():
                with objmode(out=shrubbery):
                    out = np.arange(10).astype(np.float64)
                return out
            return impl

        @njit
        def bar():
            return foo()

        self.assertPreciseEqual(bar(), np.arange(10).astype(np.float64))

    def test_objmode_closure_type_in_overload_error(self):
        def foo():
            pass

        @overload(foo)
        def foo_overload():
            shrubbery = types.float64[:]
            def impl():
                with objmode(out=shrubbery):
                    out = np.arange(10).astype(np.float64)
                return out
            # Remove closure var.
            # Otherwise, it will "shrubbery" will be a global
            del shrubbery
            return impl

        @njit
        def bar():
            return foo()

        with self.assertRaisesRegex(
            errors.TypingError,
            ("Error handling objmode argument 'out'. "
             "Freevar 'shrubbery' is not defined"),
        ):
            bar()

    def test_objmode_invalid_use(self):
        @njit
        def moderror():
            with objmode(bad=1 + 1):
                out = 1
            return val
        with self.assertRaisesRegex(
            errors.CompilerError,
            ("Error handling objmode argument 'bad'. "
             "The value must be a compile-time constant either as "
             "a non-local variable or a getattr expression that "
             "refers to a Numba type."),
        ):
            moderror()

    def test_objmode_multi_type_args(self):
        array_ty = types.int32[:]
        @njit
        def foo():
            # t1 is a string
            # t2 is a global type
            # t3 is a non-local/freevar
            with objmode(t1="float64", t2=gv_type, t3=array_ty):
                t1 = 793856.5
                t2 = t1         # to observe truncation
                t3 = np.arange(5).astype(np.int32)
            return t1, t2, t3

        t1, t2, t3 = foo()
        self.assertPreciseEqual(t1, 793856.5)
        self.assertPreciseEqual(t2, 793856)
        self.assertPreciseEqual(t3, np.arange(5).astype(np.int32))

    def test_objmode_jitclass(self):
        spec = [
            ('value', types.int32),               # a simple scalar field
            ('array', types.float32[:]),          # an array field
        ]

        @jitclass(spec)
        class Bag(object):
            def __init__(self, value):
                self.value = value
                self.array = np.zeros(value, dtype=np.float32)

            @property
            def size(self):
                return self.array.size

            def increment(self, val):
                for i in range(self.size):
                    self.array[i] += val
                return self.array

            @staticmethod
            def add(x, y):
                return x + y

        n = 21
        mybag = Bag(n)

        def foo():
            pass

        @overload(foo)
        def foo_overload():
            shrubbery = mybag._numba_type_
            def impl():
                with objmode(out=shrubbery):
                    out = Bag(123)
                    out.increment(3)
                return out
            return impl

        @njit
        def bar():
            return foo()

        z = bar()
        self.assertIsInstance(z, Bag)
        self.assertEqual(z.add(2, 3), 2 + 3)
        exp_array = np.zeros(123, dtype=np.float32) + 3
        self.assertPreciseEqual(z.array, exp_array)


    @staticmethod
    def case_objmode_cache(x):
        with objmode(output='float64'):
            output = x / 10
        return output

    def test_objmode_reflected_list(self):
        ret_type = typeof([1, 2, 3, 4, 5])
        @njit
        def test2():
            with objmode(out=ret_type):
                out = [1, 2, 3, 4, 5]
            return out

        with self.assertRaises(errors.CompilerError) as raises:
            test2()
        self.assertRegex(
            str(raises.exception),
            (r"Objmode context failed. "
             r"Argument 'out' is declared as an unsupported type: "
             r"reflected list\(int(32|64)\)<iv=None>. "
             r"Reflected types are not supported."),
        )

    def test_objmode_reflected_set(self):
        ret_type = typeof({1, 2, 3, 4, 5})
        @njit
        def test2():
            with objmode(result=ret_type):
                result = {1, 2, 3, 4, 5}
            return result

        with self.assertRaises(errors.CompilerError) as raises:
            test2()
        self.assertRegex(
            str(raises.exception),
            (r"Objmode context failed. "
             r"Argument 'result' is declared as an unsupported type: "
             r"reflected set\(int(32|64)\). "
             r"Reflected types are not supported."),
        )

    def test_objmode_typed_dict(self):
        ret_type = types.DictType(types.unicode_type, types.int64)
        @njit
        def test4():
            with objmode(res=ret_type):
                res = {'A': 1, 'B': 2}
            return res

        with self.assertRaises(TypeError) as raises:
            test4()
        self.assertIn(
            ("can't unbox a <class 'dict'> "
             "as a <class 'numba.typed.typeddict.Dict'>"),
            str(raises.exception),
        )

    def test_objmode_typed_list(self):
        ret_type = types.ListType(types.int64)
        @njit
        def test4():
            with objmode(res=ret_type):
                res = [1, 2]
            return res

        with self.assertRaises(TypeError) as raises:
            test4()
        self.assertRegex(
            str(raises.exception),
            (r"can't unbox a <class 'list'> "
             r"as a (<class ')?numba.typed.typedlist.List('>)?"),
        )

    def test_objmode_use_of_view(self):
        # See issue #7158, npm functionality should only be validated if in
        # npm.
        @njit
        def foo(x):
            with numba.objmode(y="int64[::1]"):
                y = x.view("int64")
            return y

        a = np.ones(1, np.int64).view('float64')
        expected = foo.py_func(a)
        got = foo(a)
        self.assertPreciseEqual(expected, got)


def case_inner_pyfunc(x):
    return x / 10


def case_objmode_cache(x):
    with objmode(output='float64'):
        output = case_inner_pyfunc(x)
    return output


class TestLiftObjCaching(MemoryLeak, TestCase):
    # Warnings in this test class are converted to errors

    def setUp(self):
        warnings.simplefilter("error", errors.NumbaWarning)

    def tearDown(self):
        warnings.resetwarnings()

    def check(self, py_func):
        first = njit(cache=True)(py_func)
        self.assertEqual(first(123), 12.3)

        second = njit(cache=True)(py_func)
        self.assertFalse(second._cache_hits)
        self.assertEqual(second(123), 12.3)
        self.assertTrue(second._cache_hits)

    def test_objmode_caching_basic(self):
        def pyfunc(x):
            with objmode(output='float64'):
                output = x / 10
            return output

        self.check(pyfunc)

    def test_objmode_caching_call_closure_bad(self):
        def other_pyfunc(x):
            return x / 10

        def pyfunc(x):
            with objmode(output='float64'):
                output = other_pyfunc(x)
            return output

        self.check(pyfunc)

    def test_objmode_caching_call_closure_good(self):
        self.check(case_objmode_cache)


class TestBogusContext(BaseTestWithLifting):
    def test_undefined_global(self):
        the_ir = get_func_ir(lift_undefiend)

        with self.assertRaises(errors.CompilerError) as raises:
            with_lifting(
                the_ir, self.typingctx, self.targetctx, self.flags, locals={},
            )
        self.assertIn(
            "Undefined variable used as context manager",
            str(raises.exception),
            )

    def test_invalid(self):
        the_ir = get_func_ir(lift_invalid)

        with self.assertRaises(errors.CompilerError) as raises:
            with_lifting(
                the_ir, self.typingctx, self.targetctx, self.flags, locals={},
            )
        self.assertIn(
            "Unsupported context manager in use",
            str(raises.exception),
            )

    def test_with_as_fails_gracefully(self):
        @njit
        def foo():
            with open('') as f:
                pass

        with self.assertRaises(errors.UnsupportedError) as raises:
            foo()

        excstr = str(raises.exception)
        msg = ("The 'with (context manager) as (variable):' construct is not "
               "supported.")
        self.assertIn(msg, excstr)


class TestMisc(TestCase):
    # Tests for miscellaneous objmode issues. Run serially.

    _numba_parallel_test_ = False

    @linux_only
    @TestCase.run_test_in_subprocess
    def test_no_fork_in_compilation(self):
        # Checks that there is no fork/clone/execve during compilation, see
        # issue #7881. This needs running in a subprocess as the offending fork
        # call that triggered #7881 occurs on the first call to uuid1 as it's
        # part if the initialisation process for that function (gets hardware
        # address of machine).

        if not strace_supported():
            # Needs strace support.
            self.skipTest("strace support missing")

        def force_compile():
            @njit('void()') # force compilation
            def f():
                with numba.objmode():
                    pass

        # capture these syscalls:
        syscalls = ['fork', 'clone', 'execve']

        # check that compilation does not trigger fork, clone or execve
        strace_data = strace(force_compile, syscalls)
        self.assertFalse(strace_data)


if __name__ == '__main__':
    unittest.main()