from __future__ import annotations
from typing import Union, Tuple, Optional, List, Dict, Iterable
from .exception import check, SeekException
from .utility import get_caller_assignments
import functools
import enum
import numpy as np


class GprType(enum.Enum):
    Special = enum.auto()
    V = enum.auto()
    S = enum.auto()
    A = enum.auto()

    def is_special(self) -> bool:
        return self == GprType.Special

    def is_vgpr(self) -> bool:
        return self == GprType.V

    def is_sgpr(self) -> bool:
        return self == GprType.S

    def is_agpr(self) -> bool:
        return self == GprType.A


class Align:
    __slots__ = ["__divisor", "__remainder"]

    def __init__(self, divisor: int, remainder: int = 0):
        check(isinstance(divisor, int))
        check(isinstance(remainder, int))
        check(0 < divisor, f"Invalid divisor: {divisor}")
        check(0 <= remainder < divisor, f"Invalid remainder: {remainder}")
        self.__divisor = divisor
        self.__remainder = remainder

    @property
    def divisor(self) -> int:
        return self.__divisor

    @property
    def remainder(self) -> int:
        return self.__remainder

    def offset_align(self, offset: int) -> Align:
        new_remainder = (self.remainder + offset) % self.divisor
        assert 0 <= new_remainder < self.divisor
        return Align(self.divisor, new_remainder)

    def is_aligned(self, div: int, rem: int = 0) -> bool:
        assert 0 <= rem < div
        return self.divisor % div == 0 and self.remainder % div == rem

    def __hash__(self):
        return hash((self.divisor, self.remainder))

    def __eq__(self, other):
        if not isinstance(other, Align):
            return NotImplemented
        return (self.divisor, self.remainder) == (other.divisor, other.remainder)

    def __repr__(self):
        return f"Align({self.divisor}, {self.remainder})"


