from __future__ import annotations
from typing import Set, List, Optional, Callable
from ..basic.exception import SeekException
from ..basic.register import GprSet
from ..basic.instr import MemToken, ExplicitWaitCall, ExplicitUsesCall, InstrCall, Waitcnt, ControlFlowEnum
from .base_pass import BasePass, PassTag, OptimizerState
from .divide_basic_block_pass import BasicBlock, DivideBasicBlockPassState
from .annotate_clause_pass import AnnClause


class PendingMem:
    def __init__(self):
        self.pending_vector = []  # type: List[MemToken]
        self.pending_lds = []  # type: List[MemToken]
        self.pending_gds = []  # type: List[MemToken]
        self.pending_scalar = []  # type: List[MemToken]
        self.pending_msg = []  # type: List[MemToken]
        self.pending_export = []  # type: List[MemToken]

    def __eq__(self, other):
        if not isinstance(other, PendingMem):
            return NotImplemented

        return self.pending_vector == other.pending_vector and \
            self.pending_lds == other.pending_lds and \
            self.pending_gds == other.pending_gds and \
            self.pending_scalar == other.pending_scalar and \
            self.pending_msg == other.pending_msg and \
            self.pending_export == other.pending_export

    def __hash__(self):
        return hash((tuple(self.pending_vector), tuple(self.pending_lds), tuple(self.pending_gds),
                     tuple(self.pending_scalar), tuple(self.pending_msg), tuple(self.pending_export)))

    def __repr__(self):
        pending_list_reprs = []  # type: List[str]
        if self.pending_vector:
            pending_list_reprs.append(f"pending_vector={self.pending_vector}")
        if self.pending_lds:
            pending_list_reprs.append(f"pending_lds={self.pending_lds}")
        if self.pending_gds:
            pending_list_reprs.append(f"pending_gds={self.pending_gds}")
        if self.pending_scalar:
            pending_list_reprs.append(f"pending_scalar={self.pending_scalar}")
        if self.pending_msg:
            pending_list_reprs.append(f"pending_msg={self.pending_msg}")
        if self.pending_export:
            pending_list_reprs.append(f"pending_export={self.pending_export}")
        return f"PendingMem({', '.join(pending_list_reprs)})"

    def clone(self) -> PendingMem:
        result = PendingMem()
        result.pending_vector = self.pending_vector.copy()
        result.pending_lds = self.pending_lds.copy()
        result.pending_gds = self.pending_gds.copy()
        result.pending_scalar = self.pending_scalar.copy()
        result.pending_msg = self.pending_msg.copy()
        result.pending_export = self.pending_export.copy()
        return result

    def add_pending(self, mem_token: MemToken):
        # Currently, there shall be **exactly** one inc_xxx > 0
        assert isinstance(mem_token, MemToken), mem_token
        assert int(mem_token.inc_vector > 0) + int(mem_token.inc_lds > 0) + int(mem_token.inc_gds > 0) + \
            int(mem_token.inc_scalar > 0) + int(mem_token.inc_msg > 0) + int(mem_token.inc_export > 0) == 1, mem_token

        def __add_if_exists(inc: int, pending_list: List[MemToken]):
            if inc > 0:
                assert mem_token not in pending_list
                pending_list.append(mem_token)

        __add_if_exists(mem_token.inc_vector, self.pending_vector)
        __add_if_exists(mem_token.inc_lds, self.pending_lds)
        __add_if_exists(mem_token.inc_gds, self.pending_gds)
        __add_if_exists(mem_token.inc_scalar, self.pending_scalar)
        __add_if_exists(mem_token.inc_msg, self.pending_msg)
        __add_if_exists(mem_token.inc_export, self.pending_export)

    def update_by_waitcnt(self, waitcnt: Waitcnt) -> None:

        def __maybe_pop(waitcnt_value: Optional[int],
                        pending_list: List[MemToken],
                        select_waitcnt: Callable[[MemToken], int]):
            if waitcnt_value is not None:
                cnt_sum = sum(select_waitcnt(x) for x in pending_list)
                while cnt_sum > waitcnt_value:
                    cnt_sum -= select_waitcnt(pending_list.pop(0))
                assert sum(select_waitcnt(x) for x in pending_list) <= waitcnt_value

        __maybe_pop(waitcnt.vmcnt, self.pending_vector, lambda w: w.total_inc_vmcnt)
        __maybe_pop(waitcnt.lgkmcnt, self.pending_lds, lambda w: w.total_inc_lgkmcnt)
        __maybe_pop(waitcnt.lgkmcnt, self.pending_gds, lambda w: w.total_inc_lgkmcnt)
        if waitcnt.lgkmcnt == 0:
            # Scalar reads/writes return out of order! Only valid wait is: `s_waitcnt lgkmcnt(0)`
            __maybe_pop(waitcnt.lgkmcnt, self.pending_scalar, lambda w: w.total_inc_lgkmcnt)
        __maybe_pop(waitcnt.lgkmcnt, self.pending_msg, lambda w: w.total_inc_lgkmcnt)
        __maybe_pop(waitcnt.expcnt, self.pending_export, lambda w: w.total_inc_expcnt)

    def update_by_explicit_wait(self, mem_tokens: Set[MemToken]) -> Optional[Waitcnt]:
        waitcnt = None  # type: Optional[Waitcnt]
        for mem_token in mem_tokens:
            assert isinstance(mem_token, MemToken), mem_token

            def __maybe_update(pending_list: List[MemToken],
                               key: str,
                               out_of_order_return: bool,
                               select_waitcnt: Callable[[MemToken], int]):
                assert key in ("vmcnt", "lgkmcnt", "expcnt")
                nonlocal waitcnt
                if mem_token in pending_list:
                    if out_of_order_return:
                        pending_list.clear()
                        waitcnt = Waitcnt.lcm(waitcnt, Waitcnt(**{key: 0}))
                    else:
                        while True:
                            popped = pending_list.pop(0)
                            if popped == mem_token:
                                break
                        assert mem_token not in pending_list
                        waitcnt = Waitcnt.lcm(waitcnt, Waitcnt(**{key: sum(select_waitcnt(x) for x in pending_list)}))

            __maybe_update(self.pending_vector, "vmcnt", False, lambda w: w.total_inc_vmcnt)
            __maybe_update(self.pending_lds, "lgkmcnt", False, lambda w: w.total_inc_lgkmcnt)
            __maybe_update(self.pending_gds, "lgkmcnt", False, lambda w: w.total_inc_lgkmcnt)
            __maybe_update(self.pending_scalar, "lgkmcnt", True, lambda w: w.total_inc_lgkmcnt)  # out of order return!
            __maybe_update(self.pending_msg, "lgkmcnt", False, lambda w: w.total_inc_lgkmcnt)
            __maybe_update(self.pending_export, "expcnt", False, lambda w: w.total_inc_expcnt)

        # Well, this waitcnt may cause further updates on pending_xxx
        if waitcnt is not None:
            self.update_by_waitcnt(waitcnt)

        return waitcnt

    def update_by_uses(self, gprset: GprSet) -> Waitcnt:
        waitcnt = None  # type: Optional[Waitcnt]

        def __maybe_update(pending_list: List[MemToken],
                           key: str,
                           out_of_order_return: bool,
                           select_waitcnt: Callable[[MemToken], int]):
            assert key in ("vmcnt", "lgkmcnt", "expcnt")
            nonlocal waitcnt
            for idx in range(len(pending_list)-1, -1, -1):
                if GprSet(*pending_list[idx].load_mem_to_gprs).is_intersected(gprset):
                    if out_of_order_return:
                        pending_list.clear()
                        waitcnt = Waitcnt.lcm(waitcnt, Waitcnt(**{key: 0}))
                    else:
                        for _ in range(idx+1):
                            pending_list.pop(0)
                        waitcnt = Waitcnt.lcm(waitcnt, Waitcnt(**{key: sum(select_waitcnt(x) for x in pending_list)}))
                    break

        __maybe_update(self.pending_vector, "vmcnt", False, lambda w: w.total_inc_vmcnt)
        __maybe_update(self.pending_lds, "lgkmcnt", False, lambda w: w.total_inc_lgkmcnt)
        __maybe_update(self.pending_gds, "lgkmcnt", False, lambda w: w.total_inc_lgkmcnt)
        __maybe_update(self.pending_scalar, "lgkmcnt", True, lambda w: w.total_inc_lgkmcnt)  # out of order return!
        __maybe_update(self.pending_msg, "lgkmcnt", False, lambda w: w.total_inc_lgkmcnt)
        __maybe_update(self.pending_export, "expcnt", False, lambda w: w.total_inc_expcnt)

        # Well, this waitcnt may cause further updates on pending_xxx
        if waitcnt is not None:
            self.update_by_waitcnt(waitcnt)

        return waitcnt

    def check_by_defs(self, gprset: GprSet, instr: InstrCall) -> None:
        def __check(pending_list: List[MemToken]):
            for mem_token in pending_list:
                if GprSet(*mem_token.load_mem_to_gprs).is_intersected(gprset):
                    raise SeekException(f"Instruction defines Gpr writen by pending memory {mem_token}: {instr}")

        __check(self.pending_vector)
        __check(self.pending_lds)
        __check(self.pending_gds)
        __check(self.pending_scalar)  # out of order return!
        __check(self.pending_msg)
        __check(self.pending_export)


