"examples/nas/legacy/cream/lib/utils/flops_table.py" did not exist on "abc221589c65d75b494407c60a81ca87c3020463"
libdeviceimpl.py 2.74 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from llvmlite import ir
from numba.core import cgutils, types
from numba.core.imputils import Registry
from numba.cuda import libdevice, libdevicefuncs

registry = Registry()
lower = registry.lower


def libdevice_implement(func, retty, nbargs):
    def core(context, builder, sig, args):
        lmod = builder.module
        fretty = context.get_value_type(retty)
        fargtys = [context.get_value_type(arg.ty) for arg in nbargs]
        fnty = ir.FunctionType(fretty, fargtys)
        fn = cgutils.get_or_insert_function(lmod, fnty, func)
        return builder.call(fn, args)

    key = getattr(libdevice, func[5:])

    argtys = [arg.ty for arg in args if not arg.is_ptr]
    lower(key, *argtys)(core)


def libdevice_implement_multiple_returns(func, retty, prototype_args):
    sig = libdevicefuncs.create_signature(retty, prototype_args)
    nb_retty = sig.return_type

    def core(context, builder, sig, args):
        lmod = builder.module

        fargtys = []
        for arg in prototype_args:
            ty = context.get_value_type(arg.ty)
            if arg.is_ptr:
                ty = ty.as_pointer()
            fargtys.append(ty)

        fretty = context.get_value_type(retty)

        fnty = ir.FunctionType(fretty, fargtys)
        fn = cgutils.get_or_insert_function(lmod, fnty, func)

        # For returned values that are returned through a pointer, we need to
        # allocate variables on the stack and pass a pointer to them.
        actual_args = []
        virtual_args = []
        arg_idx = 0
        for arg in prototype_args:
            if arg.is_ptr:
                # Allocate space for return value and add to args
                tmp_arg = cgutils.alloca_once(builder,
                                              context.get_value_type(arg.ty))
                actual_args.append(tmp_arg)
                virtual_args.append(tmp_arg)
            else:
                actual_args.append(args[arg_idx])
                arg_idx += 1

        ret = builder.call(fn, actual_args)

        # Following the call, we need to assemble the returned values into a
        # tuple for returning back to the caller.
        tuple_args = []
        if retty != types.void:
            tuple_args.append(ret)
        for arg in virtual_args:
            tuple_args.append(builder.load(arg))

        if isinstance(nb_retty, types.UniTuple):
            return cgutils.pack_array(builder, tuple_args)
        else:
            return cgutils.pack_struct(builder, tuple_args)

    key = getattr(libdevice, func[5:])
    lower(key, *sig.args)(core)


for func, (retty, args) in libdevicefuncs.functions.items():
    if any([arg.is_ptr for arg in args]):
        libdevice_implement_multiple_returns(func, retty, args)
    else:
        libdevice_implement(func, retty, args)