@functools.total_ordering
class Gpr:
    __slots__ = [
        "__is_view", "__name", "__count",
        "__gpr_rtype", "__gpr_id", "__gpr_align",
        "__view_base_gpr", "__view_base_offset", "__view_annotation",
    ]

    __attrs_gpr = [s for s in __slots__ if s.startswith("__gpr_")]
    __attrs_view = [s for s in __slots__ if s.startswith("__view_")]

    __ANNOTATION_NEG = 1 << 0
    __ANNOTATION_ABS = 1 << 1
    __ANNOTATION_SEXT = 1 << 2
    __ANNOTATION_ALL = (__ANNOTATION_NEG | __ANNOTATION_ABS | __ANNOTATION_SEXT)

    __create_key = object()
    __id_counter = 0

    def __init__(self, create_key: object, is_view: bool, name: str, count: int, **kwargs):
        check(create_key is Gpr.__create_key, "Gpr constructor is private")
        self.__is_view = is_view  # type: bool
        self.__name = name  # type: str
        self.__count = count  # type: int

        if is_view:  # this is a view
            assert len(kwargs) == len(Gpr.__attrs_view)
            for key in kwargs:
                assert key in Gpr.__attrs_view, key
        else:  # this is a base Gpr
            assert len(kwargs) == len(Gpr.__attrs_gpr)
            for key in kwargs:
                assert key in Gpr.__attrs_gpr, key

        for key, value in kwargs.items():
            # Gpr's "__xxx" maps to "_Gpr__xxx" in object's attributes
            setattr(self, f"_{Gpr.__name__}{key}", value)

    @property
    def is_view(self) -> bool:
        return self.__is_view

    @property
    def base_gpr(self) -> Gpr:
        if self.is_view:  # this is a view
            return self.__view_base_gpr
        else:  # this is a base Gpr
            return self

    @property
    def base_offset(self) -> int:
        if self.is_view:  # this is a view
            return self.__view_base_offset
        else:  # this is a base Gpr
            return 0

    @property
    def name(self) -> str:
        return self.__name

    @property
    def count(self) -> int:
        return self.__count

    @property
    def rtype(self) -> GprType:
        return self.base_gpr.__gpr_rtype

    @property
    def align(self) -> Optional[Align]:
        # This works for both base Gpr and views
        base_align = self.base_gpr.__gpr_align  # type: Optional[Align]
        if self.rtype.is_special():
            assert base_align is None
            return None
        else:
            assert isinstance(base_align, Align)
            return base_align.offset_align(self.base_offset)

    @property
    def __annotation(self) -> int:
        if self.is_view:
            return self.__view_annotation
        else:
            return 0

    @property
    def is_neg(self):
        return bool(self.__annotation & Gpr.__ANNOTATION_NEG)

    @property
    def is_abs(self):
        return bool(self.__annotation & Gpr.__ANNOTATION_ABS)

    @property
    def is_sext(self):
        return bool(self.__annotation & Gpr.__ANNOTATION_SEXT)

    def __hash__(self):
        return hash((self.base_gpr.__gpr_id, self.base_offset, self.count))

    def __eq__(self, other) -> bool:
        if not isinstance(other, Gpr):
            return NotImplemented

        if self.base_gpr.__gpr_id != other.base_gpr.__gpr_id:
            return False
        assert self.base_gpr is other.base_gpr

        return self.base_offset == other.base_offset and self.count == other.count

    def __lt__(self, other: Optional[Gpr]) -> bool:
        if not isinstance(other, Gpr):
            return NotImplemented

        if self.base_gpr.__gpr_id != other.base_gpr.__gpr_id:
            return self.base_gpr.__gpr_id < other.base_gpr.__gpr_id
        assert self.base_gpr is other.base_gpr

        return (self.base_offset, self.count) < (other.base_offset, other.count)

    def __copy__(self):
        raise SeekException("Gpr is not copyable")

    def __deepcopy__(self, *args, **kwargs):
        raise SeekException("Gpr is not deeply copyable")

    @staticmethod
    def _check_and_get_name(name: str, rtype: GprType):
        if rtype is not None:
            check(rtype in GprType.__members__.values(), f"Invalid Gpr rtype: {rtype}")

        if name.startswith("v_"):
            check(rtype is None or rtype == GprType.V, f"Conflict name {name!r} and rtype {rtype}")
            name = name[2:]
            rtype = GprType.V
        elif name.startswith("s_"):
            check(rtype is None or rtype == GprType.S, f"Conflict name {name!r} and rtype {rtype}")
            name = name[2:]
            rtype = GprType.S
        elif name.startswith("a_"):
            check(rtype is None or rtype == GprType.A, f"Conflict name {name!r} and rtype {rtype}")
            name = name[2:]
            rtype = GprType.A
        else:
            check(rtype is not None, f"Gpr rtype is not given, nor can we infer from its name {name!r}")
        return name, rtype

    @staticmethod
    def _create_gpr(name: str, count: int, rtype: GprType, align: Optional[Union[Align, Tuple[int, int], int]]) -> Gpr:
        # Check name and rtype
        name, rtype = Gpr._check_and_get_name(name, rtype)

        # Check count
        check(count > 0, f"Invalid Gpr count: {count}")

        # Check the alignment if this is not a special Gpr
        if rtype.is_special():
            check(align is None, "Special Gpr can't have alignment requirement")
        else:
            if align is None:
                align = Align(1, 0)
            elif isinstance(align, int):
                align = Align(align, 0)
            elif isinstance(align, tuple):
                check(len(align) == 2)
                align = Align(align[0], align[1])
            else:
                check(isinstance(align, Align))

        # Increase counter before we assign its value to current Gpr
        Gpr.__id_counter += 1

        gpr = Gpr(Gpr.__create_key, is_view=False, name=name, count=count,
                  __gpr_rtype=rtype,
                  __gpr_id=Gpr.__id_counter,
                  __gpr_align=align)
        return gpr

    def __create_view(self, name: str, offset: int, count: int, annotation: int) -> Gpr:
        # Check name and rtype
        name, rtype = Gpr._check_and_get_name(name, self.rtype)

        # Check offset and count
        check(offset >= 0, f"Invalid offset: {offset}")
        check(count > 0, f"Invalid Gpr count: {count}")
        check(offset + count <= self.count,
              f"Invalid offset and count combination: {offset} + {count} > {self.count}")

        # Check annotation
        check(annotation | Gpr.__ANNOTATION_ALL == Gpr.__ANNOTATION_ALL,
              f"Invalid annotation: 0b{annotation:b}")

        gpr = Gpr(Gpr.__create_key, is_view=True, name=name, count=count,
                  __view_base_gpr=self.base_gpr,
                  __view_base_offset=self.base_offset + offset,
                  __view_annotation=annotation)
        return gpr

    def __getitem__(self, item: Union[int, slice]) -> Gpr:
        check(not self.is_neg, "Can't subscribe a view of neg()")
        check(not self.is_abs, "Can't subscribe a view of abs()")
        check(not self.is_sext, "Can't subscribe a view of sext()")

        if isinstance(item, int):
            offset = item
            check(offset >= 0, f"Invalid offset value: {offset}")
            check(offset < self.count, f"Invalid offset value: {offset}")

            name = f"{self.name}[{item}]"
            return self.__create_view(name=name, offset=offset, count=1, annotation=self.__annotation)
        elif isinstance(item, slice):
            check(item.start is not None and isinstance(item.start, int), f"Invalid slice: {item}")
            check(item.stop is not None and isinstance(item.stop, int), f"Invalid slice: {item}")
            check(item.step is None, f"Invalid slice: {item}")
            check(item.stop >= item.start, f"Invalid slice: {item.stop} < {item.start}")

            offset = item.start
            count = item.stop - item.start + 1

            name = f"{self.name}[{item.start}:{item.stop}]"
            return self.__create_view(name=name, offset=offset, count=count, annotation=self.__annotation)
        else:
            raise SeekException(f"Invalid subscription index: {item}")

    def alias(self, name: Optional[str] = None) -> Gpr:
        if name is None:
            name = get_caller_assignments()
            if not isinstance(name, str):
                raise SeekException(f"Gpr alias name is not given, nor can we infer from the caller")
        return self.__create_view(name=name, offset=0, count=self.count, annotation=self.__annotation)

    def __neg__(self) -> Gpr:
        check(not self.is_sext, "Can't perform neg() on a view of sext()")
        check(not self.is_neg, "Can't perform neg() on a view of neg()")
        # self.is_abs may be True or False (both or OK)
        return self.__create_view(name=self.name, offset=0, count=self.count,
                                  annotation=(self.__annotation | Gpr.__ANNOTATION_NEG))

    def __abs__(self) -> Gpr:
        check(not self.is_sext, "Can't perform abs() on a view of sext()")
        check(not self.is_neg, "Can't perform abs() on a view of neg()")
        check(not self.is_abs, "Can't perform abs() on a view of abs()")
        return self.__create_view(name=self.name, offset=0, count=self.count,
                                  annotation=(self.__annotation | Gpr.__ANNOTATION_ABS))

    def sext(self) -> Gpr:
        check(not self.is_sext, "Can't perform sext() on a view of sext()")
        check(not self.is_neg, "Can't perform sext() on a view of neg()")
        check(not self.is_abs, "Can't perform sext() on a view of abs()")
        return self.__create_view(name=self.name, offset=0, count=self.count,
                                  annotation=(self.__annotation | Gpr.__ANNOTATION_SEXT))

    def __repr__(self):
        str_align = ''
        if self.align is not None and self.align.divisor != 1:
            assert self.align.divisor > 1
            if self.align.remainder == 0:
                str_align = f', align={self.align.divisor}'
            else:
                str_align = f', align=({self.align.divisor},{self.align.remainder})'

        str_rtype = "SGpr" if self.rtype.is_sgpr() else \
                    "VGpr" if self.rtype.is_vgpr() else \
                    "Acc" if self.rtype.is_agpr() else \
                    "Spec"
        str_rtype = f"{str_rtype}{self.count}"

        if self.is_sext:
            str_rtype = f"{str_rtype}.sext"
        else:
            if self.is_abs:
                str_rtype = f"|{str_rtype}|"
            if self.is_neg:
                str_rtype = f"-{str_rtype}"

        if not self.is_view:
            return f'{str_rtype}({self.name!r}{str_align})'
        else:  # self.is_view
            if self.count == self.base_gpr.count:
                if self.name == self.base_gpr.name:
                    # This might be due to abs, neg, sext
                    return f"{str_rtype}({('@'+self.name)!r}{str_align})"
                else:  # self.name != self.base_gpr.name
                    # This might be an alias
                    return f"{str_rtype}({('@'+self.name)!r}, ref={self.base_gpr!r})"
            else:  # self.count != self.base_gpr.count
                assert self.count < self.base_gpr.count
                if self.count == 1:
                    if self.name == f"{self.base_gpr.name}[{self.base_offset}]":
                        return f"{str_rtype}({('@'+self.name)!r}{str_align})"
                    else:
                        return f"{str_rtype}({('@'+self.name)!r}, ref={self.base_gpr!r}[{self.base_offset}])"
                else:  # self.count > 1
                    assert self.count > 1
                    to_offset = self.base_offset + self.count - 1
                    if self.name == f"{self.base_gpr.name}[{self.base_offset}:{to_offset}]":
                        return f"{str_rtype}({('@'+self.name)!r}{str_align})"
                    else:
                        return f"{str_rtype}({('@'+self.name)!r}, ref={self.base_gpr!r}[{self.base_offset}:{to_offset}])"  # noqa: E501 (line too long)

    def to_physical_str(self, program):
        if self.rtype.is_special():
            return self.name

        prefix = self.rtype.name[0].lower()
        assert prefix in ('v', 's', 'a')

        if self.base_gpr in program.forced_index:
            base_gpr_index = program.forced_index[self.base_gpr]
        else:
            assert self.base_gpr in program.assigned_index
            base_gpr_index = program.assigned_index[self.base_gpr]
        index = base_gpr_index + self.base_offset

        if self.count > 1:
            return f"{prefix}[{index}:{index+self.count-1}]"
        else:
            assert self.count == 1
            return f"{prefix}{index}"


