vectorizers.py 4.21 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from numba import roc

from numba.roc import dispatch
from numba.np.ufunc import deviceufunc

vectorizer_stager_source = '''
def __vectorized_{name}({args}, __out__):

    __tid__ = __hsa__.get_local_id(0)
    __blksz__ = __hsa__.get_local_size(0)
    __blkid__ = __hsa__.get_group_id(0)

    __tid0__ = __tid__ + __blksz__ * (4 * __blkid__)
    __tid1__ = __tid__ + __blksz__ * (4 * __blkid__ + 1)
    __tid2__ = __tid__ + __blksz__ * (4 * __blkid__ + 2)
    __tid3__ = __tid__ + __blksz__ * (4 * __blkid__ + 3)

    __ilp0__ = __tid0__ < __out__.shape[0]
    if not __ilp0__:
        # Early escape
        return
    __ilp1__ = __tid1__ < __out__.shape[0]
    __ilp2__ = __tid2__ < __out__.shape[0]
    __ilp3__ = __tid3__ < __out__.shape[0]

    if __ilp3__:
        __args0__ = {argitems_0}
        __args1__ = {argitems_1}
        __args2__ = {argitems_2}
        __args3__ = {argitems_3}

        __r0__ = __core__(*__args0__)
        __r1__ = __core__(*__args1__)
        __r2__ = __core__(*__args2__)
        __r3__ = __core__(*__args3__)

        __out__[__tid0__] = __r0__
        __out__[__tid1__] = __r1__
        __out__[__tid2__] = __r2__
        __out__[__tid3__] = __r3__

    elif __ilp2__:
        __args0__ = {argitems_0}
        __args1__ = {argitems_1}
        __args2__ = {argitems_2}

        __r0__ = __core__(*__args0__)
        __r1__ = __core__(*__args1__)
        __r2__ = __core__(*__args2__)

        __out__[__tid0__] = __r0__
        __out__[__tid1__] = __r1__
        __out__[__tid2__] = __r2__

    elif __ilp1__:
        __args0__ = {argitems_0}
        __args1__ = {argitems_1}

        __r0__ = __core__(*__args0__)
        __r1__ = __core__(*__args1__)

        __out__[__tid0__] = __r0__
        __out__[__tid1__] = __r1__

    else:
        __args0__ = {argitems_0}
        __r0__ = __core__(*__args0__)
        __out__[__tid0__] = __r0__

'''


class HsaVectorize(deviceufunc.DeviceVectorize):
    def _compile_core(self, sig):
        hsadevfn = roc.jit(sig, device=True)(self.pyfunc)
        return hsadevfn, hsadevfn.cres.signature.return_type

    def _get_globals(self, corefn):
        glbl = self.pyfunc.__globals__
        glbl.update({'__hsa__': roc,
                     '__core__': corefn})
        return glbl

    def _compile_kernel(self, fnobj, sig):
        return roc.jit(sig)(fnobj)

    def _get_kernel_source(self, template, sig, funcname):
        args = ['a%d' % i for i in range(len(sig.args))]

        def make_argitems(n):
            out = ', '.join('%s[__tid%d__]' % (i, n) for i in args)
            if len(args) < 2:
                # Less than two arguments.
                # We need to wrap the argument in a tuple because
                # we use stararg later.
                return "({0},)".format(out)
            else:
                return out

        fmts = dict(name=funcname,
                    args=', '.join(args),
                    argitems_0=make_argitems(n=0),
                    argitems_1=make_argitems(n=1),
                    argitems_2=make_argitems(n=2),
                    argitems_3=make_argitems(n=3))
        src = template.format(**fmts)
        return src

    def build_ufunc(self):
        return dispatch.HsaUFuncDispatcher(self.kernelmap)

    @property
    def _kernel_template(self):
        return vectorizer_stager_source


# ------------------------------------------------------------------------------
# Generalized HSA ufuncs

_gufunc_stager_source = '''
def __gufunc_{name}({args}):
    __tid__ = __hsa__.get_global_id(0)
    if __tid__ < {checkedarg}:
        __core__({argitems})
'''


class HsaGUFuncVectorize(deviceufunc.DeviceGUFuncVectorize):
    def build_ufunc(self):
        engine = deviceufunc.GUFuncEngine(self.inputsig, self.outputsig)
        return dispatch.HSAGenerializedUFunc(kernelmap=self.kernelmap,
                                             engine=engine)

    def _compile_kernel(self, fnobj, sig):
        return roc.jit(sig)(fnobj)

    @property
    def _kernel_template(self):
        return _gufunc_stager_source

    def _get_globals(self, sig):
        corefn = roc.jit(sig, device=True)(self.pyfunc)
        glbls = self.py_func.__globals__.copy()
        glbls.update({'__hsa__': roc,
                      '__core__': corefn})
        return glbls