test_nested_calls.py 3.59 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
Test problems in nested calls.
Usually due to invalid type conversion between function boundaries.
"""


from numba import int32, int64
from numba import jit, generated_jit
from numba.core import types
from numba.tests.support import TestCase, tag
import unittest


@jit(nopython=True)
def f_inner(a, b, c):
    return a, b, c

def f(x, y, z):
    return f_inner(x, c=y, b=z)

@jit(nopython=True)
def g_inner(a, b=2, c=3):
    return a, b, c

def g(x, y, z):
    return g_inner(x, b=y), g_inner(a=z, c=x)

@jit(nopython=True)
def star_inner(a=5, *b):
    return a, b

def star(x, y, z):
    return star_inner(a=x), star_inner(x, y, z)

def star_call(x, y, z):
    return star_inner(x, *y), star_inner(*z)

@jit(nopython=True)
def argcast_inner(a, b):
    if b:
        # Here `a` is unified to int64 (from int32 originally)
        a = int64(0)
    return a

def argcast(a, b):
    return argcast_inner(int32(a), b)

@generated_jit(nopython=True)
def generated_inner(x, y=5, z=6):
    if isinstance(x, types.Complex):
        def impl(x, y, z):
            return x + y, z
    else:
        def impl(x, y, z):
            return x - y, z
    return impl

def call_generated(a, b):
    return generated_inner(a, z=b)


class TestNestedCall(TestCase):

    def compile_func(self, pyfunc, objmode=False):
        def check(*args, **kwargs):
            expected = pyfunc(*args, **kwargs)
            result = f(*args, **kwargs)
            self.assertPreciseEqual(result, expected)
        flags = dict(forceobj=True) if objmode else dict(nopython=True)
        f = jit(**flags)(pyfunc)
        return f, check

    def test_boolean_return(self):
        @jit(nopython=True)
        def inner(x):
            return not x

        @jit(nopython=True)
        def outer(x):
            if inner(x):
                return True
            else:
                return False

        self.assertFalse(outer(True))
        self.assertTrue(outer(False))

    def test_named_args(self, objmode=False):
        """
        Test a nested function call with named (keyword) arguments.
        """
        cfunc, check = self.compile_func(f, objmode)
        check(1, 2, 3)
        check(1, y=2, z=3)

    def test_named_args_objmode(self):
        self.test_named_args(objmode=True)

    def test_default_args(self, objmode=False):
        """
        Test a nested function call using default argument values.
        """
        cfunc, check = self.compile_func(g, objmode)
        check(1, 2, 3)
        check(1, y=2, z=3)

    def test_default_args_objmode(self):
        self.test_default_args(objmode=True)

    def test_star_args(self):
        """
        Test a nested function call to a function with *args in its signature.
        """
        cfunc, check = self.compile_func(star)
        check(1, 2, 3)

    def test_star_call(self, objmode=False):
        """
        Test a function call with a *args.
        """
        cfunc, check = self.compile_func(star_call, objmode)
        check(1, (2,), (3,))

    def test_star_call_objmode(self):
        self.test_star_call(objmode=True)

    def test_argcast(self):
        """
        Issue #1488: implicitly casting an argument variable should not
        break nested calls.
        """
        cfunc, check = self.compile_func(argcast)
        check(1, 0)
        check(1, 1)

    def test_call_generated(self):
        """
        Test a nested function call to a generated jit function.
        """
        cfunc = jit(nopython=True)(call_generated)
        self.assertPreciseEqual(cfunc(1, 2), (-4, 2))
        self.assertPreciseEqual(cfunc(1j, 2), (1j + 5, 2))


if __name__ == '__main__':
    unittest.main()