def sext(gpr: Gpr) -> Gpr:
    return gpr.sext()


class __Allocator:
    __slots__ = ["__shape"]

    def __init__(self, shape: List[int]):
        self.__shape = shape  # type: List[int]

    def __getitem__(self, item) -> "__Allocator":
        if isinstance(item, int):
            check(item >= 0, f"Invalid Gpr list size: {item}")
            return self.__class__(shape=[item])
        else:
            check(isinstance(item, tuple))
            return self.__class__(shape=list(item))

    def __call__(self, rtype: GprType = None, name: str = None, count: int = 1, align: Optional[Union[int, Align]] = None) -> Union[Gpr, np.ndarray]:  # noqa: E501 (line too long)
        if name is None:
            name = get_caller_assignments()
            if not isinstance(name, str):
                raise SeekException(f"Gpr name is not given, nor can we infer from the caller")
        check(name, "Gpr name shall not be empty")

        # noinspection PyProtectedMember
        name, rtype = Gpr._check_and_get_name(name, rtype)

        if not self.__shape:
            # noinspection PyProtectedMember
            return Gpr._create_gpr(name=name, count=count, rtype=rtype, align=align)
        else:
            list_names = [name]
            for dim in self.__shape:
                tmp = []
                for prefix in list_names:
                    tmp += [f"{prefix}_{i}" for i in range(0, dim)]
                list_names = tmp
            # noinspection PyProtectedMember
            results = np.array(
                [Gpr._create_gpr(name=n, count=count, rtype=rtype, align=align) for n in list_names],
                dtype='O')
            results.shape = self.__shape
            return results


