target_extension.py 4.59 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from abc import ABC, abstractmethod
from numba.core.registry import DelayedRegistry, CPUDispatcher
from numba.core.decorators import jit
from numba.core.errors import InternalTargetMismatchError, NumbaValueError
from threading import local as tls


_active_context = tls()
_active_context_default = 'cpu'


class _TargetRegistry(DelayedRegistry):

    def __getitem__(self, item):
        try:
            return super().__getitem__(item)
        except KeyError:
            msg = "No target is registered against '{}', known targets:\n{}"
            known = '\n'.join([f"{k: <{10}} -> {v}"
                               for k, v in target_registry.items()])
            raise NumbaValueError(msg.format(item, known)) from None


# Registry mapping target name strings to Target classes
target_registry = _TargetRegistry()

# Registry mapping Target classes the @jit decorator for that target
jit_registry = DelayedRegistry()


class target_override(object):
    """Context manager to temporarily override the current target with that
       prescribed."""
    def __init__(self, name):
        self._orig_target = getattr(_active_context, 'target',
                                    _active_context_default)
        self.target = name

    def __enter__(self):
        _active_context.target = self.target

    def __exit__(self, ty, val, tb):
        _active_context.target = self._orig_target


def current_target():
    """Returns the current target
    """
    return getattr(_active_context, 'target', _active_context_default)


def get_local_target(context):
    """
    Gets the local target from the call stack if available and the TLS
    override if not.
    """
    # TODO: Should this logic be reversed to prefer TLS override?
    if len(context.callstack._stack) > 0:
        target = context.callstack[0].target
    else:
        target = target_registry.get(current_target(), None)
    if target is None:
        msg = ("The target found is not registered."
               "Given target was {}.")
        raise ValueError(msg.format(target))
    else:
        return target


def resolve_target_str(target_str):
    """Resolves a target specified as a string to its Target class."""
    return target_registry[target_str]


def resolve_dispatcher_from_str(target_str):
    """Returns the dispatcher associated with a target string"""
    target_hw = resolve_target_str(target_str)
    return dispatcher_registry[target_hw]


def _get_local_target_checked(tyctx, hwstr, reason):
    """Returns the local target if it is compatible with the given target
    name during a type resolution; otherwise, raises an exception.

    Parameters
    ----------
    tyctx: typing context
    hwstr: str
        target name to check against
    reason: str
        Reason for the resolution. Expects a noun.
    Returns
    -------
    target_hw : Target

    Raises
    ------
    InternalTargetMismatchError
    """
    # Get the class for the target declared by the function
    hw_clazz = resolve_target_str(hwstr)
    # get the local target
    target_hw = get_local_target(tyctx)
    # make sure the target_hw is in the MRO for hw_clazz else bail
    if not target_hw.inherits_from(hw_clazz):
        raise InternalTargetMismatchError(reason, target_hw, hw_clazz)
    return target_hw


class JitDecorator(ABC):

    @abstractmethod
    def __call__(self):
        return NotImplemented


class Target(ABC):
    """ Implements a target """

    @classmethod
    def inherits_from(cls, other):
        """Returns True if this target inherits from 'other' False otherwise"""
        return issubclass(cls, other)


class Generic(Target):
    """Mark the target as generic, i.e. suitable for compilation on
    any target. All must inherit from this.
    """


class CPU(Generic):
    """Mark the target as CPU.
    """


class GPU(Generic):
    """Mark the target as GPU, i.e. suitable for compilation on a GPU
    target.
    """


class CUDA(GPU):
    """Mark the target as CUDA.
    """

dugupeiwen's avatar
dugupeiwen committed
147
148
149
class ROCm(GPU):
    """Mark the target as ROCm.
    """
dugupeiwen's avatar
dugupeiwen committed
150
151
152
153
154
155
156
157
158
159
160
161
162

class NPyUfunc(Target):
    """Mark the target as a ufunc
    """


target_registry['generic'] = Generic
target_registry['CPU'] = CPU
target_registry['cpu'] = CPU
target_registry['GPU'] = GPU
target_registry['gpu'] = GPU
target_registry['CUDA'] = CUDA
target_registry['cuda'] = CUDA
dugupeiwen's avatar
dugupeiwen committed
163
target_registry['ROCm'] = ROCm
164
165
# sugon: support ROC
target_registry['roc'] = ROCm
dugupeiwen's avatar
dugupeiwen committed
166
167
168
169
170
171
172
173
174
target_registry['npyufunc'] = NPyUfunc

dispatcher_registry = DelayedRegistry(key_type=Target)


# Register the cpu target token with its dispatcher and jit
cpu_target = target_registry['cpu']
dispatcher_registry[cpu_target] = CPUDispatcher
jit_registry[cpu_target] = jit