test_compiler_flags.py 3.2 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import re

from numba import njit
from numba.core.extending import overload
from numba.core.targetconfig import ConfigStack
from numba.core.compiler import Flags, DEFAULT_FLAGS
from numba.core import types
from numba.core.funcdesc import default_mangler

from numba.tests.support import TestCase, unittest


class TestCompilerFlagCachedOverload(TestCase):
    def test_fastmath_in_overload(self):
        def fastmath_status():
            pass

        @overload(fastmath_status)
        def ov_fastmath_status():
            flags = ConfigStack().top()
            val = "Has fastmath" if flags.fastmath else "No fastmath"

            def codegen():
                return val

            return codegen

        @njit(fastmath=True)
        def set_fastmath():
            return fastmath_status()

        @njit()
        def foo():
            a = fastmath_status()
            b = set_fastmath()
            return (a, b)

        a, b = foo()
        self.assertEqual(a, "No fastmath")
        self.assertEqual(b, "Has fastmath")


class TestFlagMangling(TestCase):

    def test_demangle(self):

        def check(flags):
            mangled = flags.get_mangle_string()
            out = flags.demangle(mangled)
            # Demangle result MUST match summary()
            self.assertEqual(out, flags.summary())

        # test empty flags
        flags = Flags()
        check(flags)

        # test default
        check(DEFAULT_FLAGS)

        # test other
        flags = Flags()
        flags.no_cpython_wrapper = True
        flags.nrt = True
        flags.fastmath = True
        check(flags)

    def test_mangled_flags_is_shorter(self):
        # at least for these control cases
        flags = Flags()
        flags.nrt = True
        flags.auto_parallel = True
        self.assertLess(len(flags.get_mangle_string()), len(flags.summary()))

    def test_mangled_flags_with_fastmath_parfors_inline(self):
        # at least for these control cases
        flags = Flags()
        flags.nrt = True
        flags.auto_parallel = True
        flags.fastmath = True
        flags.inline = "always"
        self.assertLess(len(flags.get_mangle_string()), len(flags.summary()))
        demangled = flags.demangle(flags.get_mangle_string())
        # There should be no pointer value in the demangled string.
        self.assertNotIn("0x", demangled)

    def test_demangling_from_mangled_symbols(self):
        """Test demangling of flags from mangled symbol"""
        # Use default mangler to mangle the string
        fname = 'foo'
        argtypes = types.int32,
        flags = Flags()
        flags.nrt = True
        flags.target_backend = "myhardware"
        name = default_mangler(
            fname, argtypes, abi_tags=[flags.get_mangle_string()],
        )
        # Find the ABI-tag. Starts with "B"
        prefix = "_Z3fooB"
        # Find the length of the ABI-tag
        m = re.match("[0-9]+", name[len(prefix):])
        size = m.group(0)
        # Extract the ABI tag
        base = len(prefix) + len(size)
        abi_mangled = name[base:base + int(size)]
        # Demangle and check
        demangled = Flags.demangle(abi_mangled)
        self.assertEqual(demangled, flags.summary())


if __name__ == "__main__":
    unittest.main()