import re

# from llvmlite.llvmpy import core as lc
# from llvmlite import ir as llvmir
from llvmlite import ir
from llvmlite import binding as ll

from numba.core import typing, types, utils, datamodel, cgutils
# from numba.core.utils import cached_property
from functools import cached_property
from numba.core.base import BaseContext
from numba.core.callconv import MinimalCallConv
from numba.roc import codegen
from .hlc import DATALAYOUT

CC_SPIR_KERNEL = "spir_kernel"
CC_SPIR_FUNC = ""


# -----------------------------------------------------------------------------
# Typing


class HSATypingContext(typing.BaseContext):
    def load_additional_registries(self):
        from . import hsadecl, mathdecl

        self.install_registry(hsadecl.registry)
        self.install_registry(mathdecl.registry)


# -----------------------------------------------------------------------------
# Implementation

VALID_CHARS = re.compile(r'[^a-z0-9]', re.I)


# Address spaces
SPIR_GENERIC_ADDRSPACE = 0
SPIR_GLOBAL_ADDRSPACE = 1
SPIR_REGION_ADDRSPACE = 2
SPIR_CONSTANT_ADDRSPACE = 4
SPIR_LOCAL_ADDRSPACE = 3
SPIR_PRIVATE_ADDRSPACE = 5
SPIR_CONSTANT_32BIT_ADDRSPACE = 6

SPIR_VERSION = (2, 0)


class GenericPointerModel(datamodel.PrimitiveModel):
    def __init__(self, dmm, fe_type):
        adrsp = SPIR_GENERIC_ADDRSPACE
        be_type = dmm.lookup(fe_type.dtype).get_data_type().as_pointer(adrsp)
        super(GenericPointerModel, self).__init__(dmm, fe_type, be_type)


def _init_data_model_manager():
    dmm = datamodel.default_manager.copy()
    dmm.register(types.CPointer, GenericPointerModel)
    return dmm


hsa_data_model_manager = _init_data_model_manager()


class HSATargetContext(BaseContext):
    implement_powi_as_math_call = True
    generic_addrspace = SPIR_GENERIC_ADDRSPACE

    def __init__(self, typingctx, target='ROCm'):
        super().__init__(typingctx, target)

    def init(self):
        self._internal_codegen = codegen.JITHSACodegen("numba.hsa.jit")
        self._target_data = \
            ll.create_target_data(DATALAYOUT)
        # Override data model manager
        self.data_model_manager = hsa_data_model_manager

    def load_additional_registries(self):
        from . import hsaimpl, mathimpl

        self.insert_func_defn(hsaimpl.registry.functions)
        self.insert_func_defn(mathimpl.registry.functions)
    
    # sugon: adapt for numba-0.58
    # Overrides
    def create_module(self, name):
        return self._internal_codegen._create_empty_module(name)
        #  return lc.Module(name)
    
    @cached_property
    def call_conv(self):
        return HSACallConv(self)

    def codegen(self):
        return self._internal_codegen

    @property
    def target_data(self):
        return self._target_data

    def mangler(self, name, argtypes, *, abi_tags=(), uid=None):
        def repl(m):
            ch = m.group(0)
            return "_%X_" % ord(ch)

        qualified = name + '.' + '.'.join(str(a) for a in argtypes)
        mangled = VALID_CHARS.sub(repl, qualified)
        return 'hsapy_devfn_' + mangled

    def prepare_hsa_kernel(self, func, argtypes):
        module = func.module
        func.linkage = 'linkonce_odr'

        module.data_layout = DATALAYOUT
        wrapper = self.generate_kernel_wrapper(func, argtypes)

        return wrapper

    def mark_hsa_device(self, func):
        # Adapt to SPIR
        # module = func.module
        func.calling_convention = CC_SPIR_FUNC
        func.linkage = 'linkonce_odr'
        return func

    def generate_kernel_wrapper(self, func, argtypes):
        module = func.module
        arginfo = self.get_arg_packer(argtypes)

        def sub_gen_with_global(lty):
            if isinstance(lty, ir.PointerType):
                return (lty.pointee.as_pointer(SPIR_GLOBAL_ADDRSPACE),
                        lty.addrspace)
            return lty, None

        if len(arginfo.argument_types) > 0:
            llargtys, changed = zip(*map(sub_gen_with_global,
                                         arginfo.argument_types))
        else:
            llargtys = changed = ()
        wrapperfnty = ir.FunctionType(ir.VoidType(), llargtys)

        wrapper_module = self.create_module("hsa.kernel.wrapper")
        wrappername = 'hsaPy_{name}'.format(name=func.name)

        argtys = list(arginfo.argument_types)
        fnty = ir.FunctionType(ir.IntType(32),
                                [self.call_conv.get_return_type(
                                    types.pyobject)] + argtys)
        func = ir.Function(wrapper_module, fnty, name=func.name)
        func.calling_convention = CC_SPIR_FUNC

        wrapper = ir.Function(wrapper_module, wrapperfnty, name=wrappername)

        builder = ir.IRBuilder(wrapper.append_basic_block(''))

        # Adjust address space of each kernel argument
        fixed_args = []
        for av, adrsp in zip(wrapper.args, changed):
            if adrsp is not None:
                casted = self.addrspacecast(builder, av, adrsp)
                fixed_args.append(casted)
            else:
                fixed_args.append(av)

        callargs = arginfo.from_arguments(builder, fixed_args)

        # XXX handle error status
        status, _ = self.call_conv.call_function(builder, func, types.void,
                                                 argtypes, callargs)
        builder.ret_void()

        set_hsa_kernel(wrapper)

        # Link
        module.link_in(ll.parse_assembly(str(wrapper_module)))
        # To enable inlining which is essential because addrspacecast 1->0 is
        # illegal.  Inlining will optimize the addrspacecast out.
        func.linkage = 'internal'
        wrapper = module.get_function(wrapper.name)
        module.get_function(func.name).linkage = 'internal'
        return wrapper

    def declare_function(self, module, fndesc):
        ret = super(HSATargetContext, self).declare_function(module, fndesc)
        # XXX: Refactor fndesc instead of this special case
        if fndesc.llvm_func_name.startswith('hsapy_devfn'):
            ret.calling_convention = CC_SPIR_FUNC
        return ret

    def make_constant_array(self, builder, typ, ary):
        """
        Return dummy value.
        """
        #
        # a = self.make_array(typ)(self, builder)
        # return a._getvalue()
        raise NotImplementedError

    def addrspacecast(self, builder, src, addrspace):
        """
        Handle addrspacecast
        """
        ptras = ir.PointerType(src.type.pointee, addrspace=addrspace)
        return builder.addrspacecast(src, ptras)