new = __Allocator(shape=[])  # By default, we allocate scalar


"""
Pre-created registers
"""
vcc = new(rtype=GprType.Special, count=2)
vcc_lo = vcc[0].alias()
vcc_hi = vcc[1].alias()
vccz = new(rtype=GprType.Special)

vcc_exec = vcc.alias()  # for use in v_cmpx_... only

# noinspection PyShadowingBuiltins
exec = new(rtype=GprType.Special, count=2)
exec_lo = exec[0].alias()
exec_hi = exec[1].alias()
execz = new(rtype=GprType.Special)

# noinspection PyShadowingBuiltins
flat_scratch = new(rtype=GprType.Special, count=2)
flat_scratch_lo = flat_scratch[0].alias()
flat_scratch_hi = flat_scratch[1].alias()

# noinspection PyShadowingBuiltins
xnack_mask = new(rtype=GprType.Special, count=2)
xnack_mask_lo = xnack_mask[0].alias()
xnack_mask_hi = xnack_mask[1].alias()

scc = new(rtype=GprType.Special)

m0 = new(rtype=GprType.Special)
lds_direct = new(rtype=GprType.Special)

s_kernarg = new(rtype=GprType.S, count=2, align=2)

load_to_lds = new(rtype=GprType.Special)  # for buffer_load_xxx lds

