test_structref_usage.py 4.74 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# "magictoken" is used for markers as beginning and ending of example text.

import unittest

# magictoken.ex_structref_type_definition.begin
import numpy as np

from numba import njit
from numba.core import types
from numba.experimental import structref

from numba.tests.support import skip_unless_scipy


# Define a StructRef.
# `structref.register` associates the type with the default data model.
# This will also install getters and setters to the fields of
# the StructRef.
@structref.register
class MyStructType(types.StructRef):
    def preprocess_fields(self, fields):
        # This method is called by the type constructor for additional
        # preprocessing on the fields.
        # Here, we don't want the struct to take Literal types.
        return tuple((name, types.unliteral(typ)) for name, typ in fields)


# Define a Python type that can be use as a proxy to the StructRef
# allocated inside Numba. Users can construct the StructRef via
# the constructor for this type in python code and jit-code.
class MyStruct(structref.StructRefProxy):
    def __new__(cls, name, vector):
        # Overriding the __new__ method is optional, doing so
        # allows Python code to use keyword arguments,
        # or add other customized behavior.
        # The default __new__ takes `*args`.
        # IMPORTANT: Users should not override __init__.
        return structref.StructRefProxy.__new__(cls, name, vector)

    # By default, the proxy type does not reflect the attributes or
    # methods to the Python side. It is up to users to define
    # these. (This may be automated in the future.)

    @property
    def name(self):
        # To access a field, we can define a function that simply
        # return the field in jit-code.
        # The definition of MyStruct_get_name is shown later.
        return MyStruct_get_name(self)

    @property
    def vector(self):
        # The definition of MyStruct_get_vector is shown later.
        return MyStruct_get_vector(self)


@njit
def MyStruct_get_name(self):
    # In jit-code, the StructRef's attribute is exposed via
    # structref.register
    return self.name


@njit
def MyStruct_get_vector(self):
    return self.vector


# This associates the proxy with MyStructType for the given set of
# fields. Notice how we are not constraining the type of each field.
# Field types remain generic.
structref.define_proxy(MyStruct, MyStructType, ["name", "vector"])
# magictoken.ex_structref_type_definition.end


@skip_unless_scipy
class TestStructRefUsage(unittest.TestCase):
    def test_type_definition(self):
        np.random.seed(0)
        # Redirect print
        buf = []

        def print(*args):
            buf.append(args)

        # magictoken.ex_structref_type_definition_test.begin
        # Let's test our new StructRef.

        # Define one in Python
        alice = MyStruct("Alice", vector=np.random.random(3))

        # Define one in jit-code
        @njit
        def make_bob():
            bob = MyStruct("unnamed", vector=np.zeros(3))
            # Mutate the attributes
            bob.name = "Bob"
            bob.vector = np.random.random(3)
            return bob

        bob = make_bob()

        # Out: Alice: [0.5488135  0.71518937 0.60276338]
        print(f"{alice.name}: {alice.vector}")
        # Out: Bob: [0.88325739 0.73527629 0.87746707]
        print(f"{bob.name}: {bob.vector}")

        # Define a jit function to operate on the structs.
        @njit
        def distance(a, b):
            return np.linalg.norm(a.vector - b.vector)

        # Out: 0.4332647200356598
        print(distance(alice, bob))
        # magictoken.ex_structref_type_definition_test.end

        self.assertEqual(len(buf), 3)

    def test_overload_method(self):
        # magictoken.ex_structref_method.begin
        from numba.core.extending import overload_method
        from numba.core.errors import TypingError

        # Use @overload_method to add a method for
        # MyStructType.distance(other)
        # where *other* is an instance of MyStructType.
        @overload_method(MyStructType, "distance")
        def ol_distance(self, other):
            # Guard that *other* is an instance of MyStructType
            if not isinstance(other, MyStructType):
                raise TypingError(
                    f"*other* must be a {MyStructType}; got {other}"
                )

            def impl(self, other):
                return np.linalg.norm(self.vector - other.vector)

            return impl

        # Test
        @njit
        def test():
            alice = MyStruct("Alice", vector=np.random.random(3))
            bob = MyStruct("Bob", vector=np.random.random(3))
            # Use the method
            return alice.distance(bob)
        # magictoken.ex_structref_method.end

        self.assertIsInstance(test(), float)