def set_hsa_kernel(fn):
    """
    Ensure `fn` is usable as a SPIR kernel.
    - Fix calling convention
    - Add metadata
    """
    mod = fn.module

    # Set nounwind
    # fn.add_attribute(lc.ATTR_NO_UNWIND)

    # Set SPIR kernel calling convention
    fn.calling_convention = CC_SPIR_KERNEL

    # Mark kernels
    ocl_kernels = cgutils.get_or_insert_named_metadata(mod, 'opencl.kernels')
    ocl_kernels.add(ir.Module.add_metadata(mod, [fn,
                                          gen_arg_addrspace_md(fn),
                                          gen_arg_access_qual_md(fn),
                                          gen_arg_type(fn),
                                          gen_arg_type_qual(fn),
                                          gen_arg_base_type(fn)]))

    # SPIR version 2.0
    make_constant = lambda x: ir.Constant(ir.IntType(32), x)
    spir_version_constant = [make_constant(x) for x in SPIR_VERSION]


    spir_version = cgutils.get_or_insert_named_metadata(mod, 'opencl.spir.version')
    if not spir_version.operands:
        spir_version.add(ir.Module.add_metadata(mod, spir_version_constant))

    ocl_version = cgutils.get_or_insert_named_metadata(mod, 'opencl.spir.version')
    if not ocl_version.operands:
        ocl_version.add(ir.Module.add_metadata(mod, spir_version_constant))

        ## The following metadata does not seem to be necessary
        # Other metadata
        # empty_md = lc.MetaData.get(mod, ())
        # others = ["opencl.used.extensions",
        #           "opencl.used.optional.core.features",
        #           "opencl.compiler.options"]cat
        #
        # for name in others:
        #     nmd = mod.get_or_insert_named_metadata(name)
        #     if not nmd.operands:
        #         nmd.add(empty_md)


def gen_arg_addrspace_md(fn):
    """
    Generate kernel_arg_addr_space metadata
    """
    mod = fn.module
    fnty = fn.type.pointee
    codes = []

    for a in fnty.args:
        if cgutils.is_pointer(a):
            codes.append(SPIR_GLOBAL_ADDRSPACE)
        else:
            codes.append(SPIR_PRIVATE_ADDRSPACE)

    consts = [ir.Constant(ir.IntType(32), x) for x in codes]
    name = ir.MetaDataString(mod, "kernel_arg_addr_space")
    return ir.Module.add_metadata(mod, [name] + consts)


def gen_arg_access_qual_md(fn):
    """
    Generate kernel_arg_access_qual metadata
    """
    mod = fn.module
    consts = [ir.MetaDataString(mod, "none")] * len(fn.args)
    name = ir.MetaDataString(mod, "kernel_arg_access_qual")
    return ir.Module.add_metadata(mod, [name] + consts)


def gen_arg_type(fn):
    """
    Generate kernel_arg_type metadata
    """
    mod = fn.module
    fnty = fn.type.pointee
    consts = [ir.MetaDataString(mod, str(a)) for a in fnty.args]
    name = ir.MetaDataString(mod, "kernel_arg_type")
    return ir.Module.add_metadata(mod, [name] + consts)


def gen_arg_type_qual(fn):
    """
    Generate kernel_arg_type_qual metadata
    """
    mod = fn.module
    fnty = fn.type.pointee
    consts = [ir.MetaDataString(mod, "") for _ in fnty.args]
    name = ir.MetaDataString(mod, "kernel_arg_type_qual")
    return ir.Module.add_metadata(mod, [name] + consts)


def gen_arg_base_type(fn):
    """
    Generate kernel_arg_base_type metadata
    """
    mod = fn.module
    fnty = fn.type.pointee
    consts = [ir.MetaDataString(mod, str(a)) for a in fnty.args]
    name = ir.MetaDataString(mod, "kernel_arg_base_type")
    return ir.Module.add_metadata(mod, [name] + consts)


class HSACallConv(MinimalCallConv):
    def call_function(self, builder, callee, resty, argtys, args, env=None):
        """
        Call the Numba-compiled *callee*.
        """
        assert env is None
        retty = callee.args[0].type.pointee
        retvaltmp = cgutils.alloca_once(builder, retty)
        # initialize return value
        builder.store(cgutils.get_null_value(retty), retvaltmp)

        arginfo = self.context.get_arg_packer(argtys)
        args = arginfo.as_arguments(builder, args)
        realargs = [retvaltmp] + list(args)
        code = builder.call(callee, realargs)
        status = self._get_return_status(builder, code)
        retval = builder.load(retvaltmp)
        out = self.context.get_returned_value(builder, resty, retval)
        return status, out
