import multiprocessing import platform import threading import pickle import weakref from itertools import chain from io import StringIO import numpy as np from numba import njit, jit, generated_jit, typeof, vectorize from numba.core import types, errors from numba import _dispatcher from numba.core.compiler import compile_isolated from numba.tests.support import TestCase, captured_stdout from numba.np.numpy_support import as_dtype from numba.core.dispatcher import Dispatcher from numba.tests.support import needs_lapack, SerialMixin from numba.testing.main import _TIMEOUT as _RUNNER_TIMEOUT import unittest _TEST_TIMEOUT = _RUNNER_TIMEOUT - 60. try: import jinja2 except ImportError: jinja2 = None try: import pygments except ImportError: pygments = None _is_armv7l = platform.machine() == 'armv7l' def dummy(x): return x def add(x, y): return x + y def addsub(x, y, z): return x - y + z def addsub_defaults(x, y=2, z=3): return x - y + z def star_defaults(x, y=2, *z): return x, y, z def generated_usecase(x, y=5): if isinstance(x, types.Complex): def impl(x, y): return x + y else: def impl(x, y): return x - y return impl def bad_generated_usecase(x, y=5): if isinstance(x, types.Complex): def impl(x): return x else: def impl(x, y=6): return x - y return impl def dtype_generated_usecase(a, b, dtype=None): if isinstance(dtype, (types.misc.NoneType, types.misc.Omitted)): out_dtype = np.result_type(*(np.dtype(ary.dtype.name) for ary in (a, b))) elif isinstance(dtype, (types.DType, types.NumberClass)): out_dtype = as_dtype(dtype) else: raise TypeError("Unhandled Type %s" % type(dtype)) def _fn(a, b, dtype=None): return np.ones(a.shape, dtype=out_dtype) return _fn class BaseTest(TestCase): jit_args = dict(nopython=True) def compile_func(self, pyfunc): def check(*args, **kwargs): expected = pyfunc(*args, **kwargs) result = f(*args, **kwargs) self.assertPreciseEqual(result, expected) f = jit(**self.jit_args)(pyfunc) return f, check class TestDispatcher(BaseTest): def test_equality(self): @jit def foo(x): return x @jit def bar(x): return x # Written this way to verify `==` returns a bool (gh-5838). Using # `assertTrue(foo == foo)` or `assertEqual(foo, foo)` would defeat the # purpose of this test. self.assertEqual(foo == foo, True) self.assertEqual(foo == bar, False) self.assertEqual(foo == None, False) # noqa: E711 def test_dyn_pyfunc(self): @jit def foo(x): return x foo(1) [cr] = foo.overloads.values() # __module__ must be match that of foo self.assertEqual(cr.entry_point.__module__, foo.py_func.__module__) def test_no_argument(self): @jit def foo(): return 1 # Just make sure this doesn't crash foo() def test_coerce_input_types(self): # Issue #486: do not allow unsafe conversions if we can still # compile other specializations. c_add = jit(nopython=True)(add) self.assertPreciseEqual(c_add(123, 456), add(123, 456)) self.assertPreciseEqual(c_add(12.3, 45.6), add(12.3, 45.6)) self.assertPreciseEqual(c_add(12.3, 45.6j), add(12.3, 45.6j)) self.assertPreciseEqual(c_add(12300000000, 456), add(12300000000, 456)) # Now force compilation of only a single specialization c_add = jit('(i4, i4)', nopython=True)(add) self.assertPreciseEqual(c_add(123, 456), add(123, 456)) # Implicit (unsafe) conversion of float to int self.assertPreciseEqual(c_add(12.3, 45.6), add(12, 45)) with self.assertRaises(TypeError): # Implicit conversion of complex to int disallowed c_add(12.3, 45.6j) def test_ambiguous_new_version(self): """Test compiling new version in an ambiguous case """ @jit def foo(a, b): return a + b INT = 1 FLT = 1.5 self.assertAlmostEqual(foo(INT, FLT), INT + FLT) self.assertEqual(len(foo.overloads), 1) self.assertAlmostEqual(foo(FLT, INT), FLT + INT) self.assertEqual(len(foo.overloads), 2) self.assertAlmostEqual(foo(FLT, FLT), FLT + FLT) self.assertEqual(len(foo.overloads), 3) # The following call is ambiguous because (int, int) can resolve # to (float, int) or (int, float) with equal weight. self.assertAlmostEqual(foo(1, 1), INT + INT) self.assertEqual(len(foo.overloads), 4, "didn't compile a new " "version") def test_lock(self): """ Test that (lazy) compiling from several threads at once doesn't produce errors (see issue #908). """ errors = [] @jit def foo(x): return x + 1 def wrapper(): try: self.assertEqual(foo(1), 2) except Exception as e: errors.append(e) threads = [threading.Thread(target=wrapper) for i in range(16)] for t in threads: t.start() for t in threads: t.join() self.assertFalse(errors) def test_explicit_signatures(self): f = jit("(int64,int64)")(add) # Approximate match (unsafe conversion) self.assertPreciseEqual(f(1.5, 2.5), 3) self.assertEqual(len(f.overloads), 1, f.overloads) f = jit(["(int64,int64)", "(float64,float64)"])(add) # Exact signature matches self.assertPreciseEqual(f(1, 2), 3) self.assertPreciseEqual(f(1.5, 2.5), 4.0) # Approximate match (int32 -> float64 is a safe conversion) self.assertPreciseEqual(f(np.int32(1), 2.5), 3.5) # No conversion with self.assertRaises(TypeError) as cm: f(1j, 1j) self.assertIn("No matching definition", str(cm.exception)) self.assertEqual(len(f.overloads), 2, f.overloads) # A more interesting one... f = jit(["(float32,float32)", "(float64,float64)"])(add) self.assertPreciseEqual(f(np.float32(1), np.float32(2**-25)), 1.0) self.assertPreciseEqual(f(1, 2**-25), 1.0000000298023224) # Fail to resolve ambiguity between the two best overloads f = jit(["(float32,float64)", "(float64,float32)", "(int64,int64)"])(add) with self.assertRaises(TypeError) as cm: f(1.0, 2.0) # The two best matches are output in the error message, as well # as the actual argument types. self.assertRegexpMatches( str(cm.exception), r"Ambiguous overloading for ]*> " r"\(float64, float64\):\n" r"\(float32, float64\) -> float64\n" r"\(float64, float32\) -> float64" ) # The integer signature is not part of the best matches self.assertNotIn("int64", str(cm.exception)) def test_signature_mismatch(self): tmpl = ("Signature mismatch: %d argument types given, but function " "takes 2 arguments") with self.assertRaises(TypeError) as cm: jit("()")(add) self.assertIn(tmpl % 0, str(cm.exception)) with self.assertRaises(TypeError) as cm: jit("(intc,)")(add) self.assertIn(tmpl % 1, str(cm.exception)) with self.assertRaises(TypeError) as cm: jit("(intc,intc,intc)")(add) self.assertIn(tmpl % 3, str(cm.exception)) # With forceobj=True, an empty tuple is accepted jit("()", forceobj=True)(add) with self.assertRaises(TypeError) as cm: jit("(intc,)", forceobj=True)(add) self.assertIn(tmpl % 1, str(cm.exception)) def test_matching_error_message(self): f = jit("(intc,intc)")(add) with self.assertRaises(TypeError) as cm: f(1j, 1j) self.assertEqual(str(cm.exception), "No matching definition for argument type(s) " "complex128, complex128") def test_disabled_compilation(self): @jit def foo(a): return a foo.compile("(float32,)") foo.disable_compile() with self.assertRaises(RuntimeError) as raises: foo.compile("(int32,)") self.assertEqual(str(raises.exception), "compilation disabled") self.assertEqual(len(foo.signatures), 1) def test_disabled_compilation_through_list(self): @jit(["(float32,)", "(int32,)"]) def foo(a): return a with self.assertRaises(RuntimeError) as raises: foo.compile("(complex64,)") self.assertEqual(str(raises.exception), "compilation disabled") self.assertEqual(len(foo.signatures), 2) def test_disabled_compilation_nested_call(self): @jit(["(intp,)"]) def foo(a): return a @jit def bar(): foo(1) foo(np.ones(1)) # no matching definition with self.assertRaises(TypeError) as raises: bar() m = "No matching definition for argument type(s) array(float64, 1d, C)" self.assertEqual(str(raises.exception), m) def test_fingerprint_failure(self): """ Failure in computing the fingerprint cannot affect a nopython=False function. On the other hand, with nopython=True, a ValueError should be raised to report the failure with fingerprint. """ @jit def foo(x): return x # Empty list will trigger failure in compile_fingerprint errmsg = 'cannot compute fingerprint of empty list' with self.assertRaises(ValueError) as raises: _dispatcher.compute_fingerprint([]) self.assertIn(errmsg, str(raises.exception)) # It should work in fallback self.assertEqual(foo([]), []) # But, not in nopython=True strict_foo = jit(nopython=True)(foo.py_func) with self.assertRaises(ValueError) as raises: strict_foo([]) self.assertIn(errmsg, str(raises.exception)) # Test in loop lifting context @jit def bar(): object() # force looplifting x = [] for i in range(10): x = foo(x) return x self.assertEqual(bar(), []) # Make sure it was looplifted [cr] = bar.overloads.values() self.assertEqual(len(cr.lifted), 1) def test_serialization(self): """ Test serialization of Dispatcher objects """ @jit(nopython=True) def foo(x): return x + 1 self.assertEqual(foo(1), 2) # get serialization memo memo = Dispatcher._memo Dispatcher._recent.clear() memo_size = len(memo) # pickle foo and check memo size serialized_foo = pickle.dumps(foo) # increases the memo size self.assertEqual(memo_size + 1, len(memo)) # unpickle foo_rebuilt = pickle.loads(serialized_foo) self.assertEqual(memo_size + 1, len(memo)) self.assertIs(foo, foo_rebuilt) # do we get the same object even if we delete all the explicit # references? id_orig = id(foo_rebuilt) del foo del foo_rebuilt self.assertEqual(memo_size + 1, len(memo)) new_foo = pickle.loads(serialized_foo) self.assertEqual(id_orig, id(new_foo)) # now clear the recent cache ref = weakref.ref(new_foo) del new_foo Dispatcher._recent.clear() self.assertEqual(memo_size, len(memo)) # show that deserializing creates a new object pickle.loads(serialized_foo) self.assertIs(ref(), None) @needs_lapack @unittest.skipIf(_is_armv7l, "Unaligned loads unsupported") def test_misaligned_array_dispatch(self): # for context see issue #2937 def foo(a): return np.linalg.matrix_power(a, 1) jitfoo = jit(nopython=True)(foo) n = 64 r = int(np.sqrt(n)) dt = np.int8 count = np.complex128().itemsize // dt().itemsize tmp = np.arange(n * count + 1, dtype=dt) # create some arrays as Cartesian production of: # [F/C] x [aligned/misaligned] C_contig_aligned = tmp[:-1].view(np.complex128).reshape(r, r) C_contig_misaligned = tmp[1:].view(np.complex128).reshape(r, r) F_contig_aligned = C_contig_aligned.T F_contig_misaligned = C_contig_misaligned.T # checking routine def check(name, a): a[:, :] = np.arange(n, dtype=np.complex128).reshape(r, r) expected = foo(a) got = jitfoo(a) np.testing.assert_allclose(expected, got) # The checks must be run in this order to create the dispatch key # sequence that causes invalid dispatch noted in #2937. # The first two should hit the cache as they are aligned, supported # order and under 5 dimensions. The second two should end up in the # fallback path as they are misaligned. check("C_contig_aligned", C_contig_aligned) check("F_contig_aligned", F_contig_aligned) check("C_contig_misaligned", C_contig_misaligned) check("F_contig_misaligned", F_contig_misaligned) @unittest.skipIf(_is_armv7l, "Unaligned loads unsupported") def test_immutability_in_array_dispatch(self): # RO operation in function def foo(a): return np.sum(a) jitfoo = jit(nopython=True)(foo) n = 64 r = int(np.sqrt(n)) dt = np.int8 count = np.complex128().itemsize // dt().itemsize tmp = np.arange(n * count + 1, dtype=dt) # create some arrays as Cartesian production of: # [F/C] x [aligned/misaligned] C_contig_aligned = tmp[:-1].view(np.complex128).reshape(r, r) C_contig_misaligned = tmp[1:].view(np.complex128).reshape(r, r) F_contig_aligned = C_contig_aligned.T F_contig_misaligned = C_contig_misaligned.T # checking routine def check(name, a, disable_write_bit=False): a[:, :] = np.arange(n, dtype=np.complex128).reshape(r, r) if disable_write_bit: a.flags.writeable = False expected = foo(a) got = jitfoo(a) np.testing.assert_allclose(expected, got) # all of these should end up in the fallback path as they have no write # bit set check("C_contig_aligned", C_contig_aligned, disable_write_bit=True) check("F_contig_aligned", F_contig_aligned, disable_write_bit=True) check("C_contig_misaligned", C_contig_misaligned, disable_write_bit=True) check("F_contig_misaligned", F_contig_misaligned, disable_write_bit=True) @needs_lapack @unittest.skipIf(_is_armv7l, "Unaligned loads unsupported") def test_misaligned_high_dimension_array_dispatch(self): def foo(a): return np.linalg.matrix_power(a[0, 0, 0, 0, :, :], 1) jitfoo = jit(nopython=True)(foo) def check_properties(arr, layout, aligned): self.assertEqual(arr.flags.aligned, aligned) if layout == "C": self.assertEqual(arr.flags.c_contiguous, True) if layout == "F": self.assertEqual(arr.flags.f_contiguous, True) n = 729 r = 3 dt = np.int8 count = np.complex128().itemsize // dt().itemsize tmp = np.arange(n * count + 1, dtype=dt) # create some arrays as Cartesian production of: # [F/C] x [aligned/misaligned] C_contig_aligned = tmp[:-1].view(np.complex128).\ reshape(r, r, r, r, r, r) check_properties(C_contig_aligned, 'C', True) C_contig_misaligned = tmp[1:].view(np.complex128).\ reshape(r, r, r, r, r, r) check_properties(C_contig_misaligned, 'C', False) F_contig_aligned = C_contig_aligned.T check_properties(F_contig_aligned, 'F', True) F_contig_misaligned = C_contig_misaligned.T check_properties(F_contig_misaligned, 'F', False) # checking routine def check(name, a): a[:, :] = np.arange(n, dtype=np.complex128).\ reshape(r, r, r, r, r, r) expected = foo(a) got = jitfoo(a) np.testing.assert_allclose(expected, got) # these should all hit the fallback path as the cache is only for up to # 5 dimensions check("F_contig_misaligned", F_contig_misaligned) check("C_contig_aligned", C_contig_aligned) check("F_contig_aligned", F_contig_aligned) check("C_contig_misaligned", C_contig_misaligned) def test_dispatch_recompiles_for_scalars(self): # for context #3612, essentially, compiling a lambda x:x for a # numerically wide type (everything can be converted to a complex128) # and then calling again with e.g. an int32 would lead to the int32 # being converted to a complex128 whereas it ought to compile an int32 # specialization. def foo(x): return x # jit and compile on dispatch for 3 scalar types, expect 3 signatures jitfoo = jit(nopython=True)(foo) jitfoo(np.complex128(1 + 2j)) jitfoo(np.int32(10)) jitfoo(np.bool_(False)) self.assertEqual(len(jitfoo.signatures), 3) expected_sigs = [(types.complex128,), (types.int32,), (types.bool_,)] self.assertEqual(jitfoo.signatures, expected_sigs) # now jit with signatures so recompilation is forbidden # expect 1 signature and type conversion jitfoo = jit([(types.complex128,)], nopython=True)(foo) jitfoo(np.complex128(1 + 2j)) jitfoo(np.int32(10)) jitfoo(np.bool_(False)) self.assertEqual(len(jitfoo.signatures), 1) expected_sigs = [(types.complex128,)] self.assertEqual(jitfoo.signatures, expected_sigs) def test_dispatcher_raises_for_invalid_decoration(self): # For context see https://github.com/numba/numba/issues/4750. @jit(nopython=True) def foo(x): return x with self.assertRaises(TypeError) as raises: jit(foo) err_msg = str(raises.exception) self.assertIn( "A jit decorator was called on an already jitted function", err_msg) self.assertIn("foo", err_msg) self.assertIn(".py_func", err_msg) with self.assertRaises(TypeError) as raises: jit(BaseTest) err_msg = str(raises.exception) self.assertIn("The decorated object is not a function", err_msg) self.assertIn(f"{type(BaseTest)}", err_msg) class TestSignatureHandling(BaseTest): """ Test support for various parameter passing styles. """ def test_named_args(self): """ Test passing named arguments to a dispatcher. """ f, check = self.compile_func(addsub) check(3, z=10, y=4) check(3, 4, 10) check(x=3, y=4, z=10) # All calls above fall under the same specialization self.assertEqual(len(f.overloads), 1) # Errors with self.assertRaises(TypeError) as cm: f(3, 4, y=6, z=7) self.assertIn("too many arguments: expected 3, got 4", str(cm.exception)) with self.assertRaises(TypeError) as cm: f() self.assertIn("not enough arguments: expected 3, got 0", str(cm.exception)) with self.assertRaises(TypeError) as cm: f(3, 4, y=6) self.assertIn("missing argument 'z'", str(cm.exception)) def test_default_args(self): """ Test omitting arguments with a default value. """ f, check = self.compile_func(addsub_defaults) check(3, z=10, y=4) check(3, 4, 10) check(x=3, y=4, z=10) # Now omitting some values check(3, z=10) check(3, 4) check(x=3, y=4) check(3) check(x=3) # Errors with self.assertRaises(TypeError) as cm: f(3, 4, y=6, z=7) self.assertIn("too many arguments: expected 3, got 4", str(cm.exception)) with self.assertRaises(TypeError) as cm: f() self.assertIn("not enough arguments: expected at least 1, got 0", str(cm.exception)) with self.assertRaises(TypeError) as cm: f(y=6, z=7) self.assertIn("missing argument 'x'", str(cm.exception)) def test_star_args(self): """ Test a compiled function with starargs in the signature. """ f, check = self.compile_func(star_defaults) check(4) check(4, 5) check(4, 5, 6) check(4, 5, 6, 7) check(4, 5, 6, 7, 8) check(x=4) check(x=4, y=5) check(4, y=5) with self.assertRaises(TypeError) as cm: f(4, 5, y=6) self.assertIn("some keyword arguments unexpected", str(cm.exception)) with self.assertRaises(TypeError) as cm: f(4, 5, z=6) self.assertIn("some keyword arguments unexpected", str(cm.exception)) with self.assertRaises(TypeError) as cm: f(4, x=6) self.assertIn("some keyword arguments unexpected", str(cm.exception)) class TestSignatureHandlingObjectMode(TestSignatureHandling): """ Sams as TestSignatureHandling, but in object mode. """ jit_args = dict(forceobj=True) class TestGeneratedDispatcher(TestCase): """ Tests for @generated_jit. """ def test_generated(self): f = generated_jit(nopython=True)(generated_usecase) self.assertEqual(f(8), 8 - 5) self.assertEqual(f(x=8), 8 - 5) self.assertEqual(f(x=8, y=4), 8 - 4) self.assertEqual(f(1j), 5 + 1j) self.assertEqual(f(1j, 42), 42 + 1j) self.assertEqual(f(x=1j, y=7), 7 + 1j) def test_generated_dtype(self): f = generated_jit(nopython=True)(dtype_generated_usecase) a = np.ones((10,), dtype=np.float32) b = np.ones((10,), dtype=np.float64) self.assertEqual(f(a, b).dtype, np.float64) self.assertEqual(f(a, b, dtype=np.dtype('int32')).dtype, np.int32) self.assertEqual(f(a, b, dtype=np.int32).dtype, np.int32) def test_signature_errors(self): """ Check error reporting when implementation signature doesn't match generating function signature. """ f = generated_jit(nopython=True)(bad_generated_usecase) # Mismatching # of arguments with self.assertRaises(TypeError) as raises: f(1j) self.assertIn("should be compatible with signature '(x, y=5)', " "but has signature '(x)'", str(raises.exception)) # Mismatching defaults with self.assertRaises(TypeError) as raises: f(1) self.assertIn("should be compatible with signature '(x, y=5)', " "but has signature '(x, y=6)'", str(raises.exception)) class TestDispatcherMethods(TestCase): def test_recompile(self): closure = 1 @jit def foo(x): return x + closure self.assertPreciseEqual(foo(1), 2) self.assertPreciseEqual(foo(1.5), 2.5) self.assertEqual(len(foo.signatures), 2) closure = 2 self.assertPreciseEqual(foo(1), 2) # Recompiling takes the new closure into account. foo.recompile() # Everything was recompiled self.assertEqual(len(foo.signatures), 2) self.assertPreciseEqual(foo(1), 3) self.assertPreciseEqual(foo(1.5), 3.5) def test_recompile_signatures(self): # Same as above, but with an explicit signature on @jit. closure = 1 @jit("int32(int32)") def foo(x): return x + closure self.assertPreciseEqual(foo(1), 2) self.assertPreciseEqual(foo(1.5), 2) closure = 2 self.assertPreciseEqual(foo(1), 2) # Recompiling takes the new closure into account. foo.recompile() self.assertPreciseEqual(foo(1), 3) self.assertPreciseEqual(foo(1.5), 3) def test_inspect_llvm(self): # Create a jited function @jit def foo(explicit_arg1, explicit_arg2): return explicit_arg1 + explicit_arg2 # Call it in a way to create 3 signatures foo(1, 1) foo(1.0, 1) foo(1.0, 1.0) # base call to get all llvm in a dict llvms = foo.inspect_llvm() self.assertEqual(len(llvms), 3) # make sure the function name shows up in the llvm for llvm_bc in llvms.values(): # Look for the function name self.assertIn("foo", llvm_bc) # Look for the argument names self.assertIn("explicit_arg1", llvm_bc) self.assertIn("explicit_arg2", llvm_bc) def test_inspect_asm(self): # Create a jited function @jit def foo(explicit_arg1, explicit_arg2): return explicit_arg1 + explicit_arg2 # Call it in a way to create 3 signatures foo(1, 1) foo(1.0, 1) foo(1.0, 1.0) # base call to get all llvm in a dict asms = foo.inspect_asm() self.assertEqual(len(asms), 3) # make sure the function name shows up in the llvm for asm in asms.values(): # Look for the function name self.assertTrue("foo" in asm) def _check_cfg_display(self, cfg, wrapper=''): # simple stringify test if wrapper: wrapper = "{}{}".format(len(wrapper), wrapper) module_name = __name__.split('.', 1)[0] module_len = len(module_name) prefix = r'^digraph "CFG for \'_ZN{}{}{}'.format(wrapper, module_len, module_name) self.assertRegexpMatches(str(cfg), prefix) # .display() requires an optional dependency on `graphviz`. # just test for the attribute without running it. self.assertTrue(callable(cfg.display)) def test_inspect_cfg(self): # Exercise the .inspect_cfg(). These are minimal tests and do not fully # check the correctness of the function. @jit def foo(the_array): return the_array.sum() # Generate 3 overloads a1 = np.ones(1) a2 = np.ones((1, 1)) a3 = np.ones((1, 1, 1)) foo(a1) foo(a2) foo(a3) # Call inspect_cfg() without arguments cfgs = foo.inspect_cfg() # Correct count of overloads self.assertEqual(len(cfgs), 3) # Makes sure all the signatures are correct [s1, s2, s3] = cfgs.keys() self.assertEqual(set([s1, s2, s3]), set(map(lambda x: (typeof(x),), [a1, a2, a3]))) for cfg in cfgs.values(): self._check_cfg_display(cfg) self.assertEqual(len(list(cfgs.values())), 3) # Call inspect_cfg(signature) cfg = foo.inspect_cfg(signature=foo.signatures[0]) self._check_cfg_display(cfg) def test_inspect_cfg_with_python_wrapper(self): # Exercise the .inspect_cfg() including the python wrapper. # These are minimal tests and do not fully check the correctness of # the function. @jit def foo(the_array): return the_array.sum() # Generate 3 overloads a1 = np.ones(1) a2 = np.ones((1, 1)) a3 = np.ones((1, 1, 1)) foo(a1) foo(a2) foo(a3) # Call inspect_cfg(signature, show_wrapper="python") cfg = foo.inspect_cfg(signature=foo.signatures[0], show_wrapper="python") self._check_cfg_display(cfg, wrapper='cpython') def test_inspect_types(self): @jit def foo(a, b): return a + b foo(1, 2) # Exercise the method foo.inspect_types(StringIO()) # Test output expected = str(foo.overloads[foo.signatures[0]].type_annotation) with captured_stdout() as out: foo.inspect_types() assert expected in out.getvalue() def test_inspect_types_with_signature(self): @jit def foo(a): return a + 1 foo(1) foo(1.0) # Inspect all signatures with captured_stdout() as total: foo.inspect_types() # Inspect first signature with captured_stdout() as first: foo.inspect_types(signature=foo.signatures[0]) # Inspect second signature with captured_stdout() as second: foo.inspect_types(signature=foo.signatures[1]) self.assertEqual(total.getvalue(), first.getvalue() + second.getvalue()) @unittest.skipIf(jinja2 is None, "please install the 'jinja2' package") @unittest.skipIf(pygments is None, "please install the 'pygments' package") def test_inspect_types_pretty(self): @jit def foo(a, b): return a + b foo(1, 2) # Exercise the method, dump the output with captured_stdout(): ann = foo.inspect_types(pretty=True) # ensure HTML is found in the annotation output for k, v in ann.ann.items(): span_found = False for line in v['pygments_lines']: if 'span' in line[2]: span_found = True self.assertTrue(span_found) # check that file+pretty kwarg combo raises with self.assertRaises(ValueError) as raises: foo.inspect_types(file=StringIO(), pretty=True) self.assertIn("`file` must be None if `pretty=True`", str(raises.exception)) def test_get_annotation_info(self): @jit def foo(a): return a + 1 foo(1) foo(1.3) expected = dict(chain.from_iterable(foo.get_annotation_info(i).items() for i in foo.signatures)) result = foo.get_annotation_info() self.assertEqual(expected, result) def test_issue_with_array_layout_conflict(self): """ This test an issue with the dispatcher when an array that is both C and F contiguous is supplied as the first signature. The dispatcher checks for F contiguous first but the compiler checks for C contiguous first. This results in an C contiguous code inserted as F contiguous function. """ def pyfunc(A, i, j): return A[i, j] cfunc = jit(pyfunc) ary_c_and_f = np.array([[1.]]) ary_c = np.array([[0., 1.], [2., 3.]], order='C') ary_f = np.array([[0., 1.], [2., 3.]], order='F') exp_c = pyfunc(ary_c, 1, 0) exp_f = pyfunc(ary_f, 1, 0) self.assertEqual(1., cfunc(ary_c_and_f, 0, 0)) got_c = cfunc(ary_c, 1, 0) got_f = cfunc(ary_f, 1, 0) self.assertEqual(exp_c, got_c) self.assertEqual(exp_f, got_f) class TestDispatcherFunctionBoundaries(TestCase): def test_pass_dispatcher_as_arg(self): # Test that a Dispatcher object can be pass as argument @jit(nopython=True) def add1(x): return x + 1 @jit(nopython=True) def bar(fn, x): return fn(x) @jit(nopython=True) def foo(x): return bar(add1, x) # Check dispatcher as argument inside NPM inputs = [1, 11.1, np.arange(10)] expected_results = [x + 1 for x in inputs] for arg, expect in zip(inputs, expected_results): self.assertPreciseEqual(foo(arg), expect) # Check dispatcher as argument from python for arg, expect in zip(inputs, expected_results): self.assertPreciseEqual(bar(add1, arg), expect) def test_dispatcher_as_arg_usecase(self): @jit(nopython=True) def maximum(seq, cmpfn): tmp = seq[0] for each in seq[1:]: cmpval = cmpfn(tmp, each) if cmpval < 0: tmp = each return tmp got = maximum([1, 2, 3, 4], cmpfn=jit(lambda x, y: x - y)) self.assertEqual(got, 4) got = maximum(list(zip(range(5), range(5)[::-1])), cmpfn=jit(lambda x, y: x[0] - y[0])) self.assertEqual(got, (4, 0)) got = maximum(list(zip(range(5), range(5)[::-1])), cmpfn=jit(lambda x, y: x[1] - y[1])) self.assertEqual(got, (0, 4)) def test_dispatcher_can_return_to_python(self): @jit(nopython=True) def foo(fn): return fn fn = jit(lambda x: x) self.assertEqual(foo(fn), fn) def test_dispatcher_in_sequence_arg(self): @jit(nopython=True) def one(x): return x + 1 @jit(nopython=True) def two(x): return one(one(x)) @jit(nopython=True) def three(x): return one(one(one(x))) @jit(nopython=True) def choose(fns, x): return fns[0](x), fns[1](x), fns[2](x) # Tuple case self.assertEqual(choose((one, two, three), 1), (2, 3, 4)) # List case self.assertEqual(choose([one, one, one], 1), (2, 2, 2)) class TestBoxingDefaultError(unittest.TestCase): # Testing default error at boxing/unboxing def test_unbox_runtime_error(self): # Dummy type has no unbox support def foo(x): pass cres = compile_isolated(foo, (types.Dummy("dummy_type"),)) with self.assertRaises(TypeError) as raises: # Can pass in whatever and the unbox logic will always raise # without checking the input value. cres.entry_point(None) self.assertEqual(str(raises.exception), "can't unbox dummy_type type") def test_box_runtime_error(self): def foo(): return unittest # Module type has no boxing logic cres = compile_isolated(foo, ()) with self.assertRaises(TypeError) as raises: # Can pass in whatever and the unbox logic will always raise # without checking the input value. cres.entry_point() pat = "cannot convert native Module.* to Python object" self.assertRegexpMatches(str(raises.exception), pat) class TestNoRetryFailedSignature(unittest.TestCase): """Test that failed-to-compile signatures are not recompiled. """ def run_test(self, func): fcom = func._compiler self.assertEqual(len(fcom._failed_cache), 0) # expected failure because `int` has no `__getitem__` with self.assertRaises(errors.TypingError): func(1) self.assertEqual(len(fcom._failed_cache), 1) # retry with self.assertRaises(errors.TypingError): func(1) self.assertEqual(len(fcom._failed_cache), 1) # retry with double with self.assertRaises(errors.TypingError): func(1.0) self.assertEqual(len(fcom._failed_cache), 2) def test_direct_call(self): @jit(nopython=True) def foo(x): return x[0] self.run_test(foo) def test_nested_call(self): @jit(nopython=True) def bar(x): return x[0] @jit(nopython=True) def foobar(x): bar(x) @jit(nopython=True) def foo(x): return bar(x) + foobar(x) self.run_test(foo) def test_error_count(self): def check(field, would_fail): # Slightly modified from the reproducer in issue #4117. # Before the patch, the compilation time of the failing case is # much longer than of the successful case. This can be detected # by the number of times `trigger()` is visited. k = 10 counter = {'c': 0} @generated_jit def trigger(x): # Keep track of every visit counter['c'] += 1 if would_fail: raise errors.TypingError("invoke_failed") return lambda x: x @jit(nopython=True) def ident(out, x): pass def chain_assign(fs, inner=ident): tab_head, tab_tail = fs[-1], fs[:-1] @jit(nopython=True) def assign(out, x): inner(out, x) out[0] += tab_head(x) if tab_tail: return chain_assign(tab_tail, assign) else: return assign chain = chain_assign((trigger,) * k) out = np.ones(2) if would_fail: with self.assertRaises(errors.TypingError) as raises: chain(out, 1) self.assertIn('invoke_failed', str(raises.exception)) else: chain(out, 1) # Returns the visit counts return counter['c'] ct_ok = check('a', False) ct_bad = check('c', True) # `trigger()` is visited exactly once for both successful and failed # compilation. self.assertEqual(ct_ok, 1) self.assertEqual(ct_bad, 1) @njit def add_y1(x, y=1): return x + y @njit def add_ynone(x, y=None): return x + (1 if y else 2) @njit def mult(x, y): return x * y @njit def add_func(x, func=mult): return x + func(x, x) def _checker(f1, arg): assert f1(arg) == f1.py_func(arg) class TestMultiprocessingDefaultParameters(SerialMixin, unittest.TestCase): def run_fc_multiproc(self, fc): try: ctx = multiprocessing.get_context('spawn') except AttributeError: ctx = multiprocessing # RE: issue #5973, this doesn't use multiprocessing.Pool.map as doing so # causes the TBB library to segfault under certain conditions. It's not # clear whether the cause is something in the complexity of the Pool # itself, e.g. watcher threads etc, or if it's a problem synonymous with # a "timing attack". for a in [1, 2, 3]: p = ctx.Process(target=_checker, args=(fc, a,)) p.start() p.join(_TEST_TIMEOUT) self.assertEqual(p.exitcode, 0) def test_int_def_param(self): """ Tests issue #4888""" self.run_fc_multiproc(add_y1) def test_none_def_param(self): """ Tests None as a default parameter""" self.run_fc_multiproc(add_func) def test_function_def_param(self): """ Tests a function as a default parameter""" self.run_fc_multiproc(add_func) class TestVectorizeDifferentTargets(unittest.TestCase): """Test that vectorize can be reapplied if the target is different """ def test_cpu_vs_parallel(self): @jit def add(x, y): return x + y custom_vectorize = vectorize([], identity=None, target='cpu') custom_vectorize(add) custom_vectorize_2 = vectorize([], identity=None, target='parallel') custom_vectorize_2(add) if __name__ == '__main__': unittest.main()