test_casting.py 3.92 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import numpy as np
from numba.core.compiler import compile_isolated
from numba.core.errors import TypingError
from numba import njit
from numba.core import types
import struct
import unittest


def float_to_int(x):
    return types.int32(x)


def int_to_float(x):
    return types.float64(x) / 2


def float_to_unsigned(x):
    return types.uint32(x)


def float_to_complex(x):
    return types.complex128(x)


def numpy_scalar_cast_error():
    np.int32(np.zeros((4,)))

class TestCasting(unittest.TestCase):
    def test_float_to_int(self):
        pyfunc = float_to_int
        cr = compile_isolated(pyfunc, [types.float32])
        cfunc = cr.entry_point

        self.assertEqual(cr.signature.return_type, types.int32)
        self.assertEqual(cfunc(12.3), pyfunc(12.3))
        self.assertEqual(cfunc(12.3), int(12.3))
        self.assertEqual(cfunc(-12.3), pyfunc(-12.3))
        self.assertEqual(cfunc(-12.3), int(-12.3))

    def test_int_to_float(self):
        pyfunc = int_to_float
        cr = compile_isolated(pyfunc, [types.int64])
        cfunc = cr.entry_point

        self.assertEqual(cr.signature.return_type, types.float64)
        self.assertEqual(cfunc(321), pyfunc(321))
        self.assertEqual(cfunc(321), 321. / 2)

    def test_float_to_unsigned(self):
        pyfunc = float_to_unsigned
        cr = compile_isolated(pyfunc, [types.float32])
        cfunc = cr.entry_point

        self.assertEqual(cr.signature.return_type, types.uint32)
        self.assertEqual(cfunc(3.21), pyfunc(3.21))
        self.assertEqual(cfunc(3.21), struct.unpack('I', struct.pack('i',
                                                                      3))[0])

    def test_float_to_complex(self):
        pyfunc = float_to_complex
        cr = compile_isolated(pyfunc, [types.float64])
        cfunc = cr.entry_point
        self.assertEqual(cr.signature.return_type, types.complex128)
        self.assertEqual(cfunc(-3.21), pyfunc(-3.21))
        self.assertEqual(cfunc(-3.21), -3.21 + 0j)

    def test_array_to_array(self):
        """Make sure this compiles.

        Cast C to A array
        """
        @njit("f8(f8[:])")
        def inner(x):
            return x[0]

        inner.disable_compile()

        @njit("f8(f8[::1])")
        def driver(x):
            return inner(x)

        x = np.array([1234], dtype=np.float64)
        self.assertEqual(driver(x), x[0])
        self.assertEqual(len(inner.overloads), 1)

    def test_0darrayT_to_T(self):
        @njit
        def inner(x):
            return x.dtype.type(x)

        inputs = [
            (np.bool_, True),
            (np.float32, 12.3),
            (np.float64, 12.3),
            (np.int64, 12),
            (np.complex64, 2j+3),
            (np.complex128, 2j+3),
            (np.timedelta64, np.timedelta64(3, 'h')),
            (np.datetime64, np.datetime64('2016-01-01')),
            ('<U3', 'ABC'),
        ]

        for (T, inp) in inputs:
            x = np.array(inp, dtype=T)
            self.assertEqual(inner(x), x[()])

    def test_array_to_scalar(self):
        """
        Ensure that a TypingError exception is raised if
        user tries to convert numpy array to scalar
        """

        with self.assertRaises(TypingError) as raises:
            compile_isolated(numpy_scalar_cast_error, ())

        self.assertIn("Casting array(float64, 1d, C) to int32 directly is unsupported.",
                      str(raises.exception))

    def test_optional_to_optional(self):
        """
        Test error due mishandling of Optional to Optional casting

        Related issue: https://github.com/numba/numba/issues/1718
        """
        # Attempt to cast optional(intp) to optional(float64)
        opt_int = types.Optional(types.intp)
        opt_flt = types.Optional(types.float64)
        sig = opt_flt(opt_int)

        @njit(sig)
        def foo(a):
            return a

        self.assertEqual(foo(2), 2)
        self.assertIsNone(foo(None))


if __name__ == '__main__':
    unittest.main()