from __future__ import annotations
from typing import Set, List, Optional, Union, Dict, Any
from ..basic.const import _INSTR_STR_WIDTH
from ..basic.register import Gpr, GprSet
from ..basic.exception import SeekException, check
from ..basic.instr import InstrCall, ExplicitWaitCall, ExplicitUsesCall, MemToken, Block, ControlFlowEnum
from .base_pass import BasePass, PassTag, OptimizerState


class BasicBlockJumpInstr:
    def __init__(self,
                 control_flow_enum: ControlFlowEnum,
                 instr_name: str,
                 gpr_uses: Dict[Gpr, Dict[int, int]],  # base_gpr -> {offset -> count}
                 gpr_holds: Dict[Gpr, Dict[int, int]],  # base_gpr -> {offset -> count}
                 gpr_defs: Dict[Gpr, Dict[int, int]]):  # base_gpr -> {offset -> count}
        assert control_flow_enum != ControlFlowEnum.Fallthrough
        self.control_flow_enum = control_flow_enum  # type: ControlFlowEnum
        self.instr_name = instr_name  # type: str
        self.gpr_uses = gpr_uses  # type: Dict[Gpr, Dict[int, int]]  # base_gpr -> {offset -> count}
        self.gpr_holds = gpr_holds  # type: Dict[Gpr, Dict[int, int]]  # base_gpr -> {offset -> count}
        self.gpr_defs = gpr_defs  # type: Dict[Gpr, Dict[int, int]]  # base_gpr -> {offset -> count}

    def gpr_uses_to_gprset(self) -> GprSet:
        return InstrCall.gpr_uses_holds_defs_to_gprset(self.gpr_uses)

    def gpr_holds_to_gprset(self) -> GprSet:
        return InstrCall.gpr_uses_holds_defs_to_gprset(self.gpr_holds)

    def gpr_defs_to_gprset(self) -> GprSet:
        return InstrCall.gpr_uses_holds_defs_to_gprset(self.gpr_defs)

    def generate(self, program, wr, successor_if_jump):
        instr_text = f"{self.instr_name} ".ljust(_INSTR_STR_WIDTH)
        if successor_if_jump is None:
            assert self.control_flow_enum == ControlFlowEnum.Terminate
            wr(f"{instr_text}// {self.control_flow_enum.name}")
        else:
            assert self.control_flow_enum in (ControlFlowEnum.AlwaysJump, ControlFlowEnum.CondJump)
            wr(f"{instr_text}{successor_if_jump.name}  // {self.control_flow_enum.name}")

    @staticmethod
    def from_instr_call(instr: InstrCall) -> BasicBlockJumpInstr:
        assert instr.control_flow_enum != ControlFlowEnum.Fallthrough
        return BasicBlockJumpInstr(instr.control_flow_enum,
                                   instr.instr_name,
                                   instr.gpr_uses,
                                   instr.gpr_holds,
                                   instr.gpr_defs)

    @staticmethod
    def make_alwaysjump() -> BasicBlockJumpInstr:
        # NOTE: no uses/holds/defs
        return BasicBlockJumpInstr(ControlFlowEnum.AlwaysJump, "s_branch", {}, {}, {})

    @staticmethod
    def make_terminate() -> BasicBlockJumpInstr:
        # NOTE: no uses/holds/defs
        return BasicBlockJumpInstr(ControlFlowEnum.Terminate, "s_endpgm", {}, {}, {})


class BasicBlock:
    def __init__(self, name: str):
        self.name = name  # type: str

        # For s_branch: successor_if_jump is target if jumping, successor_if_fallthrough is None
        # For s_cbranch_xxx family: successor_if_jump is target if jumping, successor_if_fallthrough is target if no jumping  # noqa E501: line too long
        # For s_endpgm family: successor_if_jump is None, successor_if_fallthrough is None
        # For fall-through: successor_if_jump is None, successor_if_fallthrough is target
        self.successor_if_jump = None  # type: Optional[BasicBlock]
        self.successor_if_fallthrough = None  # type: Optional[BasicBlock]
        self.predecessors = []  # type: List[BasicBlock]

        self.jump_instr = None  # type: Optional[BasicBlockJumpInstr]

        self.clauses = []  # type: List[Union[InstrCall, ExplicitWaitCall, ExplicitUsesCall]]

        # For AnalyzeLiveVarPass
        self.live_var_defs = None  # type: GprSet
        self.live_var_uses = None  # type: GprSet
        self.live_var_in = None  # type: GprSet
        self.live_var_out = None  # type: GprSet
        self.live_var_gpr_life_span = None  # type: Dict[Gpr, List[int]]

        # For AnnotateClausePass
        self.annotate_clauses = None  # type: List  # List[AnnClause]

        # For InsertWaitcntPass
        self.pending_mem_in = None  # type: Set  # Set[PendingMem]
        self.pending_mem_out = None  # type: Set  # Set[PendingMem]

    def _sanity_check_control_flow_at_exit(self):
        if self.jump_instr is None:
            # For fall-through: successor_if_jump is None, successor_if_fallthrough is target
            assert self.successor_if_jump is None
            assert isinstance(self.successor_if_fallthrough, BasicBlock)
        else:
            control_flow = self.jump_instr.control_flow_enum
            if control_flow == ControlFlowEnum.AlwaysJump:
                # For s_branch: successor_if_jump is target if jumping, successor_if_fallthrough is None
                assert isinstance(self.successor_if_jump, BasicBlock)
                assert self.successor_if_fallthrough is None
            elif control_flow == ControlFlowEnum.CondJump:
                # For s_cbranch_xxx family: successor_if_jump is target if jumping, successor_if_fallthrough is target if no jumping  # noqa E501: line too long
                assert isinstance(self.successor_if_jump, BasicBlock)
                assert isinstance(self.successor_if_fallthrough, BasicBlock)

                # NOTE:
                # A basic-block, which conditionally jumps to its successor basic-block,
                # makes `successor_if_jump` and `successor_if_fallthrough` refer to the same basic-block.
                #
                # We just make this evil basic-block just fall through to its successor.
                #
                # This is not only an optimization, but also necessary:
                # after this fix-up, no (successor) basic-block B will have two same precedent basic-blocks A
                # (aka A.successor_if_jump is B and A.successor_if_fallthrough is B)
                assert self.successor_if_jump is not self.successor_if_fallthrough
            elif control_flow == ControlFlowEnum.Terminate:
                # For s_endpgm family: successor_if_jump is None, successor_if_fallthrough is None
                assert self.successor_if_jump is None
                assert self.successor_if_fallthrough is None
            else:
                assert False, f"Unexpected control_flow: {control_flow}"

    @property
    def control_flow_at_exit(self) -> ControlFlowEnum:
        self._sanity_check_control_flow_at_exit()

        if self.jump_instr is None:
            return ControlFlowEnum.Fallthrough
        else:
            return self.jump_instr.control_flow_enum

    def __repr__(self):
        return f"BasicBlock({repr(self.name)})"


class DivideBasicBlockPassState:
    def __init__(self):
        self.basic_blocks = []  # type: List[BasicBlock]
        self.mem_token_object_to_mem_tokens = {}  # type: Dict[Any, MemToken]

    def _sanity_check(self):
        assert isinstance(self.basic_blocks, list)

        # Check all `successor_if_fallthrough` are sane
        for idx, bb in enumerate(self.basic_blocks):
            if bb.successor_if_fallthrough is not None:
                assert bb.successor_if_fallthrough is self.basic_blocks[idx+1]

        # Check there are no duplicated basic-blocks
        bb_set = set()  # type: Set[BasicBlock]
        for bb in self.basic_blocks:
            assert isinstance(bb, BasicBlock)
            assert bb not in bb_set, f"Duplicates in self.basic_blocks: {bb}"
            bb_set.add(bb)

        # Check `predecessors` and `jump_instr` are sane
        for bb in bb_set:
            assert len(bb.predecessors) == len(set(bb.predecessors)), \
                f"Duplicated predecessors in {bb}: {bb.predecessors}"
            for pred in bb.predecessors:
                assert isinstance(pred, BasicBlock)
                assert pred in bb_set
                assert pred.successor_if_jump is bb or pred.successor_if_fallthrough is bb
                assert not (pred.successor_if_jump is bb and pred.successor_if_fallthrough is bb)

            if bb.successor_if_jump is not None:
                assert bb in bb.successor_if_jump.predecessors
            if bb.successor_if_fallthrough is not None:
                assert bb in bb.successor_if_fallthrough.predecessors

            # Also, check that `jump_instr` is sane
            bb._sanity_check_control_flow_at_exit()


