test_print.py 6.06 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import sys

import numpy as np

import unittest
from numba.core.compiler import compile_isolated, Flags
from numba import jit
from numba.core import types, errors, utils
from numba.tests.support import (captured_stdout, tag, TestCase,
                                 EnableNRTStatsMixin)


enable_pyobj_flags = Flags()
enable_pyobj_flags.enable_pyobject = True

force_pyobj_flags = Flags()
force_pyobj_flags.force_pyobject = True


def print_value(x):
    print(x)

def print_array_item(arr, i):
    print(arr[i].x)

def print_values(a, b, c):
    print(a, b, c)

def print_empty():
    print()

def print_string(x):
    print(x, "hop!", 3.5)

def print_vararg(a, b, c):
    print(a, b, *c)

def print_string_vararg(a, b, c):
    print(a, "hop!", b, *c)

def make_print_closure(x):
    def print_closure():
        return x
    return jit(nopython=True)(x)


class TestPrint(EnableNRTStatsMixin, TestCase):

    def test_print_values(self):
        """
        Test printing a single argument value.
        """
        pyfunc = print_value

        def check_values(typ, values):
            cr = compile_isolated(pyfunc, (typ,))
            cfunc = cr.entry_point
            for val in values:
                with captured_stdout():
                    cfunc(val)
                    self.assertEqual(sys.stdout.getvalue(), str(val) + '\n')

        # Various scalars
        check_values(types.int32, (1, -234))
        check_values(types.int64, (1, -234,
                                   123456789876543210, -123456789876543210))
        check_values(types.uint64, (1, 234,
                                   123456789876543210, 2**63 + 123))
        check_values(types.boolean, (True, False))
        check_values(types.float64, (1.5, 100.0**10.0, float('nan')))
        check_values(types.complex64, (1+1j,))
        check_values(types.NPTimedelta('ms'), (np.timedelta64(100, 'ms'),))

        cr = compile_isolated(pyfunc, (types.float32,))
        cfunc = cr.entry_point
        with captured_stdout():
            cfunc(1.1)
            # Float32 will lose precision
            got = sys.stdout.getvalue()
            expect = '1.10000002384'
            self.assertTrue(got.startswith(expect))
            self.assertTrue(got.endswith('\n'))

        # NRT-enabled type
        with self.assertNoNRTLeak():
            x = [1, 3, 5, 7]
            with self.assertRefCount(x):
                check_values(types.List(types.int32), (x,))

        # Array will have to use object mode
        arraytype = types.Array(types.int32, 1, 'C')
        cr = compile_isolated(pyfunc, (arraytype,), flags=enable_pyobj_flags)
        cfunc = cr.entry_point
        with captured_stdout():
            cfunc(np.arange(10, dtype=np.int32))
            self.assertEqual(sys.stdout.getvalue(),
                             '[0 1 2 3 4 5 6 7 8 9]\n')

    def test_print_array_item(self):
        """
        Test printing a Numpy character sequence
        """
        dtype = np.dtype([('x', 'S4')])
        arr = np.frombuffer(bytearray(range(1, 9)), dtype=dtype)

        pyfunc = print_array_item
        cfunc = jit(nopython=True)(pyfunc)
        for i in range(len(arr)):
            with captured_stdout():
                cfunc(arr, i)
                self.assertEqual(sys.stdout.getvalue(), str(arr[i]['x']) + '\n')

    def test_print_multiple_values(self):
        pyfunc = print_values
        cr = compile_isolated(pyfunc, (types.int32,) * 3)
        cfunc = cr.entry_point
        with captured_stdout():
            cfunc(1, 2, 3)
            self.assertEqual(sys.stdout.getvalue(), '1 2 3\n')

    def test_print_nogil(self):
        pyfunc = print_values
        cfunc = jit(nopython=True, nogil=True)(pyfunc)
        with captured_stdout():
            cfunc(1, 2, 3)
            self.assertEqual(sys.stdout.getvalue(), '1 2 3\n')

    def test_print_empty(self):
        pyfunc = print_empty
        cr = compile_isolated(pyfunc, ())
        cfunc = cr.entry_point
        with captured_stdout():
            cfunc()
            self.assertEqual(sys.stdout.getvalue(), '\n')

    def test_print_strings(self):
        pyfunc = print_string
        cr = compile_isolated(pyfunc, (types.int32,))
        cfunc = cr.entry_point
        with captured_stdout():
            cfunc(1)
            self.assertEqual(sys.stdout.getvalue(), '1 hop! 3.5\n')

    def test_print_vararg(self):
        # Test *args support for print().  This is desired since
        # print() can use a dedicated IR node.
        pyfunc = print_vararg
        cfunc = jit(nopython=True)(pyfunc)
        with captured_stdout():
            cfunc(1, (2, 3), (4, 5j))
            self.assertEqual(sys.stdout.getvalue(), '1 (2, 3) 4 5j\n')

        pyfunc = print_string_vararg
        cfunc = jit(nopython=True)(pyfunc)
        with captured_stdout():
            cfunc(1, (2, 3), (4, 5j))
            self.assertEqual(sys.stdout.getvalue(), '1 hop! (2, 3) 4 5j\n')

    def test_inner_fn_print(self):
        @jit(nopython=True)
        def foo(x):
            print(x)

        @jit(nopython=True)
        def bar(x):
            foo(x)
            foo('hello')

        # Printing an array requires the Env.
        # We need to make sure the inner function can obtain the Env.
        x = np.arange(5)
        with captured_stdout():
            bar(x)
            self.assertEqual(sys.stdout.getvalue(), '[0 1 2 3 4]\nhello\n')

    def test_print_w_kwarg_raises(self):
        @jit(nopython=True)
        def print_kwarg():
            print('x', flush=True)

        with self.assertRaises(errors.UnsupportedError) as raises:
            print_kwarg()
        expected = ("Numba's print() function implementation does not support "
                    "keyword arguments.")
        self.assertIn(raises.exception.msg, expected)

    def test_print_no_truncation(self):
        ''' See: https://github.com/numba/numba/issues/3811
        '''
        @jit(nopython=True)
        def foo():
            print(''.join(['a'] * 10000))
        with captured_stdout():
            foo()
            self.assertEqual(sys.stdout.getvalue(), ''.join(['a'] * 10000) + '\n')

if __name__ == '__main__':
    unittest.main()