test_literal_container_usage.py 5.63 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# Contents in this file are referenced from the sphinx-generated docs.
# "magictoken" is used for markers as beginning and ending of example text.

import unittest
from numba.tests.support import captured_stdout
from numba import typed


class DocsLiteralContainerUsageTest(unittest.TestCase):

    def test_ex_literal_dict_compile_time_consts(self):
        with captured_stdout():
            # magictoken.test_ex_literal_dict_compile_time_consts.begin
            import numpy as np
            from numba import njit, types
            from numba.extending import overload

            # overload this function
            def specialize(x):
                pass

            @overload(specialize)
            def ol_specialize(x):
                ld = x.literal_value
                const_expr = []
                for k, v in ld.items():
                    if isinstance(v, types.Literal):
                        lv = v.literal_value
                        if lv == 'cat':
                            const_expr.append("Meow!")
                        elif lv == 'dog':
                            const_expr.append("Woof!")
                        elif isinstance(lv, int):
                            const_expr.append(k.literal_value * lv)
                    else: # it's an array
                        const_expr.append("Array(dim={dim}".format(dim=v.ndim))
                const_strings = tuple(const_expr)

                def impl(x):
                    return const_strings
                return impl

            @njit
            def foo():
                pets_ints_and_array = {'a': 1,
                                       'b': 2,
                                       'c': 'cat',
                                       'd': 'dog',
                                       'e': np.ones(5,)}
                return specialize(pets_ints_and_array)

            result = foo()
            print(result) # ('a', 'bb', 'Meow!', 'Woof!', 'Array(dim=1')
            # magictoken.test_ex_literal_dict_compile_time_consts.end

        self.assertEqual(result, ('a', 'bb', 'Meow!', 'Woof!', 'Array(dim=1'))

    def test_ex_initial_value_dict_compile_time_consts(self):
        with captured_stdout():
            # magictoken.test_ex_initial_value_dict_compile_time_consts.begin
            from numba import njit, literally
            from numba.extending import overload

            # overload this function
            def specialize(x):
                pass

            @overload(specialize)
            def ol_specialize(x):
                iv = x.initial_value
                if iv is None:
                    return lambda x: literally(x) # Force literal dispatch
                assert iv == {'a': 1, 'b': 2, 'c': 3} # INITIAL VALUE
                return lambda x: literally(x)

            @njit
            def foo():
                d = {'a': 1, 'b': 2, 'c': 3}
                d['c'] = 20 # no impact on .initial_value
                d['d'] = 30 # no impact on .initial_value
                return specialize(d)

            result = foo()
            print(result) # {a: 1, b: 2, c: 20, d: 30} # NOT INITIAL VALUE!
            # magictoken.test_ex_initial_value_dict_compile_time_consts.end

        expected = typed.Dict()
        for k, v in {'a': 1, 'b': 2, 'c': 20, 'd': 30}.items():
            expected[k] = v
        self.assertEqual(result, expected)

    def test_ex_literal_list(self):
        with captured_stdout():
            # magictoken.test_ex_literal_list.begin
            from numba import njit
            from numba.extending import overload

            # overload this function
            def specialize(x):
                pass

            @overload(specialize)
            def ol_specialize(x):
                l = x.literal_value
                const_expr = []
                for v in l:
                    const_expr.append(str(v))
                const_strings = tuple(const_expr)

                def impl(x):
                    return const_strings
                return impl

            @njit
            def foo():
                const_list = ['a', 10, 1j, ['another', 'list']]
                return specialize(const_list)

            result = foo()
            print(result) # ('Literal[str](a)', 'Literal[int](10)', 'complex128', 'list(unicode_type)') # noqa E501
            # magictoken.test_ex_literal_list.end

        expected = ('Literal[str](a)', 'Literal[int](10)', 'complex128',
                    "list(unicode_type)<iv=['another', 'list']>")
        self.assertEqual(result, expected)

    def test_ex_initial_value_list_compile_time_consts(self):
        with captured_stdout():
            # magictoken.test_ex_initial_value_list_compile_time_consts.begin
            from numba import njit, literally
            from numba.extending import overload

            # overload this function
            def specialize(x):
                pass

            @overload(specialize)
            def ol_specialize(x):
                iv = x.initial_value
                if iv is None:
                    return lambda x: literally(x) # Force literal dispatch
                assert iv == [1, 2, 3] # INITIAL VALUE
                return lambda x: x

            @njit
            def foo():
                l = [1, 2, 3]
                l[2] = 20 # no impact on .initial_value
                l.append(30) # no impact on .initial_value
                return specialize(l)

            result = foo()
            print(result) # [1, 2, 20, 30] # NOT INITIAL VALUE!
            # magictoken.test_ex_initial_value_list_compile_time_consts.end

        expected = [1, 2, 20, 30]
        self.assertEqual(result, expected)


if __name__ == '__main__':
    unittest.main()