class DivideBasicBlockPass(BasePass):
    """
    Divide program blocks into basic-blocks.
    NOTE: Dead basic-blocks are **not** pruned in this pass. Leave them to OptimizeBasicBlockPass
    """
    def __init__(self, /, priority: int = PassTag.DivideBasicBlock.value):
        super().__init__(priority)

    def required_tags(self) -> Set[PassTag]:
        return set()

    def generated_tags(self) -> Set[PassTag]:
        return {PassTag.DivideBasicBlock}

    def invalidated_tags(self) -> Set[PassTag]:
        return set()

    def reset(self, program):
        optimizer_state = program.optimizer_state  # type: OptimizerState
        optimizer_state.divide_basic_block = None

    def run(self, program) -> bool:
        optimizer_state = program.optimizer_state  # type: OptimizerState

        #
        # Each `Block` could be divided to one or more basic-blocks (due to branching instructions)
        # Here, each element of `bb_by_block` is a list, corresponding to all basic-blocks of a Block
        #
        bb_by_block = {}  # type: Dict[Block, List[BasicBlock]]
        for block in program.blocks:  # type: Block
            assert block not in bb_by_block
            bb = BasicBlock(block.block_name)
            bb_by_block[block] = [bb]

        for block_idx, block in enumerate(program.blocks):  # type: (int, Block)
            curr_bb = bb_by_block[block][0]  # this should be the very first basic-block

            # Init `alwaysjump_or_terminate_just_now` in case `block.clauses` is empty
            alwaysjump_or_terminate_just_now = False
            for clause in block.clauses:  # type: Union[InstrCall, ExplicitWaitCall, ExplicitUsesCall]
                assert curr_bb is not None
                alwaysjump_or_terminate_just_now = False

                # NOTE: InstrCall is added to curr_bb even if they are control-flow instructions
                # This will be fixed (aka removed) later when setting `jump_instr`
                curr_bb.clauses.append(clause)

                if isinstance(clause, ExplicitWaitCall):
                    # Nothing to do with ExplicitWaitCall
                    continue

                if isinstance(clause, ExplicitUsesCall):
                    # Nothing to do with ExplicitUsesCall
                    continue

                assert isinstance(clause, InstrCall)
                instr = clause  # type: InstrCall

                new_bb = None
                if instr.control_flow_enum in (ControlFlowEnum.AlwaysJump, ControlFlowEnum.Terminate):
                    alwaysjump_or_terminate_just_now = True

                if instr.control_flow_enum != ControlFlowEnum.Fallthrough:
                    # We need to switch to a new basic-block **after** this instruction
                    new_bb = BasicBlock(f"{block.block_name}_BB{len(bb_by_block[block])}")
                    bb_by_block[block].append(new_bb)

                assert curr_bb.successor_if_jump is None
                assert curr_bb.successor_if_fallthrough is None
                if instr.control_flow_enum == ControlFlowEnum.CondJump:
                    cond_jump_target = program.get_block(instr.operands["label"])
                    curr_bb.successor_if_jump = bb_by_block[cond_jump_target][0]
                    curr_bb.successor_if_fallthrough = new_bb
                elif instr.control_flow_enum == ControlFlowEnum.AlwaysJump:
                    alwaysjump_target = program.get_block(instr.operands["label"])
                    curr_bb.successor_if_jump = bb_by_block[alwaysjump_target][0]
                    # curr_bb.successor_if_fallthrough = None  # not necessary as it is already None
                elif instr.control_flow_enum == ControlFlowEnum.Terminate:
                    # curr_bb.successor_if_jump = None  # not necessary as it is already None
                    # curr_bb.successor_if_fallthrough = None  # not necessary as it is already None
                    pass

                # Switch to a new basic-block if desired
                if new_bb is not None:
                    curr_bb = new_bb

            assert curr_bb is not None
            assert curr_bb.successor_if_jump is None
            assert curr_bb.successor_if_fallthrough is None

            if block_idx+1 == len(program.blocks):
                # If this is the last block: we must have ended with s_endpgm or s_branch
                if not alwaysjump_or_terminate_just_now:
                    raise SeekException("Program doesn't terminate with s_endpgm family or s_branch")

                # Remove this (empty) basic-block from current block
                assert len(curr_bb.clauses) == 0  # this basic-block must be empty
                assert len(bb_by_block[block]) >= 1  # the last block must have at least 1 (just added) basic-block
                popped_bb = bb_by_block[block].pop()
                assert popped_bb is curr_bb
            else:
                # Now all instructions in this Block have been added to basic-blocks
                # Link curr_bb to next block unconditionally (aka fall-through)
                next_block = program.blocks[block_idx+1]
                curr_bb.successor_if_fallthrough = bb_by_block[next_block][0]

        # Flatten basic-block list of all Blocks
        state = DivideBasicBlockPassState()
        for bb_list in bb_by_block.values():
            state.basic_blocks += bb_list
        if len(state.basic_blocks) == 0:
            raise SeekException("Program doesn't contain any instructions")

        # Fix-up setting `jump_instr`
        # Fix-up all last branching instruction for basic-blocks: remove them from current BasicBlock
        for bb in state.basic_blocks:
            assert bb.jump_instr is None
            if bb.successor_if_jump is None and bb.successor_if_fallthrough is not None:
                # For fall-through: successor_if_jump is None, successor_if_fallthrough is target
                pass
            else:
                last_instr = bb.clauses.pop()
                assert isinstance(last_instr, InstrCall)

                # NOTE: may be changed back to None for ControlFlowEnum.CondJump
                bb.jump_instr = BasicBlockJumpInstr.from_instr_call(last_instr)

                if last_instr.control_flow_enum == ControlFlowEnum.CondJump:
                    # If we are conditionally jumping to the successor basic-block, eliminate the s_cbranch_xxx
                    # This is not only an optimization, but also necessary
                    #
                    # After this fix-up, no (successor) basic-block B will have two same precedent basic-blocks A
                    # (aka A.successor_if_jump is B and A.successor_if_fallthrough is B)
                    if bb.successor_if_jump is bb.successor_if_fallthrough:
                        # We simply set successor_if_jump to None, and leave successor_if_fallthrough as is (like normal fall-through)  # noqa E501: line too long
                        # Moreover, `last_instr` (s_cbranch_xxx) is dropped
                        bb.successor_if_jump = None
                        bb.jump_instr = None

            bb._sanity_check_control_flow_at_exit()

        # Fix-up: compute predecessors for all basic-blocks
        for bb in state.basic_blocks:
            if bb.successor_if_jump is not None:
                assert bb not in bb.successor_if_jump.predecessors
                bb.successor_if_jump.predecessors.append(bb)
            if bb.successor_if_fallthrough is not None:
                assert bb not in bb.successor_if_fallthrough.predecessors
                bb.successor_if_fallthrough.predecessors.append(bb)

        # All done
        state._sanity_check()

        # Resolve mem_token: `state.mem_token_object_to_mem_tokens` is filled
        self.__resolve_mem_token(program, state)

        optimizer_state.divide_basic_block = state
        return True

    # noinspection PyMethodMayBeStatic
    def __resolve_mem_token(self, program, state: DivideBasicBlockPassState):
        """
        Check there are no duplicated mem_token_object.
        Resolves mem_token_object to mem_token in explicit_call.

        This must be done after after basic-blocks are divided, but before OptimizeBasicBlockPass,
        in case some mem_token are removed by dead basic-block elimination.
        """
        explicit_wait_mem_token_or_token_objects = set()  # type: Set[Any]

        for bb in state.basic_blocks:  # type: BasicBlock
            for clause in bb.clauses:
                if isinstance(clause, ExplicitWaitCall):
                    for mem_token_or_token_object in clause.mem_token_or_token_objects:
                        assert mem_token_or_token_object is not None
                        explicit_wait_mem_token_or_token_objects.add(mem_token_or_token_object)
                elif isinstance(clause, ExplicitUsesCall):
                    pass
                else:
                    assert isinstance(clause, InstrCall)
                    if clause.mem_token is not None:
                        assert isinstance(clause.mem_token, MemToken)
                        check(clause.mem_token.token_object not in state.mem_token_object_to_mem_tokens,
                              f"Duplicated mem_token: {clause.mem_token}")
                        state.mem_token_object_to_mem_tokens[clause.mem_token.token_object] = clause.mem_token

        # Check that all explicit waits are valid
        for mem_token_or_token_object in explicit_wait_mem_token_or_token_objects:
            if isinstance(mem_token_or_token_object, MemToken):
                mem_token = mem_token_or_token_object
                check(mem_token.token_object in state.mem_token_object_to_mem_tokens,
                      f"mem_token not found: {mem_token}")
                my_mem_token = state.mem_token_object_to_mem_tokens[mem_token.token_object]
                check(mem_token is my_mem_token,
                      f"mem_token mismatched: {mem_token} is not {my_mem_token}")
            else:  # this is a mem_token_object
                mem_token_object = mem_token_or_token_object
                check(mem_token_object in state.mem_token_object_to_mem_tokens,
                      f"mem_token_object not found: {mem_token_object}")
