import re # from llvmlite.llvmpy import core as lc # from llvmlite import ir as llvmir from llvmlite import ir from llvmlite import binding as ll from numba.core import typing, types, utils, datamodel, cgutils # from numba.core.utils import cached_property from functools import cached_property from numba.core.base import BaseContext from numba.core.callconv import MinimalCallConv from numba.roc import codegen from .hlc import DATALAYOUT CC_SPIR_KERNEL = "spir_kernel" CC_SPIR_FUNC = "" # ----------------------------------------------------------------------------- # Typing class HSATypingContext(typing.BaseContext): def load_additional_registries(self): from . import hsadecl, mathdecl self.install_registry(hsadecl.registry) self.install_registry(mathdecl.registry) # ----------------------------------------------------------------------------- # Implementation VALID_CHARS = re.compile(r'[^a-z0-9]', re.I) # Address spaces SPIR_GENERIC_ADDRSPACE = 0 SPIR_GLOBAL_ADDRSPACE = 1 SPIR_REGION_ADDRSPACE = 2 SPIR_CONSTANT_ADDRSPACE = 4 SPIR_LOCAL_ADDRSPACE = 3 SPIR_PRIVATE_ADDRSPACE = 5 SPIR_CONSTANT_32BIT_ADDRSPACE = 6 SPIR_VERSION = (2, 0) class GenericPointerModel(datamodel.PrimitiveModel): def __init__(self, dmm, fe_type): adrsp = SPIR_GENERIC_ADDRSPACE be_type = dmm.lookup(fe_type.dtype).get_data_type().as_pointer(adrsp) super(GenericPointerModel, self).__init__(dmm, fe_type, be_type) def _init_data_model_manager(): dmm = datamodel.default_manager.copy() dmm.register(types.CPointer, GenericPointerModel) return dmm hsa_data_model_manager = _init_data_model_manager() class HSATargetContext(BaseContext): implement_powi_as_math_call = True generic_addrspace = SPIR_GENERIC_ADDRSPACE def __init__(self, typingctx, target='ROCm'): super().__init__(typingctx, target) def init(self): self._internal_codegen = codegen.JITHSACodegen("numba.hsa.jit") self._target_data = \ ll.create_target_data(DATALAYOUT) # Override data model manager self.data_model_manager = hsa_data_model_manager def load_additional_registries(self): from . import hsaimpl, mathimpl self.insert_func_defn(hsaimpl.registry.functions) self.insert_func_defn(mathimpl.registry.functions) # sugon: adapt for numba-0.58 # Overrides def create_module(self, name): return self._internal_codegen._create_empty_module(name) # return lc.Module(name) @cached_property def call_conv(self): return HSACallConv(self) def codegen(self): return self._internal_codegen @property def target_data(self): return self._target_data def mangler(self, name, argtypes, *, abi_tags=(), uid=None): def repl(m): ch = m.group(0) return "_%X_" % ord(ch) qualified = name + '.' + '.'.join(str(a) for a in argtypes) mangled = VALID_CHARS.sub(repl, qualified) return 'hsapy_devfn_' + mangled def prepare_hsa_kernel(self, func, argtypes): module = func.module func.linkage = 'linkonce_odr' module.data_layout = DATALAYOUT wrapper = self.generate_kernel_wrapper(func, argtypes) return wrapper def mark_hsa_device(self, func): # Adapt to SPIR # module = func.module func.calling_convention = CC_SPIR_FUNC func.linkage = 'linkonce_odr' return func def generate_kernel_wrapper(self, func, argtypes): module = func.module arginfo = self.get_arg_packer(argtypes) def sub_gen_with_global(lty): if isinstance(lty, ir.PointerType): return (lty.pointee.as_pointer(SPIR_GLOBAL_ADDRSPACE), lty.addrspace) return lty, None if len(arginfo.argument_types) > 0: llargtys, changed = zip(*map(sub_gen_with_global, arginfo.argument_types)) else: llargtys = changed = () wrapperfnty = ir.FunctionType(ir.VoidType(), llargtys) wrapper_module = self.create_module("hsa.kernel.wrapper") wrappername = 'hsaPy_{name}'.format(name=func.name) argtys = list(arginfo.argument_types) fnty = ir.FunctionType(ir.IntType(32), [self.call_conv.get_return_type( types.pyobject)] + argtys) func = ir.Function(wrapper_module, fnty, name=func.name) func.calling_convention = CC_SPIR_FUNC wrapper = ir.Function(wrapper_module, wrapperfnty, name=wrappername) builder = ir.IRBuilder(wrapper.append_basic_block('')) # Adjust address space of each kernel argument fixed_args = [] for av, adrsp in zip(wrapper.args, changed): if adrsp is not None: casted = self.addrspacecast(builder, av, adrsp) fixed_args.append(casted) else: fixed_args.append(av) callargs = arginfo.from_arguments(builder, fixed_args) # XXX handle error status status, _ = self.call_conv.call_function(builder, func, types.void, argtypes, callargs) builder.ret_void() set_hsa_kernel(wrapper) # Link module.link_in(ll.parse_assembly(str(wrapper_module))) # To enable inlining which is essential because addrspacecast 1->0 is # illegal. Inlining will optimize the addrspacecast out. func.linkage = 'internal' wrapper = module.get_function(wrapper.name) module.get_function(func.name).linkage = 'internal' return wrapper def declare_function(self, module, fndesc): ret = super(HSATargetContext, self).declare_function(module, fndesc) # XXX: Refactor fndesc instead of this special case if fndesc.llvm_func_name.startswith('hsapy_devfn'): ret.calling_convention = CC_SPIR_FUNC return ret def make_constant_array(self, builder, typ, ary): """ Return dummy value. """ # # a = self.make_array(typ)(self, builder) # return a._getvalue() raise NotImplementedError def addrspacecast(self, builder, src, addrspace): """ Handle addrspacecast """ ptras = ir.PointerType(src.type.pointee, addrspace=addrspace) return builder.addrspacecast(src, ptras) def set_hsa_kernel(fn): """ Ensure `fn` is usable as a SPIR kernel. - Fix calling convention - Add metadata """ mod = fn.module # Set nounwind # fn.add_attribute(lc.ATTR_NO_UNWIND) # Set SPIR kernel calling convention fn.calling_convention = CC_SPIR_KERNEL # Mark kernels ocl_kernels = cgutils.get_or_insert_named_metadata(mod, 'opencl.kernels') ocl_kernels.add(ir.Module.add_metadata(mod, [fn, gen_arg_addrspace_md(fn), gen_arg_access_qual_md(fn), gen_arg_type(fn), gen_arg_type_qual(fn), gen_arg_base_type(fn)])) # SPIR version 2.0 make_constant = lambda x: ir.Constant(ir.IntType(32), x) spir_version_constant = [make_constant(x) for x in SPIR_VERSION] spir_version = cgutils.get_or_insert_named_metadata(mod, 'opencl.spir.version') if not spir_version.operands: spir_version.add(ir.Module.add_metadata(mod, spir_version_constant)) ocl_version = cgutils.get_or_insert_named_metadata(mod, 'opencl.spir.version') if not ocl_version.operands: ocl_version.add(ir.Module.add_metadata(mod, spir_version_constant)) ## The following metadata does not seem to be necessary # Other metadata # empty_md = lc.MetaData.get(mod, ()) # others = ["opencl.used.extensions", # "opencl.used.optional.core.features", # "opencl.compiler.options"]cat # # for name in others: # nmd = mod.get_or_insert_named_metadata(name) # if not nmd.operands: # nmd.add(empty_md) def gen_arg_addrspace_md(fn): """ Generate kernel_arg_addr_space metadata """ mod = fn.module fnty = fn.type.pointee codes = [] for a in fnty.args: if cgutils.is_pointer(a): codes.append(SPIR_GLOBAL_ADDRSPACE) else: codes.append(SPIR_PRIVATE_ADDRSPACE) consts = [ir.Constant(ir.IntType(32), x) for x in codes] name = ir.MetaDataString(mod, "kernel_arg_addr_space") return ir.Module.add_metadata(mod, [name] + consts) def gen_arg_access_qual_md(fn): """ Generate kernel_arg_access_qual metadata """ mod = fn.module consts = [ir.MetaDataString(mod, "none")] * len(fn.args) name = ir.MetaDataString(mod, "kernel_arg_access_qual") return ir.Module.add_metadata(mod, [name] + consts) def gen_arg_type(fn): """ Generate kernel_arg_type metadata """ mod = fn.module fnty = fn.type.pointee consts = [ir.MetaDataString(mod, str(a)) for a in fnty.args] name = ir.MetaDataString(mod, "kernel_arg_type") return ir.Module.add_metadata(mod, [name] + consts) def gen_arg_type_qual(fn): """ Generate kernel_arg_type_qual metadata """ mod = fn.module fnty = fn.type.pointee consts = [ir.MetaDataString(mod, "") for _ in fnty.args] name = ir.MetaDataString(mod, "kernel_arg_type_qual") return ir.Module.add_metadata(mod, [name] + consts) def gen_arg_base_type(fn): """ Generate kernel_arg_base_type metadata """ mod = fn.module fnty = fn.type.pointee consts = [ir.MetaDataString(mod, str(a)) for a in fnty.args] name = ir.MetaDataString(mod, "kernel_arg_base_type") return ir.Module.add_metadata(mod, [name] + consts) class HSACallConv(MinimalCallConv): def call_function(self, builder, callee, resty, argtys, args, env=None): """ Call the Numba-compiled *callee*. """ assert env is None retty = callee.args[0].type.pointee retvaltmp = cgutils.alloca_once(builder, retty) # initialize return value builder.store(cgutils.get_null_value(retty), retvaltmp) arginfo = self.context.get_arg_packer(argtys) args = arginfo.as_arguments(builder, args) realargs = [retvaltmp] + list(args) code = builder.call(callee, realargs) status = self._get_return_status(builder, code) retval = builder.load(retvaltmp) out = self.context.get_returned_value(builder, resty, retval) return status, out