Commit 3e5f428e authored by dugupeiwen's avatar dugupeiwen
Browse files

Remove use of llvmlite.llvmpy for 0.58

parent 5be111ee
from llvmlite import binding as ll
from llvmlite.llvmpy import core as lc
# from llvmlite.llvmpy import core as lc
import llvmlite.ir as llvmir
from numba.core import utils
from numba.core.codegen import BaseCPUCodegen, CodeLibrary
from numba.core.codegen import Codegen, CodeLibrary, CPUCodeLibrary
from .hlc import DATALAYOUT, TRIPLE, hlc
class HSACodeLibrary(CodeLibrary):
class HSACodeLibrary(CPUCodeLibrary):
def _optimize_functions(self, ll_module):
pass
......@@ -25,17 +25,55 @@ class HSACodeLibrary(CodeLibrary):
return str(out.hsail)
class JITHSACodegen(BaseCPUCodegen):
# class JITHSACodegen(Codegen):
# _library_class = HSACodeLibrary
# def _init(self, llvm_module):
# assert list(llvm_module.global_variables) == [], "Module isn't empty"
# self._data_layout = DATALAYOUT[utils.MACHINE_BITS]
# self._target_data = ll.create_target_data(self._data_layout)
# def _create_empty_module(self, name):
# ir_module = llvmir.Module(name)
# ir_module.triple = TRIPLE
# return ir_module
# def _module_pass_manager(self):
# raise NotImplementedError
# def _function_pass_manager(self, llvm_module):
# raise NotImplementedError
# def _add_module(self, module):
# pass
class JITHSACodegen(Codegen):
_library_class = HSACodeLibrary
def __init__(self, module_name):
# initialize_llvm()
ll.initialize()
ll.initialize_native_target()
ll.initialize_native_asmprinter()
self._data_layout = None
self._llvm_module = ll.parse_assembly(
str(self._create_empty_module(module_name)))
self._llvm_module.name = "global_codegen_module"
# self._rtlinker = RuntimeLinker()
self._init(self._llvm_module)
def _init(self, llvm_module):
assert list(llvm_module.global_variables) == [], "Module isn't empty"
self._data_layout = DATALAYOUT[utils.MACHINE_BITS]
self._target_data = ll.create_target_data(self._data_layout)
def _create_empty_module(self, name):
ir_module = lc.Module(name)
ir_module = llvmir.Module(name)
ir_module.triple = TRIPLE
if self._data_layout:
ir_module.data_layout = self._data_layout
return ir_module
def _module_pass_manager(self):
......
......@@ -26,10 +26,10 @@ def compile_hsa(pyfunc, return_type, args, debug):
# TODO handle debug flag
flags = compiler.Flags()
# Do not compile (generate native code), just lower (to LLVM)
flags.set('no_compile')
flags.set('no_cpython_wrapper')
flags.set('no_cfunc_wrapper')
flags.unset('nrt')
flags.no_compile = True
flags.no_cpython_wrapper = True
flags.no_cfunc_wrapper = True
flags.nrt = False
# Run compilation pipeline
cres = compiler.compile_extra(typingctx=typingctx,
targetctx=targetctx,
......
import numpy as np
from numba.np.ufunc.deviceufunc import (UFuncMechanism, GenerializedUFunc,
# from numba.np.ufunc.deviceufunc import (UFuncMechanism, GenerializedUFunc,
# GUFuncCallSteps)
from numba.np.ufunc.deviceufunc import (UFuncMechanism, GeneralizedUFunc,
GUFuncCallSteps)
from numba.roc.hsadrv.driver import dgpu_present
import numba.roc.hsadrv.devicearray as devicearray
......@@ -119,7 +121,7 @@ class _HsaGUFuncCallSteps(GUFuncCallSteps):
kernel.configure(nelem, min(nelem, 64))(*args)
class HSAGenerializedUFunc(GenerializedUFunc):
class HSAGenerializedUFunc(GeneralizedUFunc):
@property
def _call_steps(self):
return _HsaGUFuncCallSteps
......
import operator
from functools import reduce
from llvmlite.llvmpy.core import Type
import llvmlite.llvmpy.core as lc
# from llvmlite.llvmpy.core import Type
# import llvmlite.llvmpy.core as lc
import llvmlite.binding as ll
from llvmlite import ir
from numba import roc
from numba.core.imputils import Registry
from numba.core import types, cgutils
from numba.core.itanium_mangler import mangle_c, mangle, mangle_type
# from numba.core.itanium_mangler import mangle_c, mangle, mangle_type
from numba.core.itanium_mangler import mangle, mangle_type
from numba.core.typing.npydecl import parse_dtype
from numba.roc import target
from numba.roc import stubs
......@@ -19,13 +20,13 @@ from numba.roc import enums
registry = Registry()
lower = registry.lower
_void_value = lc.Constant.null(lc.Type.pointer(lc.Type.int(8)))
_void_value = ir.Constant(ir.PointerType(ir.IntType(8)), None)
# -----------------------------------------------------------------------------
def _declare_function(context, builder, name, sig, cargs,
mangler=mangle_c):
mangler=mangle):
"""Insert declaration for a opencl builtin function.
Uses the Itanium mangler.
......@@ -50,11 +51,11 @@ def _declare_function(context, builder, name, sig, cargs,
"""
mod = builder.module
if sig.return_type == types.void:
llretty = lc.Type.void()
llretty = ir.VoidType()
else:
llretty = context.get_value_type(sig.return_type)
llargs = [context.get_value_type(t) for t in sig.args]
fnty = Type.function(llretty, llargs)
fnty = ir.FunctionType(llretty, llargs)
mangled = mangler(name, cargs)
fn = mod.get_or_insert_function(fnty, mangled)
fn.calling_convention = target.CC_SPIR_FUNC
......@@ -154,7 +155,7 @@ def mem_fence_impl(context, builder, sig, args):
@lower(stubs.wavebarrier)
def wavebarrier_impl(context, builder, sig, args):
assert not args
fnty = Type.function(Type.void(), [])
fnty = ir.FunctionType(ir.VoidType(), [])
fn = builder.module.declare_intrinsic('llvm.amdgcn.wave.barrier', fnty=fnty)
builder.call(fn, [])
return _void_value
......@@ -166,12 +167,12 @@ def activelanepermute_wavewidth_impl(context, builder, sig, args):
assert sig.args[0] == sig.args[2]
elem_type = sig.args[0]
bitwidth = elem_type.bitwidth
intbitwidth = Type.int(bitwidth)
i32 = Type.int(32)
i1 = Type.int(1)
intbitwidth = ir.IntType(bitwidth)
i32 = ir.IntType(32)
i1 = ir.IntType(1)
name = "__hsail_activelanepermute_wavewidth_b{0}".format(bitwidth)
fnty = Type.function(intbitwidth, [intbitwidth, i32, intbitwidth, i1])
fnty = ir.FunctionType(intbitwidth, [intbitwidth, i32, intbitwidth, i1])
fn = builder.module.get_or_insert_function(fnty, name=name)
fn.calling_convention = target.CC_SPIR_FUNC
......@@ -188,14 +189,14 @@ def _gen_ds_permute(intrinsic_name):
"""
assert sig.return_type == sig.args[1]
idx, src = args
i32 = Type.int(32)
fnty = Type.function(i32, [i32, i32])
i32 = ir.IntType(32)
fnty = ir.FunctionType(i32, [i32, i32])
fn = builder.module.declare_intrinsic(intrinsic_name, fnty=fnty)
# the args are byte addressable, VGPRs are 4 wide so mul idx by 4
# the idx might be an int64, this is ok to trunc to int32 as
# wavefront_size is never likely overflow an int32
idx = builder.trunc(idx, i32)
four = lc.Constant.int(i32, 4)
four = ir.Constant(i32, 4)
idx = builder.mul(idx, four)
# bit cast is so float32 works as packed i32, the return casts back
result = builder.call(fn, (idx, builder.bitcast(src, i32)))
......@@ -258,7 +259,7 @@ def hsail_smem_alloc_array_tuple(context, builder, sig, args):
def _generic_array(context, builder, shape, dtype, symbol_name, addrspace):
elemcount = reduce(operator.mul, shape, 1)
lldtype = context.get_data_type(dtype)
laryty = Type.array(lldtype, elemcount)
laryty = ir.ArrayType(lldtype, elemcount)
if addrspace == target.SPIR_LOCAL_ADDRSPACE:
lmod = builder.module
......@@ -269,7 +270,7 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace):
if elemcount <= 0:
raise ValueError("array length <= 0")
else:
gvmem.linkage = lc.LINKAGE_INTERNAL
gvmem.linkage = 'internal'
if dtype not in types.number_domain:
raise TypeError("unsupported type: %s" % dtype)
......
import re
from llvmlite.llvmpy import core as lc
from llvmlite import ir as llvmir
# from llvmlite.llvmpy import core as lc
# from llvmlite import ir as llvmir
from llvmlite import ir
from llvmlite import binding as ll
from numba.core import typing, types, utils, datamodel, cgutils
from numba.core.utils import cached_property
# from numba.core.utils import cached_property
from functools import cached_property
from numba.core.base import BaseContext
from numba.core.callconv import MinimalCallConv
from numba.roc import codegen
......@@ -65,6 +67,9 @@ class HSATargetContext(BaseContext):
implement_powi_as_math_call = True
generic_addrspace = SPIR_GENERIC_ADDRSPACE
def __init__(self, typingctx, target='ROCm'):
super().__init__(typingctx, target)
def init(self):
self._internal_codegen = codegen.JITHSACodegen("numba.hsa.jit")
self._target_data = \
......@@ -89,7 +94,7 @@ class HSATargetContext(BaseContext):
def target_data(self):
return self._target_data
def mangler(self, name, argtypes):
def mangler(self, name, argtypes, *, abi_tags=(), uid=None):
def repl(m):
ch = m.group(0)
return "_%X_" % ord(ch)
......@@ -119,7 +124,7 @@ class HSATargetContext(BaseContext):
arginfo = self.get_arg_packer(argtypes)
def sub_gen_with_global(lty):
if isinstance(lty, llvmir.PointerType):
if isinstance(lty, ir.PointerType):
return (lty.pointee.as_pointer(SPIR_GLOBAL_ADDRSPACE),
lty.addrspace)
return lty, None
......@@ -129,13 +134,13 @@ class HSATargetContext(BaseContext):
arginfo.argument_types))
else:
llargtys = changed = ()
wrapperfnty = lc.Type.function(lc.Type.void(), llargtys)
wrapperfnty = ir.FunctionType(ir.VoidType(), llargtys)
wrapper_module = self.create_module("hsa.kernel.wrapper")
wrappername = 'hsaPy_{name}'.format(name=func.name)
argtys = list(arginfo.argument_types)
fnty = lc.Type.function(lc.Type.int(),
fnty = ir.FunctionType(ir.IntType(),
[self.call_conv.get_return_type(
types.pyobject)] + argtys)
......@@ -144,7 +149,7 @@ class HSATargetContext(BaseContext):
wrapper = wrapper_module.add_function(wrapperfnty, name=wrappername)
builder = lc.Builder(wrapper.append_basic_block(''))
builder = ir.IRBuilder(wrapper.append_basic_block(''))
# Adjust address space of each kernel argument
fixed_args = []
......@@ -193,7 +198,7 @@ class HSATargetContext(BaseContext):
"""
Handle addrspacecast
"""
ptras = llvmir.PointerType(src.type.pointee, addrspace=addrspace)
ptras = ir.PointerType(src.type.pointee, addrspace=addrspace)
return builder.addrspacecast(src, ptras)
......@@ -213,7 +218,7 @@ def set_hsa_kernel(fn):
# Mark kernels
ocl_kernels = mod.get_or_insert_named_metadata("opencl.kernels")
ocl_kernels.add(lc.MetaData.get(mod, [fn,
ocl_kernels.add(ir.Module.add_metadata(mod, [fn,
gen_arg_addrspace_md(fn),
gen_arg_access_qual_md(fn),
gen_arg_type(fn),
......@@ -221,16 +226,16 @@ def set_hsa_kernel(fn):
gen_arg_base_type(fn)]))
# SPIR version 2.0
make_constant = lambda x: lc.Constant.int(lc.Type.int(), x)
make_constant = lambda x: ir.Constant(ir.IntType(), x)
spir_version_constant = [make_constant(x) for x in SPIR_VERSION]
spir_version = mod.get_or_insert_named_metadata("opencl.spir.version")
if not spir_version.operands:
spir_version.add(lc.MetaData.get(mod, spir_version_constant))
spir_version.add(ir.Module.add_metadata(mod, spir_version_constant))
ocl_version = mod.get_or_insert_named_metadata("opencl.ocl.version")
if not ocl_version.operands:
ocl_version.add(lc.MetaData.get(mod, spir_version_constant))
ocl_version.add(ir.Module.add_metadata(mod, spir_version_constant))
## The following metadata does not seem to be necessary
# Other metadata
......@@ -259,9 +264,9 @@ def gen_arg_addrspace_md(fn):
else:
codes.append(SPIR_PRIVATE_ADDRSPACE)
consts = [lc.Constant.int(lc.Type.int(), x) for x in codes]
name = lc.MetaDataString.get(mod, "kernel_arg_addr_space")
return lc.MetaData.get(mod, [name] + consts)
consts = [ir.Constant(ir.IntType(), x) for x in codes]
name = ir.MetaDataString(mod, "kernel_arg_addr_space")
return ir.Module.add_metadata(mod, [name] + consts)
def gen_arg_access_qual_md(fn):
......@@ -269,9 +274,9 @@ def gen_arg_access_qual_md(fn):
Generate kernel_arg_access_qual metadata
"""
mod = fn.module
consts = [lc.MetaDataString.get(mod, "none")] * len(fn.args)
name = lc.MetaDataString.get(mod, "kernel_arg_access_qual")
return lc.MetaData.get(mod, [name] + consts)
consts = [ir.MetaDataString(mod, "none")] * len(fn.args)
name = ir.MetaDataString(mod, "kernel_arg_access_qual")
return ir.Module.add_metadata(mod, [name] + consts)
def gen_arg_type(fn):
......@@ -280,9 +285,9 @@ def gen_arg_type(fn):
"""
mod = fn.module
fnty = fn.type.pointee
consts = [lc.MetaDataString.get(mod, str(a)) for a in fnty.args]
name = lc.MetaDataString.get(mod, "kernel_arg_type")
return lc.MetaData.get(mod, [name] + consts)
consts = [ir.MetaDataString(mod, str(a)) for a in fnty.args]
name = ir.MetaDataString(mod, "kernel_arg_type")
return ir.Module.add_metadata(mod, [name] + consts)
def gen_arg_type_qual(fn):
......@@ -291,9 +296,9 @@ def gen_arg_type_qual(fn):
"""
mod = fn.module
fnty = fn.type.pointee
consts = [lc.MetaDataString.get(mod, "") for _ in fnty.args]
name = lc.MetaDataString.get(mod, "kernel_arg_type_qual")
return lc.MetaData.get(mod, [name] + consts)
consts = [ir.MetaDataString(mod, "") for _ in fnty.args]
name = ir.MetaDataString(mod, "kernel_arg_type_qual")
return ir.Module.add_metadata(mod, [name] + consts)
def gen_arg_base_type(fn):
......@@ -302,9 +307,9 @@ def gen_arg_base_type(fn):
"""
mod = fn.module
fnty = fn.type.pointee
consts = [lc.MetaDataString.get(mod, str(a)) for a in fnty.args]
name = lc.MetaDataString.get(mod, "kernel_arg_base_type")
return lc.MetaData.get(mod, [name] + consts)
consts = [ir.MetaDataString(mod, str(a)) for a in fnty.args]
name = ir.MetaDataString(mod, "kernel_arg_base_type")
return ir.Module.add_metadata(mod, [name] + consts)
class HSACallConv(MinimalCallConv):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment