test_errorhandling.py 17.5 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
"""
Unspecified error handling tests
"""

import numpy as np
import os
import warnings

from numba import jit, njit, typed, int64, types
from numba.core import errors
from numba.experimental import structref
from numba.extending import (overload, intrinsic, overload_method,
                             overload_attribute)
from numba.core.compiler import CompilerBase
from numba.core.untyped_passes import (TranslateByteCode, FixupArgs,
                                       IRProcessing,)
from numba.core.typed_passes import (NopythonTypeInference, DeadCodeElimination,
                                     NoPythonBackend, NativeLowering)
from numba.core.compiler_machinery import PassManager
from numba.core.types.functions import _err_reasons as error_reasons

from numba.tests.support import (skip_parfors_unsupported, override_config,
                                 SerialMixin, skip_unless_cffi,
                                 skip_unless_scipy, TestCase)
import unittest

# used in TestMiscErrorHandling::test_handling_of_write_to_*_global
_global_list = [1, 2, 3, 4]
_global_dict = typed.Dict.empty(int64, int64)


class TestErrorHandlingBeforeLowering(unittest.TestCase):

    def test_unsupported_make_function_return_inner_func(self):
        def func(x):
            """ return the closure """
            z = x + 1

            def inner(x):
                return x + z
            return inner

        for pipeline in jit, njit:
            with self.assertRaises(errors.TypingError) as raises:
                pipeline(func)(1)

            expected = "Cannot capture the non-constant value"
            self.assertIn(expected, str(raises.exception))


class TestUnsupportedReporting(unittest.TestCase):

    def test_unsupported_numpy_function(self):
        # np.asanyarray(list) currently unsupported
        @njit
        def func():
            np.asanyarray([1,2,3])

        with self.assertRaises(errors.TypingError) as raises:
            func()

        expected = "Use of unsupported NumPy function 'numpy.asanyarray'"
        self.assertIn(expected, str(raises.exception))


class TestMiscErrorHandling(unittest.TestCase):

    def test_use_of_exception_for_flow_control(self):
        # constant inference uses exceptions with no Loc specified to determine
        # flow control, this asserts that the construction of the lowering
        # error context handler works in the case of an exception with no Loc
        # specified. See issue #3135.
        @njit
        def fn(x):
            return 10**x

        a = np.array([1.0],dtype=np.float64)
        fn(a) # should not raise

    def test_commented_func_definition_is_not_a_definition(self):
        # See issue #4056, the commented def should not be found as the
        # definition for reporting purposes when creating the synthetic
        # traceback because it is commented! Use of def in docstring would also
        # cause this issue hence is tested.

        def foo_commented():
            #def commented_definition()
            raise Exception('test_string')

        def foo_docstring():
            """ def docstring containing def might match function definition!"""
            raise Exception('test_string')

        for func in (foo_commented, foo_docstring):
            with self.assertRaises(Exception) as raises:
                func()

            self.assertIn("test_string", str(raises.exception))

    def test_use_of_ir_unknown_loc(self):
        # for context see # 3390
        class TestPipeline(CompilerBase):
            def define_pipelines(self):
                name = 'bad_DCE_pipeline'
                pm = PassManager(name)
                pm.add_pass(TranslateByteCode, "analyzing bytecode")
                pm.add_pass(FixupArgs, "fix up args")
                pm.add_pass(IRProcessing, "processing IR")
                # remove dead before type inference so that the Arg node is
                # removed and the location of the arg cannot be found
                pm.add_pass(DeadCodeElimination, "DCE")
                # typing
                pm.add_pass(NopythonTypeInference, "nopython frontend")
                pm.add_pass(NativeLowering, "native lowering")
                pm.add_pass(NoPythonBackend, "nopython mode backend")
                pm.finalize()
                return [pm]

        @njit(pipeline_class=TestPipeline)
        def f(a):
            return 0

        with self.assertRaises(errors.TypingError) as raises:
            f(iter([1,2]))  # use a type that Numba doesn't recognize

        expected = 'File "unknown location", line 0:'
        self.assertIn(expected, str(raises.exception))

    def check_write_to_globals(self, func):
        with self.assertRaises(errors.TypingError) as raises:
            func()

        expected = ["The use of a", "in globals, is not supported as globals"]
        for ex in expected:
            self.assertIn(ex, str(raises.exception))

    def test_handling_of_write_to_reflected_global(self):
        @njit
        def foo():
            _global_list[0] = 10

        self.check_write_to_globals(foo)

    def test_handling_of_write_to_typed_dict_global(self):
        @njit
        def foo():
            _global_dict[0] = 10

        self.check_write_to_globals(foo)

    @skip_parfors_unsupported
    def test_handling_forgotten_numba_internal_import(self):
        @njit(parallel=True)
        def foo():
            for i in prange(10): # noqa: F821 prange is not imported
                pass

        with self.assertRaises(errors.TypingError) as raises:
            foo()

        expected = ("'prange' looks like a Numba internal function, "
                    "has it been imported")
        self.assertIn(expected, str(raises.exception))

    def test_handling_unsupported_generator_expression(self):
        def foo():
            (x for x in range(10))

        expected = "The use of yield in a closure is unsupported."

        for dec in jit(forceobj=True), njit:
            with self.assertRaises(errors.UnsupportedError) as raises:
                dec(foo)()
            self.assertIn(expected, str(raises.exception))

    def test_handling_undefined_variable(self):
        @njit
        def foo():
            return a # noqa: F821

        expected = "NameError: name 'a' is not defined"

        with self.assertRaises(errors.TypingError) as raises:
            foo()
        self.assertIn(expected, str(raises.exception))


class TestErrorMessages(unittest.TestCase):

    def test_specific_error(self):

        given_reason = "specific_reason"

        def foo():
            pass

        @overload(foo)
        def ol_foo():
            raise errors.NumbaValueError(given_reason)

        @njit
        def call_foo():
            foo()

        with self.assertRaises(errors.TypingError) as raises:
            call_foo()

        excstr = str(raises.exception)
        self.assertIn(error_reasons['specific_error'].splitlines()[0], excstr)
        self.assertIn(given_reason, excstr)

    def test_no_match_error(self):

        def foo():
            pass

        @overload(foo)
        def ol_foo():
            return None # emulate no impl available for type

        @njit
        def call_foo():
            foo()

        with self.assertRaises(errors.TypingError) as raises:
            call_foo()

        excstr = str(raises.exception)
        self.assertIn("No match", excstr)

    @skip_unless_scipy
    def test_error_function_source_is_correct(self):
        """ Checks that the reported source location for an overload is the
        overload implementation source, not the actual function source from the
        target library."""

        @njit
        def foo():
            np.linalg.svd("chars")

        with self.assertRaises(errors.TypingError) as raises:
            foo()

        excstr = str(raises.exception)
        self.assertIn(error_reasons['specific_error'].splitlines()[0], excstr)
        expected_file = os.path.join("numba", "np", "linalg.py")
        expected = f"Overload in function 'svd_impl': File: {expected_file}:"
        self.assertIn(expected.format(expected_file), excstr)

    def test_concrete_template_source(self):
        # hits ConcreteTemplate
        @njit
        def foo():
            return 'a' + 1

        with self.assertRaises(errors.TypingError) as raises:
            foo()

        excstr = str(raises.exception)

        self.assertIn("Overload of function 'add'", excstr)
        # there'll be numerous matched templates that don't work but as they
        # are mostly "overload"s they'll just appear as "No match".
        self.assertIn("No match.", excstr)

    def test_abstract_template_source(self):
        # hits AbstractTemplate
        @njit
        def foo():
            return len(1)

        with self.assertRaises(errors.TypingError) as raises:
            foo()

        excstr = str(raises.exception)
        self.assertIn("Overload of function 'len'", excstr)

    def test_callable_template_source(self):
        # hits CallableTemplate
        @njit
        def foo():
            return np.angle(None)

        with self.assertRaises(errors.TypingError) as raises:
            foo()

        excstr = str(raises.exception)
        self.assertIn("No implementation of function Function(<function angle",
                      excstr)

    def test_overloadfunction_template_source(self):
        # hits _OverloadFunctionTemplate
        def bar(x):
            pass

        @overload(bar)
        def ol_bar(x):
            pass

        @njit
        def foo():
            return bar(1)

        with self.assertRaises(errors.TypingError) as raises:
            foo()

        excstr = str(raises.exception)
        # there will not be "numerous" matched templates, there's just one,
        # the one above, so assert it is reported
        self.assertNotIn("<numerous>", excstr)
        expected_file = os.path.join("numba", "tests",
                                     "test_errorhandling.py")
        expected_ol = f"Overload of function 'bar': File: {expected_file}:"
        self.assertIn(expected_ol.format(expected_file), excstr)
        self.assertIn("No match.", excstr)

    def test_intrinsic_template_source(self):
        # hits _IntrinsicTemplate
        given_reason1 = "x must be literal"
        given_reason2 = "array.ndim must be 1"

        @intrinsic
        def myintrin(typingctx, x, arr):
            if not isinstance(x, types.IntegerLiteral):
                raise errors.RequireLiteralValue(given_reason1)

            if arr.ndim != 1:
                raise errors.NumbaValueError(given_reason2)

            sig = types.intp(x, arr)

            def codegen(context, builder, signature, args):
                pass
            return sig, codegen

        @njit
        def call_intrin():
            arr = np.zeros((2, 2))
            myintrin(1, arr)

        with self.assertRaises(errors.TypingError) as raises:
            call_intrin()

        excstr = str(raises.exception)
        self.assertIn(error_reasons['specific_error'].splitlines()[0], excstr)
        self.assertIn(given_reason1, excstr)
        self.assertIn(given_reason2, excstr)
        self.assertIn("Intrinsic in function", excstr)

    def test_overloadmethod_template_source(self):
        # doesn't hit _OverloadMethodTemplate for source as it's a nested
        # exception
        @overload_method(types.UnicodeType, 'isnonsense')
        def ol_unicode_isnonsense(self):
            pass

        @njit
        def foo():
            "abc".isnonsense()

        with self.assertRaises(errors.TypingError) as raises:
            foo()

        excstr = str(raises.exception)
        self.assertIn("Overload of function 'ol_unicode_isnonsense'", excstr)

    def test_overloadattribute_template_source(self):
        # doesn't hit _OverloadMethodTemplate for source as it's a nested
        # exception
        @overload_attribute(types.UnicodeType, 'isnonsense')
        def ol_unicode_isnonsense(self):
            pass

        @njit
        def foo():
            "abc".isnonsense

        with self.assertRaises(errors.TypingError) as raises:
            foo()

        excstr = str(raises.exception)
        self.assertIn("Overload of function 'ol_unicode_isnonsense'", excstr)

    def test_external_function_pointer_template_source(self):
        from numba.tests.ctypes_usecases import c_cos

        @njit
        def foo():
            c_cos('a')

        with self.assertRaises(errors.TypingError) as raises:
            foo()

        excstr = str(raises.exception)
        self.assertIn("Type Restricted Function in function 'unknown'", excstr)

    @skip_unless_cffi
    def test_cffi_function_pointer_template_source(self):
        from numba.tests import cffi_usecases as mod
        mod.init()
        func = mod.cffi_cos

        @njit
        def foo():
            func('a')

        with self.assertRaises(errors.TypingError) as raises:
            foo()

        excstr = str(raises.exception)
        self.assertIn("Type Restricted Function in function 'unknown'", excstr)

    def test_missing_source(self):

        @structref.register
        class ParticleType(types.StructRef):
            pass

        class Particle(structref.StructRefProxy):
            def __new__(cls, pos, mass):
                return structref.StructRefProxy.__new__(cls, pos)
                # didn't provide the required mass argument ----^

        structref.define_proxy(Particle, ParticleType, ["pos", "mass"])

        with self.assertRaises(errors.TypingError) as raises:
            Particle(pos=1, mass=2)

        excstr = str(raises.exception)
        self.assertIn("missing a required argument: 'mass'", excstr)


class TestDeveloperSpecificErrorMessages(SerialMixin, unittest.TestCase):

    def test_bound_function_error_string(self):
        # See PR #5952
        def foo(x):
            x.max(-1)

        with override_config('DEVELOPER_MODE', 1):
            with self.assertRaises(errors.TypingError) as raises:
                njit("void(int64[:,:])")(foo)

        excstr = str(raises.exception)
        self.assertIn("too many positional arguments", excstr)


class TestCapturedErrorHandling(SerialMixin, TestCase):
    """Checks that the way errors are captured changes depending on the env
    var "NUMBA_CAPTURED_ERRORS".
    """

    def test_error_in_overload(self):

        def bar(x):
            pass

        @overload(bar)
        def ol_bar(x):
            x.some_invalid_attr # doesn't exist!

            def impl(x):
                pass
            return impl

        with warnings.catch_warnings():
            # Suppress error going into stdout
            warnings.simplefilter("ignore",
                                  errors.NumbaPendingDeprecationWarning)
            # Check both new_style and old_style
            for style, err_class in (('new_style', AttributeError),
                                     ('old_style', errors.TypingError)):
                with override_config('CAPTURED_ERRORS', style):
                    with self.assertRaises(err_class) as raises:

                        @njit('void(int64)')
                        def foo(x):
                            bar(x)
                    expected = "object has no attribute 'some_invalid_attr'"
                    self.assertIn(expected, str(raises.exception))

    @TestCase.run_test_in_subprocess(
        envvars={"NUMBA_CAPTURED_ERRORS": "old_style"},
    )
    def test_old_style_deprecation(self):
        # Verify that old_style error raise the correct deprecation warning
        warnings.simplefilter("always", errors.NumbaPendingDeprecationWarning)

        def bar(x):
            pass

        @overload(bar)
        def ol_bar(x):
            raise AttributeError("Invalid attribute")

        with self.assertWarns(errors.NumbaPendingDeprecationWarning) as warns:
            with self.assertRaises(errors.TypingError):
                @njit('void(int64)')
                def foo(x):
                    bar(x)

            self.assertIn(
                "Code using Numba extension API maybe depending on 'old_style' "
                "error-capturing",
                str(warns.warnings[0].message),
            )

    @TestCase.run_test_in_subprocess(
        envvars={"NUMBA_CAPTURED_ERRORS": "old_style"},
    )
    def test_old_style_no_deprecation(self):
        # Verify that old_style error with NumbaError does not raise warnings
        warnings.simplefilter("always", errors.NumbaPendingDeprecationWarning)

        def bar(x):
            pass

        @overload(bar)
        def ol_bar(x):
            raise errors.TypingError("Invalid attribute")

        with warnings.catch_warnings(record=True) as warns:
            with self.assertRaises(errors.TypingError):
                @njit('void(int64)')
                def foo(x):
                    bar(x)

            self.assertEqual(len(warns), 0,
                             msg="There should not be any warnings")

    @TestCase.run_test_in_subprocess(
        envvars={"NUMBA_CAPTURED_ERRORS": "new_style"},
    )
    def test_new_style_no_warnings(self):
        # Verify that new_style error raise no warnings
        warnings.simplefilter("always", errors.NumbaPendingDeprecationWarning)

        def bar(x):
            pass

        @overload(bar)
        def ol_bar(x):
            raise AttributeError("Invalid attribute")

        with warnings.catch_warnings(record=True) as warns:
            with self.assertRaises(AttributeError):
                @njit('void(int64)')
                def foo(x):
                    bar(x)
            # There should not be any warnings
            self.assertEqual(len(warns), 0,
                             msg="There should not be any warnings")


if __name__ == '__main__':
    unittest.main()