off = new(rtype=GprType.Special, count=1)  # count=1 is dummy


# noinspection PyPep8Naming
class threadIdx:
    x = new(rtype=GprType.V, name="threadIdx.x")
    y = new(rtype=GprType.V, name="threadIdx.y")
    z = new(rtype=GprType.V, name="threadIdx.z")


# noinspection PyPep8Naming
class blockIdx:
    x = new(rtype=GprType.S, name="blockIdx.x")
    y = new(rtype=GprType.S, name="blockIdx.y")
    z = new(rtype=GprType.S, name="blockIdx.z")


class GprSet:
    __slots__ = ["__d"]

    @staticmethod
    def __offsets_to_mask(offsets: Iterable[int]) -> int:
        mask = 0
        for offset in offsets:
            mask |= 1 << offset
        assert mask > 0
        return mask

    @staticmethod
    def __gpr_to_mask(gpr: Gpr) -> int:
        return GprSet.__offsets_to_mask(range(gpr.base_offset, gpr.base_offset + gpr.count))

    @staticmethod
    def __mask_to_offset_list(mask: int) -> List[int]:
        assert mask > 0
        results = []  # type: List[int]
        offset = 0
        while mask > 0:
            if mask & 1:
                results.append(offset)
            mask >>= 1
            offset += 1
        return results

    def __init__(self, *items: Union[GprSet, Gpr]):
        self.__d = {}  # type: Dict[Gpr, int]
        self.union_update(*items)

    def clone(self) -> GprSet:
        cloned = GprSet()
        cloned.__d = self.__d.copy()  # shallow copy
        return cloned

    def union_update(self, *others: Union[GprSet, Gpr]) -> GprSet:
        for other in others:
            if isinstance(other, Gpr):
                if other.base_gpr not in self.__d:
                    self.__d[other.base_gpr] = 0
                self.__d[other.base_gpr] |= GprSet.__gpr_to_mask(other)
            elif isinstance(other, GprSet):
                for other_base_gpr, other_mask in other.__d.items():
                    assert not other_base_gpr.is_view
                    assert other_mask > 0
                    if other_base_gpr not in self.__d:
                        self.__d[other_base_gpr] = 0
                    self.__d[other_base_gpr] |= other_mask
            else:
                raise SeekException(f"Unknown item: {other}")
        return self

    def union(self, *others: Union[GprSet, Gpr]) -> GprSet:
        result = self.clone()
        result.union_update(*others)
        return result

    def intersect_update(self, other: Union[GprSet, Gpr]) -> GprSet:
        if isinstance(other, Gpr):
            other = GprSet(other)
        elif isinstance(other, GprSet):
            pass
        else:
            raise SeekException(f"Unknown item: {other}")

        base_gpr_to_remove = []
        for base_gpr, mask in self.__d.items():
            other_mask = other.__d[base_gpr] if base_gpr in other.__d else 0
            intersect_mask = mask & other_mask
            if intersect_mask != 0:
                self.__d[base_gpr] = intersect_mask
            else:
                base_gpr_to_remove.append(base_gpr)

        for base_gpr in base_gpr_to_remove:
            del self.__d[base_gpr]
        return self

    def intersect(self, other: Union[GprSet, Gpr]) -> GprSet:
        result = self.clone()
        result.intersect_update(other)
        return result

    def difference_update(self, *others: Union[GprSet, Gpr]) -> GprSet:
        for other in others:
            if isinstance(other, Gpr):
                other = GprSet(other)
            elif isinstance(other, GprSet):
                pass
            else:
                raise SeekException(f"Unknown item: {other}")

            base_gpr_to_remove = []
            for other_base_gpr, other_mask in other.__d.items():
                assert other_mask > 0
                if other_base_gpr not in self.__d:
                    continue
                intersect_mask = self.__d[other_base_gpr] & other_mask
                # what we really mean below: self.__d[other_base_gpr] &= ~intersect_mask
                self.__d[other_base_gpr] -= intersect_mask
                if self.__d[other_base_gpr] == 0:
                    base_gpr_to_remove.append(other_base_gpr)

            for base_gpr in base_gpr_to_remove:
                del self.__d[base_gpr]
        return self

    def difference(self, *others: Union[GprSet, Gpr]) -> GprSet:
        result = self.clone()
        result.difference_update(*others)
        return result

    @property
    def base_gprs(self) -> List[Gpr]:
        return list(sorted(self.__d.keys()))

    @property
    def expanded_gprs(self) -> List[Gpr]:
        results = []  # type: List[Gpr]
        for base_gpr, mask in self.__d.items():
            results += [base_gpr[o] for o in GprSet.__mask_to_offset_list(mask)]
        results.sort()
        return results

    def get_offset_list(self, base_gpr: Gpr) -> List[int]:
        assert not base_gpr.is_view
        if base_gpr in self.__d:
            mask = self.__d[base_gpr]
            assert mask > 0
            return GprSet.__mask_to_offset_list(mask)
        else:
            return []

    def is_superset(self, item: Union[GprSet, Gpr]) -> bool:
        # TODO: will this be a performance issue?
        return self.union(item) == self

    def is_intersected(self, item: Union[GprSet, Gpr]) -> bool:
        # TODO: will this be a performance issue?
        return not self.intersect(item).empty()

    def __contains__(self, item):
        raise RuntimeError("Don't use __contains__() as it's ambiguous")

    def __eq__(self, other):
        if not isinstance(other, GprSet):
            return NotImplemented
        return self.__d == other.__d

    def __repr__(self):
        str_gprs = []
        for base_gpr in self.base_gprs:  # has been sorted
            offset_list = self.get_offset_list(base_gpr)
            assert offset_list
            if len(offset_list) == base_gpr.count:
                # All offsets of a base Gpr is here
                # As an abbreviation, we don't explicitly list all the offsets 1..count
                str_gprs.append(f"{base_gpr}")
            else:
                str_gprs.append(f"{base_gpr}[{','.join(map(str, offset_list))}]")
        return f"GprSet({{{', '.join(str_gprs)}}})"

    def __or__(self, other: Union[GprSet, Gpr]) -> GprSet:
        return self.union(other)

    def __ior__(self, other: Union[GprSet, Gpr]) -> GprSet:
        return self.union_update(other)

    def __and__(self, other: Union[GprSet, Gpr]) -> GprSet:
        return self.intersect(other)

    def __iand__(self, other: Union[GprSet, Gpr]) -> GprSet:
        return self.intersect_update(other)

    def __sub__(self, other: Union[GprSet, Gpr]) -> GprSet:
        return self.difference(other)

    def __isub__(self, other: Union[GprSet, Gpr]) -> GprSet:
        return self.difference_update(other)

    def __copy__(self) -> GprSet:
        raise SeekException("Use clone() instead")

    def __deepcopy__(self, *args, **kwargs) -> GprSet:
        # Don't really deep copy the Gpr
        raise SeekException("Use clone() instead")

    def __bool__(self):
        return not self.empty()

    def empty(self) -> bool:
        return len(self.__d) == 0

    def _sanity_check(self):
        """
        For debugging purpose only
        """
        for base_gpr, mask in self.__d.items():
            assert not base_gpr.is_view
            assert mask > 0