class InsertWaitcntPass(BasePass):
    def __init__(self, /,
                 s_barrier_implies_wait_vmcnt_0: bool = True,
                 priority: int = PassTag.InsertWaitcnt.value):
        """
        s_barrier_implies_wait_vmcnt_0:
            Does `s_barrier` imply `s_waitcnt vmcnt(0)`?
            This is not a documented behavior, but it seems true from micro_benchmark.
        """
        super().__init__(priority)
        self.s_barrier_implies_wait_vmcnt_0 = s_barrier_implies_wait_vmcnt_0

    def required_tags(self) -> Set[PassTag]:
        return {PassTag.AnnotateClause}

    def generated_tags(self) -> Set[PassTag]:
        return {PassTag.InsertWaitcnt}

    def invalidated_tags(self) -> Set[PassTag]:
        return set()

    def reset(self, program):
        optimizer_state = program.optimizer_state  # type: OptimizerState
        state = optimizer_state.divide_basic_block  # type: DivideBasicBlockPassState
        if state is None:
            return
        for bb in state.basic_blocks:
            assert isinstance(bb, BasicBlock)
            bb.pending_mem_in = None
            bb.pending_mem_out = None
            if bb.annotate_clauses is None:
                continue
            for annclause in bb.annotate_clauses:
                assert isinstance(annclause, AnnClause)
                annclause.insert_waitcnt = None

    def run(self, program) -> bool:
        # Note: let's reset all existing `annclause.insert_waitcnt`, if any
        self.reset(program)

        optimizer_state = program.optimizer_state  # type: OptimizerState
        state = optimizer_state.divide_basic_block  # type: DivideBasicBlockPassState
        assert state is not None

        update_queue = []  # type: List[BasicBlock]

        for bb in state.basic_blocks:
            assert isinstance(bb, BasicBlock)
            assert bb.annotate_clauses is not None
            for annclause in bb.annotate_clauses:
                assert isinstance(annclause, AnnClause)
                assert annclause.insert_waitcnt is None
            bb.pending_mem_in = []
            bb.pending_mem_out = []

        state.basic_blocks[0].pending_mem_in = [PendingMem()]
        update_queue.append(state.basic_blocks[0])

        while update_queue:
            bb = update_queue.pop(0)

            # Shallow copy current basic-block's set of PendingMem at entrance to a list.
            # After its updates by all clauses by current basic-block,
            # it's distincted and then becomes the set of PendingMem at exit.
            pending_mem_list = list(pending_mem.clone() for pending_mem in bb.pending_mem_in)  # type: List[PendingMem]

            # Loop from first to last clause
            for annclause in bb.annotate_clauses:
                clause = annclause.clause

                waitcnt = annclause.insert_waitcnt  # type: Optional[Waitcnt]
                if waitcnt is not None:
                    for pending_mem in pending_mem_list:
                        pending_mem.update_by_waitcnt(waitcnt)

                if isinstance(clause, ExplicitWaitCall):
                    mem_tokens = set()  # type: Set[MemToken]
                    for mem_token_or_token_object in clause.mem_token_or_token_objects:
                        if isinstance(mem_token_or_token_object, MemToken):
                            mem_tokens.add(mem_token_or_token_object)
                        else:
                            assert mem_token_or_token_object in state.mem_token_object_to_mem_tokens
                            mem_tokens.add(state.mem_token_object_to_mem_tokens[mem_token_or_token_object])

                    for pending_mem in pending_mem_list:
                        new_waitcnt = pending_mem.update_by_explicit_wait(mem_tokens)
                        waitcnt = Waitcnt.lcm(waitcnt, new_waitcnt)
                elif isinstance(clause, ExplicitUsesCall):
                    for pending_mem in pending_mem_list:
                        new_waitcnt = pending_mem.update_by_uses(GprSet(*clause.uses))
                        waitcnt = Waitcnt.lcm(waitcnt, new_waitcnt)
                else:
                    assert isinstance(clause, InstrCall)

                    # Update by the instruction's uses first
                    for pending_mem in pending_mem_list:
                        new_waitcnt = pending_mem.update_by_uses(clause.gpr_uses_to_gprset())
                        waitcnt = Waitcnt.lcm(waitcnt, new_waitcnt)

                    # Then, check by the instruction's defs
                    for pending_mem in pending_mem_list:
                        pending_mem.check_by_defs(clause.gpr_defs_to_gprset(), clause)

                    # Deal with s_waitcnt
                    if clause.instr_name == "s_waitcnt":
                        operand = clause.operands["waitcnt"]
                        if isinstance(operand, int):
                            new_waitcnt = Waitcnt.from_int(operand)
                        else:
                            assert isinstance(operand, Waitcnt)
                            new_waitcnt = waitcnt
                        waitcnt = Waitcnt.lcm(waitcnt, new_waitcnt)

                    # Deal with s_barrier
                    if self.s_barrier_implies_wait_vmcnt_0 and clause.instr_name == "s_barrier":
                        waitcnt = Waitcnt.lcm(waitcnt, Waitcnt(vmcnt=0))

                # OK, `waitcnt` may (or may not) be updated.
                # Let's apply it to `pending_mem_list`.
                if waitcnt is not None:
                    for pending_mem in pending_mem_list:
                        pending_mem.update_by_waitcnt(waitcnt)
                annclause.insert_waitcnt = waitcnt

                # If this is a memory instruction, we have to update `pending_mem_list` here
                if isinstance(clause, InstrCall):
                    if clause.mem_token is not None:
                        pending_mem_list.append(PendingMem())
                        for pending_mem in pending_mem_list:
                            pending_mem.add_pending(clause.mem_token)

            # Possibly update `bb.pending_mem_out`
            if bb.control_flow_at_exit == ControlFlowEnum.Terminate:
                # s_endpgm family implies wait all pending memory
                pending_mem_list = []
            else:
                pending_mem_list = list(set(pending_mem_list))  # unique
                pending_mem_list.sort(key=repr)

            if bb.pending_mem_out != pending_mem_list:
                bb.pending_mem_out = pending_mem_list

                if bb.successor_if_jump is not None:
                    # Calculate successor_if_jump basic-block's set of PendingMem at entrance.
                    bb.successor_if_jump.pending_mem_in = []
                    for pred in bb.successor_if_jump.predecessors:
                        bb.successor_if_jump.pending_mem_in += pred.pending_mem_out
                    bb.successor_if_jump.pending_mem_in = list(set(bb.successor_if_jump.pending_mem_in))  # unique
                    bb.successor_if_jump.pending_mem_in.sort(key=repr)
                    update_queue.append(bb.successor_if_jump)

                if bb.successor_if_fallthrough is not None:
                    # Calculate successor_if_fallthrough basic-block's set of PendingMem at entrance.
                    bb.successor_if_fallthrough.pending_mem_in = []
                    for pred in bb.successor_if_fallthrough.predecessors:
                        bb.successor_if_fallthrough.pending_mem_in += pred.pending_mem_out
                    bb.successor_if_fallthrough.pending_mem_in = list(set(bb.successor_if_fallthrough.pending_mem_in))  # unique  # noqa E501: line too long
                    bb.successor_if_fallthrough.pending_mem_in.sort(key=repr)
                    update_queue.append(bb.successor_if_fallthrough)

        return True
