import operator from functools import reduce # from llvmlite.llvmpy.core import Type # import llvmlite.llvmpy.core as lc import llvmlite.binding as ll from llvmlite import ir from numba import roc from numba.core.imputils import Registry from numba.core import types, cgutils # from numba.core.itanium_mangler import mangle_c, mangle, mangle_type from numba.core.itanium_mangler import mangle, mangle_type from numba.core.typing.npydecl import parse_dtype from numba.roc import target from numba.roc import stubs from numba.roc import hlc from numba.roc import enums registry = Registry() lower = registry.lower _void_value = ir.Constant(ir.PointerType(ir.IntType(8)), None) # ----------------------------------------------------------------------------- def _declare_function(context, builder, name, sig, cargs, mangler=mangle): """Insert declaration for a opencl builtin function. Uses the Itanium mangler. Args ---- context: target context builder: llvm builder name: str symbol name sig: signature function signature of the symbol being declared cargs: sequence of str C type names for the arguments mangler: a mangler function function to use to mangle the symbol """ mod = builder.module if sig.return_type == types.void: llretty = ir.VoidType() else: llretty = context.get_value_type(sig.return_type) llargs = [context.get_value_type(t) for t in sig.args] fnty = ir.FunctionType(llretty, llargs) mangled = mangler(name, cargs) fn = cgutils.get_or_insert_function(mod, fnty, mangled) fn.calling_convention = target.CC_SPIR_FUNC return fn # sugon: there need to use 'types.uint32' ,not string 'unsigned int'. @lower(stubs.get_global_id, types.uint32) def get_global_id_impl(context, builder, sig, args): [dim] = args get_global_id = _declare_function(context, builder, 'get_global_id', sig, [types.uint32]) res = builder.call(get_global_id, [dim]) return context.cast(builder, res, types.uintp, types.intp) @lower(stubs.get_local_id, types.uint32) def get_local_id_impl(context, builder, sig, args): [dim] = args get_local_id = _declare_function(context, builder, 'get_local_id', sig, [types.uint32]) res = builder.call(get_local_id, [dim]) return context.cast(builder, res, types.uintp, types.intp) @lower(stubs.get_group_id, types.uint32) def get_group_id_impl(context, builder, sig, args): [dim] = args get_group_id = _declare_function(context, builder, 'get_group_id', sig, [types.uint32]) res = builder.call(get_group_id, [dim]) return context.cast(builder, res, types.uintp, types.intp) @lower(stubs.get_num_groups, types.uint32) def get_num_groups_impl(context, builder, sig, args): [dim] = args get_num_groups = _declare_function(context, builder, 'get_num_groups', sig, [types.uint32]) res = builder.call(get_num_groups, [dim]) return context.cast(builder, res, types.uintp, types.intp) @lower(stubs.get_work_dim) def get_work_dim_impl(context, builder, sig, args): get_work_dim = _declare_function(context, builder, 'get_work_dim', sig, [types.void]) res = builder.call(get_work_dim, []) return res @lower(stubs.get_global_size, types.uint32) def get_global_size_impl(context, builder, sig, args): [dim] = args get_global_size = _declare_function(context, builder, 'get_global_size', sig, [types.uint32]) res = builder.call(get_global_size, [dim]) return context.cast(builder, res, types.uintp, types.intp) @lower(stubs.get_local_size, types.uint32) def get_local_size_impl(context, builder, sig, args): [dim] = args get_local_size = _declare_function(context, builder, 'get_local_size', sig, [types.uint32]) res = builder.call(get_local_size, [dim]) return context.cast(builder, res, types.uintp, types.intp) @lower(stubs.barrier, types.uint32) def barrier_one_arg_impl(context, builder, sig, args): [flags] = args barrier = _declare_function(context, builder, 'barrier', sig, [types.uint32]) builder.call(barrier, [flags]) return _void_value @lower(stubs.barrier) def barrier_no_arg_impl(context, builder, sig, args): assert not args sig = types.void(types.uint32) barrier = _declare_function(context, builder, 'barrier', sig, [types.uint32]) flags = context.get_constant(types.uint32, enums.CLK_GLOBAL_MEM_FENCE) builder.call(barrier, [flags]) return _void_value @lower(stubs.mem_fence, types.uint32) def mem_fence_impl(context, builder, sig, args): [flags] = args mem_fence = _declare_function(context, builder, 'mem_fence', sig, [types.uint32]) builder.call(mem_fence, [flags]) return _void_value @lower(stubs.wavebarrier) def wavebarrier_impl(context, builder, sig, args): assert not args fnty = ir.FunctionType(ir.VoidType(), []) fn = builder.module.declare_intrinsic('llvm.amdgcn.wave.barrier', fnty=fnty) builder.call(fn, []) return _void_value @lower(stubs.activelanepermute_wavewidth, types.Any, types.uint32, types.Any, types.bool_) def activelanepermute_wavewidth_impl(context, builder, sig, args): [src, laneid, identity, use_ident] = args assert sig.args[0] == sig.args[2] elem_type = sig.args[0] bitwidth = elem_type.bitwidth intbitwidth = ir.IntType(bitwidth) i32 = ir.IntType(32) i1 = ir.IntType(1) name = "__hsail_activelanepermute_wavewidth_b{0}".format(bitwidth) fnty = ir.FunctionType(intbitwidth, [intbitwidth, i32, intbitwidth, i1]) fn = cgutils.get_or_insert_function(builder, fnty, name=name) fn.calling_convention = target.CC_SPIR_FUNC def cast(val): return builder.bitcast(val, intbitwidth) result = builder.call(fn, [cast(src), laneid, cast(identity), use_ident]) return builder.bitcast(result, context.get_value_type(elem_type)) def _gen_ds_permute(intrinsic_name): def _impl(context, builder, sig, args): """ args are (index, src) """ assert sig.return_type == sig.args[1] idx, src = args i32 = ir.IntType(32) fnty = ir.FunctionType(i32, [i32, i32]) fn = builder.module.declare_intrinsic(intrinsic_name, fnty=fnty) # the args are byte addressable, VGPRs are 4 wide so mul idx by 4 # the idx might be an int64, this is ok to trunc to int32 as # wavefront_size is never likely overflow an int32 idx = builder.trunc(idx, i32) four = ir.Constant(i32, 4) idx = builder.mul(idx, four) # bit cast is so float32 works as packed i32, the return casts back result = builder.call(fn, (idx, builder.bitcast(src, i32))) return builder.bitcast(result, context.get_value_type(sig.return_type)) return _impl lower(stubs.ds_permute, types.Any, types.Any)(_gen_ds_permute('llvm.amdgcn.ds.permute')) lower(stubs.ds_bpermute, types.Any, types.Any)(_gen_ds_permute('llvm.amdgcn.ds.bpermute')) @lower(stubs.atomic.add, types.Array, types.intp, types.Any) @lower(stubs.atomic.add, types.Array, types.UniTuple, types.Any) @lower(stubs.atomic.add, types.Array, types.Tuple, types.Any) def hsail_atomic_add_tuple(context, builder, sig, args): aryty, indty, valty = sig.args ary, inds, val = args dtype = aryty.dtype if indty == types.intp: indices = [inds] # just a single integer indty = [indty] else: indices = cgutils.unpack_tuple(builder, inds, count=len(indty)) indices = [context.cast(builder, i, t, types.intp) for t, i in zip(indty, indices)] if dtype != valty: raise TypeError("expecting %s but got %s" % (dtype, valty)) if aryty.ndim != len(indty): raise TypeError("indexing %d-D array with %d-D index" % (aryty.ndim, len(indty))) lary = context.make_array(aryty)(context, builder, ary) ptr = cgutils.get_item_pointer(context, builder, aryty, lary, indices) return builder.atomic_rmw("add", ptr, val, ordering='monotonic') @lower(roc.shared.array, types.IntegerLiteral, types.Any) def hsail_smem_alloc_array_integer(context, builder, sig, args): length = sig.args[0].literal_value dtype = parse_dtype(sig.args[1]) return _generic_array(context, builder, shape=(length,), dtype=dtype, symbol_name='_hsapy_smem', addrspace=target.SPIR_LOCAL_ADDRSPACE) @lower(roc.shared.array, types.Tuple, types.Any) @lower(roc.shared.array, types.UniTuple, types.Any) def hsail_smem_alloc_array_tuple(context, builder, sig, args): shape = [ s.literal_value for s in sig.args[0] ] dtype = parse_dtype(sig.args[1]) return _generic_array(context, builder, shape=shape, dtype=dtype, symbol_name='_hsapy_smem', addrspace=target.SPIR_LOCAL_ADDRSPACE) def _generic_array(context, builder, shape, dtype, symbol_name, addrspace): elemcount = reduce(operator.mul, shape, 1) lldtype = context.get_data_type(dtype) laryty = ir.ArrayType(lldtype, elemcount) if addrspace == target.SPIR_LOCAL_ADDRSPACE: lmod = builder.module # Create global variable in the requested address-space gvmem = cgutils.add_global_variable(lmod, laryty, symbol_name, addrspace) if elemcount <= 0: raise ValueError("array length <= 0") else: gvmem.linkage = 'internal' if dtype not in types.number_domain: raise TypeError("unsupported type: %s" % dtype) # Convert to generic address-space dataptr = context.addrspacecast(builder, gvmem, target.SPIR_GENERIC_ADDRSPACE) else: raise NotImplementedError("addrspace {addrspace}".format(**locals())) return _make_array(context, builder, dataptr, dtype, shape) def _make_array(context, builder, dataptr, dtype, shape, layout='C'): ndim = len(shape) # Create array object aryty = types.Array(dtype=dtype, ndim=ndim, layout='C') ary = context.make_array(aryty)(context, builder) targetdata = _get_target_data(context) lldtype = context.get_data_type(dtype) itemsize = lldtype.get_abi_size(targetdata) # Compute strides rstrides = [itemsize] for i, lastsize in enumerate(reversed(shape[1:])): rstrides.append(lastsize * rstrides[-1]) strides = [s for s in reversed(rstrides)] kshape = [context.get_constant(types.intp, s) for s in shape] kstrides = [context.get_constant(types.intp, s) for s in strides] context.populate_array(ary, data=builder.bitcast(dataptr, ary.data.type), shape=kshape, strides=kstrides, itemsize=context.get_constant(types.intp, itemsize), meminfo=None) return ary._getvalue() def _get_target_data(context): return ll.create_target_data(hlc.DATALAYOUT)