decorators.py 1.38 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from numba.core import types, sigutils
from .compiler import (compile_kernel, compile_device, AutoJitHSAKernel,
                       compile_device_template)


def jit(signature=None, device=False):
    """JIT compile a python function conforming to
    the HSA-Python
    """
    if signature is None:
        return autojit(device=device)
    elif not sigutils.is_signature(signature):
        func = signature
        return autojit(device=device)(func)
    else:
        if device:
            return _device_jit(signature)
        else:
            return _kernel_jit(signature)


def autojit(device=False):
    if device:
        return _device_autojit
    else:
        return _kernel_autojit


def _device_jit(signature):
    argtypes, restype = sigutils.normalize_signature(signature)

    def _wrapped(pyfunc):
        return compile_device(pyfunc, restype, argtypes)

    return _wrapped


def _kernel_jit(signature):
    argtypes, restype = sigutils.normalize_signature(signature)
    if restype is not None and restype != types.void:
        msg = "HSA kernel must have void return type but got {restype}"
        raise TypeError(msg.format(restype=restype))

    def _wrapped(pyfunc):
        return compile_kernel(pyfunc, argtypes)

    return _wrapped


def _device_autojit(pyfunc):
    return compile_device_template(pyfunc)


def _kernel_autojit(pyfunc):
    return AutoJitHSAKernel(pyfunc)