mathimpl.py 3.19 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import math
import warnings

from numba.core.imputils import Registry
from numba.core import types
from numba.core.itanium_mangler import mangle
from .hsaimpl import _declare_function

registry = Registry()
lower = registry.lower

# -----------------------------------------------------------------------------

_unary_b_f = types.int32(types.float32)
_unary_b_d = types.int32(types.float64)
_unary_f_f = types.float32(types.float32)
_unary_d_d = types.float64(types.float64)
_binary_f_ff = types.float32(types.float32, types.float32)
_binary_d_dd = types.float64(types.float64, types.float64)

function_descriptors = {
    'isnan': (_unary_b_f, _unary_b_d),
    'isinf': (_unary_b_f, _unary_b_d),

    'ceil': (_unary_f_f, _unary_d_d),
    'floor': (_unary_f_f, _unary_d_d),

    'fabs': (_unary_f_f, _unary_d_d),

    'sqrt': (_unary_f_f, _unary_d_d),
    'exp': (_unary_f_f, _unary_d_d),
    'expm1': (_unary_f_f, _unary_d_d),
    'log': (_unary_f_f, _unary_d_d),
    'log10': (_unary_f_f, _unary_d_d),
    'log1p': (_unary_f_f, _unary_d_d),

    'sin': (_unary_f_f, _unary_d_d),
    'cos': (_unary_f_f, _unary_d_d),
    'tan': (_unary_f_f, _unary_d_d),
    'asin': (_unary_f_f, _unary_d_d),
    'acos': (_unary_f_f, _unary_d_d),
    'atan': (_unary_f_f, _unary_d_d),
    'sinh': (_unary_f_f, _unary_d_d),
    'cosh': (_unary_f_f, _unary_d_d),
    'tanh': (_unary_f_f, _unary_d_d),
    'asinh': (_unary_f_f, _unary_d_d),
    'acosh': (_unary_f_f, _unary_d_d),
    'atanh': (_unary_f_f, _unary_d_d),

    'copysign': (_binary_f_ff, _binary_d_dd),
    'atan2': (_binary_f_ff, _binary_d_dd),
    'pow': (_binary_f_ff, _binary_d_dd),
    'fmod': (_binary_f_ff, _binary_d_dd),

    'erf': (_unary_f_f, _unary_d_d),
    'erfc': (_unary_f_f, _unary_d_d),
    'gamma': (_unary_f_f, _unary_d_d),
    'lgamma': (_unary_f_f, _unary_d_d),

    # unsupported functions listed in the math module documentation:
    # frexp, ldexp, trunc, modf, factorial, fsum
}


# some functions may be named differently by the underlying math
# library as opposed to the Python name.
_lib_counterpart = {
    'gamma': 'tgamma'
}


def _mk_fn_decl(name, decl_sig):
    sym = _lib_counterpart.get(name, name)

    def core(context, builder, sig, args):
        fn = _declare_function(context, builder, sym, decl_sig, decl_sig.args,
                               mangler=mangle)
        res = builder.call(fn, args)
        return context.cast(builder, res, decl_sig.return_type, sig.return_type)

    core.__name__ = name
    return core


_supported = ['sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'atan2', 'sinh',
              'cosh', 'tanh', 'asinh', 'acosh', 'atanh', 'isnan', 'isinf',
              'ceil', 'floor', 'fabs', 'sqrt', 'exp', 'expm1', 'log',
              'log10', 'log1p', 'copysign', 'pow', 'fmod', 'erf', 'erfc',
              'gamma', 'lgamma',
              ]

for name in _supported:
    sigs = function_descriptors.get(name)
    if sigs is None:
        warnings.warn("HSA - failed to register '{0}'".format(name))
        continue

    try:
        # only symbols present in the math module
        key = getattr(math, name)
    except AttributeError:
        continue

    for sig in sigs:
        fn = _mk_fn_decl(name, sig)
        lower(key, *sig.args)(fn)