Commit 3e5f428e authored by dugupeiwen's avatar dugupeiwen
Browse files

Remove use of llvmlite.llvmpy for 0.58

parent 5be111ee
from llvmlite import binding as ll from llvmlite import binding as ll
from llvmlite.llvmpy import core as lc # from llvmlite.llvmpy import core as lc
import llvmlite.ir as llvmir
from numba.core import utils from numba.core import utils
from numba.core.codegen import BaseCPUCodegen, CodeLibrary from numba.core.codegen import Codegen, CodeLibrary, CPUCodeLibrary
from .hlc import DATALAYOUT, TRIPLE, hlc from .hlc import DATALAYOUT, TRIPLE, hlc
class HSACodeLibrary(CPUCodeLibrary):
class HSACodeLibrary(CodeLibrary):
def _optimize_functions(self, ll_module): def _optimize_functions(self, ll_module):
pass pass
...@@ -25,17 +25,55 @@ class HSACodeLibrary(CodeLibrary): ...@@ -25,17 +25,55 @@ class HSACodeLibrary(CodeLibrary):
return str(out.hsail) return str(out.hsail)
class JITHSACodegen(BaseCPUCodegen): # class JITHSACodegen(Codegen):
# _library_class = HSACodeLibrary
# def _init(self, llvm_module):
# assert list(llvm_module.global_variables) == [], "Module isn't empty"
# self._data_layout = DATALAYOUT[utils.MACHINE_BITS]
# self._target_data = ll.create_target_data(self._data_layout)
# def _create_empty_module(self, name):
# ir_module = llvmir.Module(name)
# ir_module.triple = TRIPLE
# return ir_module
# def _module_pass_manager(self):
# raise NotImplementedError
# def _function_pass_manager(self, llvm_module):
# raise NotImplementedError
# def _add_module(self, module):
# pass
class JITHSACodegen(Codegen):
_library_class = HSACodeLibrary _library_class = HSACodeLibrary
def __init__(self, module_name):
# initialize_llvm()
ll.initialize()
ll.initialize_native_target()
ll.initialize_native_asmprinter()
self._data_layout = None
self._llvm_module = ll.parse_assembly(
str(self._create_empty_module(module_name)))
self._llvm_module.name = "global_codegen_module"
# self._rtlinker = RuntimeLinker()
self._init(self._llvm_module)
def _init(self, llvm_module): def _init(self, llvm_module):
assert list(llvm_module.global_variables) == [], "Module isn't empty" assert list(llvm_module.global_variables) == [], "Module isn't empty"
self._data_layout = DATALAYOUT[utils.MACHINE_BITS] self._data_layout = DATALAYOUT[utils.MACHINE_BITS]
self._target_data = ll.create_target_data(self._data_layout) self._target_data = ll.create_target_data(self._data_layout)
def _create_empty_module(self, name): def _create_empty_module(self, name):
ir_module = lc.Module(name) ir_module = llvmir.Module(name)
ir_module.triple = TRIPLE ir_module.triple = TRIPLE
if self._data_layout:
ir_module.data_layout = self._data_layout
return ir_module return ir_module
def _module_pass_manager(self): def _module_pass_manager(self):
......
...@@ -26,10 +26,10 @@ def compile_hsa(pyfunc, return_type, args, debug): ...@@ -26,10 +26,10 @@ def compile_hsa(pyfunc, return_type, args, debug):
# TODO handle debug flag # TODO handle debug flag
flags = compiler.Flags() flags = compiler.Flags()
# Do not compile (generate native code), just lower (to LLVM) # Do not compile (generate native code), just lower (to LLVM)
flags.set('no_compile') flags.no_compile = True
flags.set('no_cpython_wrapper') flags.no_cpython_wrapper = True
flags.set('no_cfunc_wrapper') flags.no_cfunc_wrapper = True
flags.unset('nrt') flags.nrt = False
# Run compilation pipeline # Run compilation pipeline
cres = compiler.compile_extra(typingctx=typingctx, cres = compiler.compile_extra(typingctx=typingctx,
targetctx=targetctx, targetctx=targetctx,
......
import numpy as np import numpy as np
from numba.np.ufunc.deviceufunc import (UFuncMechanism, GenerializedUFunc, # from numba.np.ufunc.deviceufunc import (UFuncMechanism, GenerializedUFunc,
# GUFuncCallSteps)
from numba.np.ufunc.deviceufunc import (UFuncMechanism, GeneralizedUFunc,
GUFuncCallSteps) GUFuncCallSteps)
from numba.roc.hsadrv.driver import dgpu_present from numba.roc.hsadrv.driver import dgpu_present
import numba.roc.hsadrv.devicearray as devicearray import numba.roc.hsadrv.devicearray as devicearray
...@@ -119,7 +121,7 @@ class _HsaGUFuncCallSteps(GUFuncCallSteps): ...@@ -119,7 +121,7 @@ class _HsaGUFuncCallSteps(GUFuncCallSteps):
kernel.configure(nelem, min(nelem, 64))(*args) kernel.configure(nelem, min(nelem, 64))(*args)
class HSAGenerializedUFunc(GenerializedUFunc): class HSAGenerializedUFunc(GeneralizedUFunc):
@property @property
def _call_steps(self): def _call_steps(self):
return _HsaGUFuncCallSteps return _HsaGUFuncCallSteps
......
import operator import operator
from functools import reduce from functools import reduce
from llvmlite.llvmpy.core import Type # from llvmlite.llvmpy.core import Type
import llvmlite.llvmpy.core as lc # import llvmlite.llvmpy.core as lc
import llvmlite.binding as ll import llvmlite.binding as ll
from llvmlite import ir from llvmlite import ir
from numba import roc from numba import roc
from numba.core.imputils import Registry from numba.core.imputils import Registry
from numba.core import types, cgutils from numba.core import types, cgutils
from numba.core.itanium_mangler import mangle_c, mangle, mangle_type # from numba.core.itanium_mangler import mangle_c, mangle, mangle_type
from numba.core.itanium_mangler import mangle, mangle_type
from numba.core.typing.npydecl import parse_dtype from numba.core.typing.npydecl import parse_dtype
from numba.roc import target from numba.roc import target
from numba.roc import stubs from numba.roc import stubs
...@@ -19,13 +20,13 @@ from numba.roc import enums ...@@ -19,13 +20,13 @@ from numba.roc import enums
registry = Registry() registry = Registry()
lower = registry.lower lower = registry.lower
_void_value = lc.Constant.null(lc.Type.pointer(lc.Type.int(8))) _void_value = ir.Constant(ir.PointerType(ir.IntType(8)), None)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def _declare_function(context, builder, name, sig, cargs, def _declare_function(context, builder, name, sig, cargs,
mangler=mangle_c): mangler=mangle):
"""Insert declaration for a opencl builtin function. """Insert declaration for a opencl builtin function.
Uses the Itanium mangler. Uses the Itanium mangler.
...@@ -50,11 +51,11 @@ def _declare_function(context, builder, name, sig, cargs, ...@@ -50,11 +51,11 @@ def _declare_function(context, builder, name, sig, cargs,
""" """
mod = builder.module mod = builder.module
if sig.return_type == types.void: if sig.return_type == types.void:
llretty = lc.Type.void() llretty = ir.VoidType()
else: else:
llretty = context.get_value_type(sig.return_type) llretty = context.get_value_type(sig.return_type)
llargs = [context.get_value_type(t) for t in sig.args] llargs = [context.get_value_type(t) for t in sig.args]
fnty = Type.function(llretty, llargs) fnty = ir.FunctionType(llretty, llargs)
mangled = mangler(name, cargs) mangled = mangler(name, cargs)
fn = mod.get_or_insert_function(fnty, mangled) fn = mod.get_or_insert_function(fnty, mangled)
fn.calling_convention = target.CC_SPIR_FUNC fn.calling_convention = target.CC_SPIR_FUNC
...@@ -154,7 +155,7 @@ def mem_fence_impl(context, builder, sig, args): ...@@ -154,7 +155,7 @@ def mem_fence_impl(context, builder, sig, args):
@lower(stubs.wavebarrier) @lower(stubs.wavebarrier)
def wavebarrier_impl(context, builder, sig, args): def wavebarrier_impl(context, builder, sig, args):
assert not args assert not args
fnty = Type.function(Type.void(), []) fnty = ir.FunctionType(ir.VoidType(), [])
fn = builder.module.declare_intrinsic('llvm.amdgcn.wave.barrier', fnty=fnty) fn = builder.module.declare_intrinsic('llvm.amdgcn.wave.barrier', fnty=fnty)
builder.call(fn, []) builder.call(fn, [])
return _void_value return _void_value
...@@ -166,12 +167,12 @@ def activelanepermute_wavewidth_impl(context, builder, sig, args): ...@@ -166,12 +167,12 @@ def activelanepermute_wavewidth_impl(context, builder, sig, args):
assert sig.args[0] == sig.args[2] assert sig.args[0] == sig.args[2]
elem_type = sig.args[0] elem_type = sig.args[0]
bitwidth = elem_type.bitwidth bitwidth = elem_type.bitwidth
intbitwidth = Type.int(bitwidth) intbitwidth = ir.IntType(bitwidth)
i32 = Type.int(32) i32 = ir.IntType(32)
i1 = Type.int(1) i1 = ir.IntType(1)
name = "__hsail_activelanepermute_wavewidth_b{0}".format(bitwidth) name = "__hsail_activelanepermute_wavewidth_b{0}".format(bitwidth)
fnty = Type.function(intbitwidth, [intbitwidth, i32, intbitwidth, i1]) fnty = ir.FunctionType(intbitwidth, [intbitwidth, i32, intbitwidth, i1])
fn = builder.module.get_or_insert_function(fnty, name=name) fn = builder.module.get_or_insert_function(fnty, name=name)
fn.calling_convention = target.CC_SPIR_FUNC fn.calling_convention = target.CC_SPIR_FUNC
...@@ -188,14 +189,14 @@ def _gen_ds_permute(intrinsic_name): ...@@ -188,14 +189,14 @@ def _gen_ds_permute(intrinsic_name):
""" """
assert sig.return_type == sig.args[1] assert sig.return_type == sig.args[1]
idx, src = args idx, src = args
i32 = Type.int(32) i32 = ir.IntType(32)
fnty = Type.function(i32, [i32, i32]) fnty = ir.FunctionType(i32, [i32, i32])
fn = builder.module.declare_intrinsic(intrinsic_name, fnty=fnty) fn = builder.module.declare_intrinsic(intrinsic_name, fnty=fnty)
# the args are byte addressable, VGPRs are 4 wide so mul idx by 4 # the args are byte addressable, VGPRs are 4 wide so mul idx by 4
# the idx might be an int64, this is ok to trunc to int32 as # the idx might be an int64, this is ok to trunc to int32 as
# wavefront_size is never likely overflow an int32 # wavefront_size is never likely overflow an int32
idx = builder.trunc(idx, i32) idx = builder.trunc(idx, i32)
four = lc.Constant.int(i32, 4) four = ir.Constant(i32, 4)
idx = builder.mul(idx, four) idx = builder.mul(idx, four)
# bit cast is so float32 works as packed i32, the return casts back # bit cast is so float32 works as packed i32, the return casts back
result = builder.call(fn, (idx, builder.bitcast(src, i32))) result = builder.call(fn, (idx, builder.bitcast(src, i32)))
...@@ -258,7 +259,7 @@ def hsail_smem_alloc_array_tuple(context, builder, sig, args): ...@@ -258,7 +259,7 @@ def hsail_smem_alloc_array_tuple(context, builder, sig, args):
def _generic_array(context, builder, shape, dtype, symbol_name, addrspace): def _generic_array(context, builder, shape, dtype, symbol_name, addrspace):
elemcount = reduce(operator.mul, shape, 1) elemcount = reduce(operator.mul, shape, 1)
lldtype = context.get_data_type(dtype) lldtype = context.get_data_type(dtype)
laryty = Type.array(lldtype, elemcount) laryty = ir.ArrayType(lldtype, elemcount)
if addrspace == target.SPIR_LOCAL_ADDRSPACE: if addrspace == target.SPIR_LOCAL_ADDRSPACE:
lmod = builder.module lmod = builder.module
...@@ -269,7 +270,7 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace): ...@@ -269,7 +270,7 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace):
if elemcount <= 0: if elemcount <= 0:
raise ValueError("array length <= 0") raise ValueError("array length <= 0")
else: else:
gvmem.linkage = lc.LINKAGE_INTERNAL gvmem.linkage = 'internal'
if dtype not in types.number_domain: if dtype not in types.number_domain:
raise TypeError("unsupported type: %s" % dtype) raise TypeError("unsupported type: %s" % dtype)
......
import re import re
from llvmlite.llvmpy import core as lc # from llvmlite.llvmpy import core as lc
from llvmlite import ir as llvmir # from llvmlite import ir as llvmir
from llvmlite import ir
from llvmlite import binding as ll from llvmlite import binding as ll
from numba.core import typing, types, utils, datamodel, cgutils from numba.core import typing, types, utils, datamodel, cgutils
from numba.core.utils import cached_property # from numba.core.utils import cached_property
from functools import cached_property
from numba.core.base import BaseContext from numba.core.base import BaseContext
from numba.core.callconv import MinimalCallConv from numba.core.callconv import MinimalCallConv
from numba.roc import codegen from numba.roc import codegen
...@@ -65,6 +67,9 @@ class HSATargetContext(BaseContext): ...@@ -65,6 +67,9 @@ class HSATargetContext(BaseContext):
implement_powi_as_math_call = True implement_powi_as_math_call = True
generic_addrspace = SPIR_GENERIC_ADDRSPACE generic_addrspace = SPIR_GENERIC_ADDRSPACE
def __init__(self, typingctx, target='ROCm'):
super().__init__(typingctx, target)
def init(self): def init(self):
self._internal_codegen = codegen.JITHSACodegen("numba.hsa.jit") self._internal_codegen = codegen.JITHSACodegen("numba.hsa.jit")
self._target_data = \ self._target_data = \
...@@ -89,7 +94,7 @@ class HSATargetContext(BaseContext): ...@@ -89,7 +94,7 @@ class HSATargetContext(BaseContext):
def target_data(self): def target_data(self):
return self._target_data return self._target_data
def mangler(self, name, argtypes): def mangler(self, name, argtypes, *, abi_tags=(), uid=None):
def repl(m): def repl(m):
ch = m.group(0) ch = m.group(0)
return "_%X_" % ord(ch) return "_%X_" % ord(ch)
...@@ -119,7 +124,7 @@ class HSATargetContext(BaseContext): ...@@ -119,7 +124,7 @@ class HSATargetContext(BaseContext):
arginfo = self.get_arg_packer(argtypes) arginfo = self.get_arg_packer(argtypes)
def sub_gen_with_global(lty): def sub_gen_with_global(lty):
if isinstance(lty, llvmir.PointerType): if isinstance(lty, ir.PointerType):
return (lty.pointee.as_pointer(SPIR_GLOBAL_ADDRSPACE), return (lty.pointee.as_pointer(SPIR_GLOBAL_ADDRSPACE),
lty.addrspace) lty.addrspace)
return lty, None return lty, None
...@@ -129,13 +134,13 @@ class HSATargetContext(BaseContext): ...@@ -129,13 +134,13 @@ class HSATargetContext(BaseContext):
arginfo.argument_types)) arginfo.argument_types))
else: else:
llargtys = changed = () llargtys = changed = ()
wrapperfnty = lc.Type.function(lc.Type.void(), llargtys) wrapperfnty = ir.FunctionType(ir.VoidType(), llargtys)
wrapper_module = self.create_module("hsa.kernel.wrapper") wrapper_module = self.create_module("hsa.kernel.wrapper")
wrappername = 'hsaPy_{name}'.format(name=func.name) wrappername = 'hsaPy_{name}'.format(name=func.name)
argtys = list(arginfo.argument_types) argtys = list(arginfo.argument_types)
fnty = lc.Type.function(lc.Type.int(), fnty = ir.FunctionType(ir.IntType(),
[self.call_conv.get_return_type( [self.call_conv.get_return_type(
types.pyobject)] + argtys) types.pyobject)] + argtys)
...@@ -144,7 +149,7 @@ class HSATargetContext(BaseContext): ...@@ -144,7 +149,7 @@ class HSATargetContext(BaseContext):
wrapper = wrapper_module.add_function(wrapperfnty, name=wrappername) wrapper = wrapper_module.add_function(wrapperfnty, name=wrappername)
builder = lc.Builder(wrapper.append_basic_block('')) builder = ir.IRBuilder(wrapper.append_basic_block(''))
# Adjust address space of each kernel argument # Adjust address space of each kernel argument
fixed_args = [] fixed_args = []
...@@ -193,7 +198,7 @@ class HSATargetContext(BaseContext): ...@@ -193,7 +198,7 @@ class HSATargetContext(BaseContext):
""" """
Handle addrspacecast Handle addrspacecast
""" """
ptras = llvmir.PointerType(src.type.pointee, addrspace=addrspace) ptras = ir.PointerType(src.type.pointee, addrspace=addrspace)
return builder.addrspacecast(src, ptras) return builder.addrspacecast(src, ptras)
...@@ -213,7 +218,7 @@ def set_hsa_kernel(fn): ...@@ -213,7 +218,7 @@ def set_hsa_kernel(fn):
# Mark kernels # Mark kernels
ocl_kernels = mod.get_or_insert_named_metadata("opencl.kernels") ocl_kernels = mod.get_or_insert_named_metadata("opencl.kernels")
ocl_kernels.add(lc.MetaData.get(mod, [fn, ocl_kernels.add(ir.Module.add_metadata(mod, [fn,
gen_arg_addrspace_md(fn), gen_arg_addrspace_md(fn),
gen_arg_access_qual_md(fn), gen_arg_access_qual_md(fn),
gen_arg_type(fn), gen_arg_type(fn),
...@@ -221,16 +226,16 @@ def set_hsa_kernel(fn): ...@@ -221,16 +226,16 @@ def set_hsa_kernel(fn):
gen_arg_base_type(fn)])) gen_arg_base_type(fn)]))
# SPIR version 2.0 # SPIR version 2.0
make_constant = lambda x: lc.Constant.int(lc.Type.int(), x) make_constant = lambda x: ir.Constant(ir.IntType(), x)
spir_version_constant = [make_constant(x) for x in SPIR_VERSION] spir_version_constant = [make_constant(x) for x in SPIR_VERSION]
spir_version = mod.get_or_insert_named_metadata("opencl.spir.version") spir_version = mod.get_or_insert_named_metadata("opencl.spir.version")
if not spir_version.operands: if not spir_version.operands:
spir_version.add(lc.MetaData.get(mod, spir_version_constant)) spir_version.add(ir.Module.add_metadata(mod, spir_version_constant))
ocl_version = mod.get_or_insert_named_metadata("opencl.ocl.version") ocl_version = mod.get_or_insert_named_metadata("opencl.ocl.version")
if not ocl_version.operands: if not ocl_version.operands:
ocl_version.add(lc.MetaData.get(mod, spir_version_constant)) ocl_version.add(ir.Module.add_metadata(mod, spir_version_constant))
## The following metadata does not seem to be necessary ## The following metadata does not seem to be necessary
# Other metadata # Other metadata
...@@ -259,9 +264,9 @@ def gen_arg_addrspace_md(fn): ...@@ -259,9 +264,9 @@ def gen_arg_addrspace_md(fn):
else: else:
codes.append(SPIR_PRIVATE_ADDRSPACE) codes.append(SPIR_PRIVATE_ADDRSPACE)
consts = [lc.Constant.int(lc.Type.int(), x) for x in codes] consts = [ir.Constant(ir.IntType(), x) for x in codes]
name = lc.MetaDataString.get(mod, "kernel_arg_addr_space") name = ir.MetaDataString(mod, "kernel_arg_addr_space")
return lc.MetaData.get(mod, [name] + consts) return ir.Module.add_metadata(mod, [name] + consts)
def gen_arg_access_qual_md(fn): def gen_arg_access_qual_md(fn):
...@@ -269,9 +274,9 @@ def gen_arg_access_qual_md(fn): ...@@ -269,9 +274,9 @@ def gen_arg_access_qual_md(fn):
Generate kernel_arg_access_qual metadata Generate kernel_arg_access_qual metadata
""" """
mod = fn.module mod = fn.module
consts = [lc.MetaDataString.get(mod, "none")] * len(fn.args) consts = [ir.MetaDataString(mod, "none")] * len(fn.args)
name = lc.MetaDataString.get(mod, "kernel_arg_access_qual") name = ir.MetaDataString(mod, "kernel_arg_access_qual")
return lc.MetaData.get(mod, [name] + consts) return ir.Module.add_metadata(mod, [name] + consts)
def gen_arg_type(fn): def gen_arg_type(fn):
...@@ -280,9 +285,9 @@ def gen_arg_type(fn): ...@@ -280,9 +285,9 @@ def gen_arg_type(fn):
""" """
mod = fn.module mod = fn.module
fnty = fn.type.pointee fnty = fn.type.pointee
consts = [lc.MetaDataString.get(mod, str(a)) for a in fnty.args] consts = [ir.MetaDataString(mod, str(a)) for a in fnty.args]
name = lc.MetaDataString.get(mod, "kernel_arg_type") name = ir.MetaDataString(mod, "kernel_arg_type")
return lc.MetaData.get(mod, [name] + consts) return ir.Module.add_metadata(mod, [name] + consts)
def gen_arg_type_qual(fn): def gen_arg_type_qual(fn):
...@@ -291,9 +296,9 @@ def gen_arg_type_qual(fn): ...@@ -291,9 +296,9 @@ def gen_arg_type_qual(fn):
""" """
mod = fn.module mod = fn.module
fnty = fn.type.pointee fnty = fn.type.pointee
consts = [lc.MetaDataString.get(mod, "") for _ in fnty.args] consts = [ir.MetaDataString(mod, "") for _ in fnty.args]
name = lc.MetaDataString.get(mod, "kernel_arg_type_qual") name = ir.MetaDataString(mod, "kernel_arg_type_qual")
return lc.MetaData.get(mod, [name] + consts) return ir.Module.add_metadata(mod, [name] + consts)
def gen_arg_base_type(fn): def gen_arg_base_type(fn):
...@@ -302,9 +307,9 @@ def gen_arg_base_type(fn): ...@@ -302,9 +307,9 @@ def gen_arg_base_type(fn):
""" """
mod = fn.module mod = fn.module
fnty = fn.type.pointee fnty = fn.type.pointee
consts = [lc.MetaDataString.get(mod, str(a)) for a in fnty.args] consts = [ir.MetaDataString(mod, str(a)) for a in fnty.args]
name = lc.MetaDataString.get(mod, "kernel_arg_base_type") name = ir.MetaDataString(mod, "kernel_arg_base_type")
return lc.MetaData.get(mod, [name] + consts) return ir.Module.add_metadata(mod, [name] + consts)
class HSACallConv(MinimalCallConv): class HSACallConv(MinimalCallConv):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment