from typing import Set
from .base_pass import BasePass, PassTag, OptimizerState
from .divide_basic_block_pass import DivideBasicBlockPassState, ControlFlowEnum, BasicBlockJumpInstr


class OptimizeBasicBlockPass(BasePass):
    def __init__(self, /,
                 prune_dead_basic_block: bool = True,
                 prune_empty_fallthrough_basic_block: bool = True,
                 prune_empty_alwaysjump_basic_block: bool = True,
                 prune_empty_terminate_basic_block: bool = True,
                 merge_consecutive_fallthrough_basic_block: bool = True,
                 promote_single_entry_basic_block_clauses: bool = True,
                 rewrite_alwaysjump_to_next_basic_block: bool = True,
                 priority: int = PassTag.OptimizeBasicBlock.value):
        super().__init__(priority)
        self.prune_dead_basic_block = prune_dead_basic_block  # type: bool
        self.prune_empty_fallthrough_basic_block = prune_empty_fallthrough_basic_block  # type: bool
        self.prune_empty_alwaysjump_basic_block = prune_empty_alwaysjump_basic_block  # type: bool
        self.prune_empty_terminate_basic_block = prune_empty_terminate_basic_block  # type: bool
        self.merge_consecutive_fallthrough_basic_block = merge_consecutive_fallthrough_basic_block  # type: bool
        self.promote_single_entry_basic_block_clauses = promote_single_entry_basic_block_clauses  # type: bool
        self.rewrite_alwaysjump_to_next_basic_block = rewrite_alwaysjump_to_next_basic_block  # type: bool

    def required_tags(self) -> Set[PassTag]:
        return {PassTag.DivideBasicBlock}

    def generated_tags(self) -> Set[PassTag]:
        return {PassTag.OptimizeBasicBlock}

    def invalidated_tags(self) -> Set[PassTag]:
        return {PassTag.AnalyzeLiveVar, PassTag.AnnotateClause}

    def reset(self, program):
        # Just reset `program.optimizer_state.divide_basic_block` as DivideBasicBlockPass
        optimizer_state = program.optimizer_state  # type: OptimizerState
        optimizer_state.divide_basic_block = None

    def __prune_dead_basic_block(self, program, state: DivideBasicBlockPassState) -> bool:
        """
        As the name suggests, prune all dead basic-blocks
        """
        assert self.prune_dead_basic_block
        modified = False

        dead_bb_set = set(state.basic_blocks)
        alive_bb_set = {state.basic_blocks[0]}

        while alive_bb_set:
            bb = alive_bb_set.pop()
            if bb not in dead_bb_set:
                continue
            dead_bb_set.remove(bb)

            if bb.successor_if_jump is not None:
                alive_bb_set.add(bb.successor_if_jump)
            if bb.successor_if_fallthrough is not None:
                alive_bb_set.add(bb.successor_if_fallthrough)

        for bb in dead_bb_set:
            # Now `bb` is a target to remove
            program.logger.debug(f"{self}: prune dead basic-block: {bb}")
            state.basic_blocks.remove(bb)
            modified = True

            for pred in bb.predecessors:
                assert pred in dead_bb_set

            if bb.successor_if_jump is not None:
                bb.successor_if_jump.predecessors.remove(bb)
            if bb.successor_if_fallthrough is not None:
                bb.successor_if_fallthrough.predecessors.remove(bb)

        if modified:
            state._sanity_check()
        return modified

    def __prune_empty_fallthrough_basic_block(self, program, state: DivideBasicBlockPassState) -> bool:
        """
        B_empty(remove!) --fallthrough--> C
                          {OTHERS} --*--> C
        """
        assert self.prune_empty_fallthrough_basic_block
        modified = False

        # NOTE: we **include** the first (entrypoint) basic-block
        for idx in range(len(state.basic_blocks)-1, -1, -1):
            bb = state.basic_blocks[idx]
            if not (len(bb.clauses) == 0 and bb.control_flow_at_exit == ControlFlowEnum.Fallthrough):
                continue

            # Now `bb` is a target to remove
            assert bb.successor_if_jump is None
            assert bb.successor_if_fallthrough is not None

            program.logger.debug(f"{self}: prune empty fallthrough basic-block: {bb}")
            bb_popped = state.basic_blocks.pop(idx)
            assert bb_popped is bb
            modified = True

            assert bb.successor_if_fallthrough is not bb
            bb.successor_if_fallthrough.predecessors.remove(bb)

            for pred in bb.predecessors:
                assert (pred.successor_if_jump is bb) ^ (pred.successor_if_fallthrough is bb)
                if pred is bb:
                    continue

                if pred.successor_if_jump is bb:
                    assert pred.successor_if_fallthrough is not bb.successor_if_fallthrough  # `pred` is not `bb`
                    pred.successor_if_jump = bb.successor_if_fallthrough

                    assert pred not in bb.successor_if_fallthrough.predecessors
                    bb.successor_if_fallthrough.predecessors.append(pred)
                else:  # `pred.successor_if_fallthrough` is `bb`
                    assert pred.successor_if_fallthrough is bb
                    pred.successor_if_fallthrough = bb.successor_if_fallthrough

                    # If we are conditionally jumping to the successor basic-block, eliminate the s_cbranch_xxx
                    # This is not only an optimization, but also necessary
                    #
                    # After this fix-up, no (successor) basic-block B will have two same precedent basic-blocks A
                    # (aka A.successor_if_jump is B and A.successor_if_fallthrough is B)
                    if pred.successor_if_jump is bb.successor_if_fallthrough:
                        assert pred.jump_instr.control_flow_enum == ControlFlowEnum.CondJump  # don't use `pred.control_flow_at_exit` here  # noqa E501: line too long

                        # We simply set successor_if_jump to None, and leave successor_if_fallthrough as is (like normal fall-through)  # noqa E501: line too long
                        # Moreover, `last_instr` (s_cbranch_xxx) is dropped
                        pred.successor_if_jump = None
                        pred.jump_instr = None

                        # Originally, `pred.successor_if_jump` is `bb.successor_if_fallthrough`,
                        # so `pred` must be in `bb.successor_if_fallthrough.predecessors`
                        # We shall not add `pred` to `bb.successor_if_fallthrough.predecessors` again
                        assert pred in bb.successor_if_fallthrough.predecessors
                    else:
                        # Originally, `pred.successor_if_jump` is **not** `bb.successor_if_fallthrough`,
                        # (and trivially, `pred.successor_if_fallthrough` is not `bb.successor_if_fallthrough`)
                        # so `pred` must **not** be in `bb.successor_if_fallthrough.predecessors`
                        # We shall add `pred` to `bb.successor_if_fallthrough.predecessors`
                        assert pred not in bb.successor_if_fallthrough.predecessors
                        bb.successor_if_fallthrough.predecessors.append(pred)

        if modified:
            state._sanity_check()
        return modified

    def __prune_empty_alwaysjump_basic_block(self, program, state: DivideBasicBlockPassState) -> bool:
        """
        B_empty(remove!) --alwaysjump--> C
                         {OTHERS} --*--> C

        There is one case when we can't remove B_empty:
            PRED:
                // do something
                s_cbranch_xxx D  // D != C
            B_empty:
                s_branch C
        """
        assert self.prune_empty_alwaysjump_basic_block
        modified = False

        # NOTE: we do **not** include the first (entrypoint) basic-block
        for bb in state.basic_blocks[1:]:
            if not (len(bb.clauses) == 0 and bb.control_flow_at_exit == ControlFlowEnum.AlwaysJump):
                continue

            assert bb.successor_if_jump is not None
            assert bb.successor_if_fallthrough is None
            if bb.successor_if_jump is bb:  # `bb` dumbly jumps to itself...
                continue
            assert bb not in bb.predecessors

            # Now `bb` might be a target to remove
            # We do not directly remove `bb`, but make it a dead basic-block, and prune it later
            for pred in bb.predecessors.copy():  # Loop over a copy to prevent remove-while-iterate issue
                assert pred is not bb
                assert (pred.successor_if_jump is bb) ^ (pred.successor_if_fallthrough is bb)

                if pred.control_flow_at_exit == ControlFlowEnum.AlwaysJump:
                    assert pred.successor_if_jump is bb
                    assert pred.successor_if_fallthrough is None
                    assert pred.jump_instr.control_flow_enum == ControlFlowEnum.AlwaysJump
                    program.logger.debug(f"{self}: maybe prune empty basic-block ending with s_branch: {bb}")
                    pred.successor_if_jump = bb.successor_if_jump
                    # pred.successor_if_fallthrough = None  # not necessary as it's already None
                    bb.predecessors.remove(pred)
                    assert pred not in bb.successor_if_jump.predecessors
                    bb.successor_if_jump.predecessors.append(pred)
                    modified = True
                elif pred.control_flow_at_exit == ControlFlowEnum.Fallthrough:
                    assert pred.successor_if_jump is None
                    assert pred.successor_if_fallthrough is bb
                    assert pred.jump_instr is None
                    program.logger.debug(f"{self}: maybe prune empty basic-block ending with s_branch: {bb}")
                    pred.successor_if_jump = bb.successor_if_jump
                    pred.successor_if_fallthrough = None
                    pred.jump_instr = BasicBlockJumpInstr.make_alwaysjump()
                    assert pred not in bb.successor_if_jump.predecessors
                    bb.successor_if_jump.predecessors.append(pred)
                    bb.predecessors.remove(pred)
                    modified = True
                elif pred.control_flow_at_exit == ControlFlowEnum.CondJump:
                    if pred.successor_if_jump is bb:
                        assert pred.successor_if_fallthrough is not bb
                        assert pred.jump_instr.control_flow_enum == ControlFlowEnum.CondJump
                        program.logger.debug(f"{self}: maybe prune empty basic-block ending with s_branch: {bb}")

                        # If we are conditionally jumping to the successor basic-block, eliminate the s_cbranch_xxx
                        # This is not only an optimization, but also necessary
                        #
                        # After this fix-up, no (successor) basic-block B will have two same precedent basic-blocks A
                        # (aka A.successor_if_jump is B and A.successor_if_fallthrough is B)
                        if pred.successor_if_fallthrough is bb.successor_if_jump:
                            assert pred in bb.successor_if_jump.predecessors
                            pred.successor_if_jump = None
                            # `pred.successor_if_fallthrough` is not touched (keep it `bb.successor_if_jump`)
                            pred.jump_instr = None
                        else:  # pred.successor_if_fallthrough is not bb.successor_if_jump:
                            pred.successor_if_jump = bb.successor_if_jump
                            # `pred.successor_if_fallthrough` is not changed
                            assert pred not in bb.successor_if_jump.predecessors
                            bb.successor_if_jump.predecessors.append(pred)
                        bb.predecessors.remove(pred)
                        modified = True
                    else:  # pred.successor_if_jump is not bb, aka pred.successor_if_fallthrough is bb
                        assert pred.successor_if_fallthrough is bb
                        if pred.successor_if_jump is bb.successor_if_jump:
                            assert pred.successor_if_jump is bb.successor_if_jump
                            assert pred.successor_if_fallthrough is bb
                            assert pred in bb.successor_if_jump.predecessors
                            program.logger.debug(f"{self}: maybe prune empty basic-block ending with s_branch: {bb}")
                            # `pred.successor_if_jump` is not touched (keep it `bb.successor_if_jump`)
                            pred.successor_if_fallthrough = None
                            pred.jump_instr = BasicBlockJumpInstr.make_alwaysjump()
                            bb.predecessors.remove(pred)
                            modified = True
                        else:  # pred.successor_if_jump is not bb.successor_if_jump
                            # In this case, we could **not** remove `bb`!
                            program.logger.debug(f"{self}: cannot prune empty basic-block ending with s_branch: {bb}")
                            pass
                else:  # pred.control_flow_at_exit == ControlFlowEnum.Terminate
                    assert False, f"Why {pred.control_flow_at_exit}?"

        if modified:
            state._sanity_check()
        return modified

    def __prune_empty_terminate_basic_block(self, program, state: DivideBasicBlockPassState) -> bool:
        """
        B_empty(remove!) --terminate-->
        """
        assert self.prune_empty_terminate_basic_block
        modified = False

        # NOTE: we do **not** include the first (entrypoint) basic-block
        for bb in state.basic_blocks[1:]:
            if not (len(bb.clauses) == 0 and bb.control_flow_at_exit == ControlFlowEnum.Terminate):
                continue

            assert bb.successor_if_jump is None
            assert bb.successor_if_fallthrough is None
            assert bb not in bb.predecessors

            # Now `bb` might be a target to remove
            # We do not directly remove `bb`, but make it a dead basic-block, and prune it later
            for pred in bb.predecessors.copy():  # Loop over a copy to prevent remove-while-iterate issue
                assert pred is not bb
                assert (pred.successor_if_jump is bb) ^ (pred.successor_if_fallthrough is bb)

                if pred.control_flow_at_exit == ControlFlowEnum.AlwaysJump:
                    assert pred.successor_if_jump is bb
                    assert pred.successor_if_fallthrough is None
                    assert pred.jump_instr.control_flow_enum == ControlFlowEnum.AlwaysJump
                    program.logger.debug(f"{self}: maybe prune empty basic-block ending with s_endpgm: {bb}")
                    pred.successor_if_jump = None
                    # pred.successor_if_fallthrough = None  # not necessary as it's already None
                    pred.jump_instr = BasicBlockJumpInstr.make_terminate()
                    bb.predecessors.remove(pred)
                    modified = True
                elif pred.control_flow_at_exit == ControlFlowEnum.Fallthrough:
                    assert pred.successor_if_jump is None
                    assert pred.successor_if_fallthrough is bb
                    assert pred.jump_instr is None
                    remaining_other_predecessors = [p for p in bb.predecessors if p is not pred]
                    # Other predecessors shall never end with fall-through (to `bb`) or Terminate
                    assert not any(p.control_flow_at_exit == ControlFlowEnum.Fallthrough
                                   for p in remaining_other_predecessors)
                    assert not any(p.control_flow_at_exit == ControlFlowEnum.Terminate
                                   for p in remaining_other_predecessors)

                    all_alwaysjump = all(p.control_flow_at_exit == ControlFlowEnum.AlwaysJump
                                         for p in remaining_other_predecessors)
                    if all_alwaysjump:
                        # pred.successor_if_jump = None  # not necessary as it's already None
                        pred.successor_if_fallthrough = None
                        pred.jump_instr = BasicBlockJumpInstr.make_terminate()
                        bb.predecessors.remove(pred)
                        modified = True
                    else:
                        # In this case, we don't prune `bb`:
                        #   pred_1 --fallthrough--> bb
                        #   pred_2 --condjump--> bb
                        program.logger.debug(f"{self}: cannot prune empty basic-block ending with s_endpgm: {bb}")
                elif pred.control_flow_at_exit == ControlFlowEnum.CondJump:
                    # In this case, we could **not** remove `bb`!
                    program.logger.debug(f"{self}: cannot prune empty basic-block ending with s_endpgm: {bb}")
                    pass
                else:  # pred.control_flow_at_exit == ControlFlowEnum.Terminate
                    assert False, f"Why {pred.control_flow_at_exit}?"

        if modified:
            state._sanity_check()
        return modified

    def __merge_consecutive_fallthrough_basic_block(self, program, state: DivideBasicBlockPassState) -> bool:
        """
        A --fallthrough--> B(remove!)
        """
        assert self.merge_consecutive_fallthrough_basic_block
        modified = False

        # NOTE: we **include** the first (entrypoint) basic-block
        for idx in range(len(state.basic_blocks)-1, -1, -1):
            bb = state.basic_blocks[idx]
            if not (len(bb.predecessors) == 1 and
                    bb.predecessors[0].control_flow_at_exit == ControlFlowEnum.Fallthrough):
                continue

            # Now `bb` is a target to remove
            pred = bb.predecessors[0]
            assert pred is not bb
            assert pred.successor_if_jump is None
            assert pred.successor_if_fallthrough is bb

            program.logger.debug(f"{self}: merge consecutive fallthrough basic-block: {bb}")
            bb_popped = state.basic_blocks.pop(idx)
            assert bb_popped is bb
            modified = True

            # Append `bb` at the end of `pred`
            pred.clauses += bb.clauses
            pred.successor_if_jump = bb.successor_if_jump
            pred.successor_if_fallthrough = bb.successor_if_fallthrough
            pred.jump_instr = bb.jump_instr

            if bb.successor_if_jump is not None:
                assert pred not in bb.successor_if_jump.predecessors
                bb.successor_if_jump.predecessors.remove(bb)
                bb.successor_if_jump.predecessors.append(pred)

            if bb.successor_if_fallthrough is not None:
                assert pred not in bb.successor_if_fallthrough.predecessors
                bb.successor_if_fallthrough.predecessors.remove(bb)
                bb.successor_if_fallthrough.predecessors.append(pred)

        if modified:
            state._sanity_check()
        return modified

    def __promote_single_entry_basic_block_clauses(self, program, state: DivideBasicBlockPassState) -> bool:
        """
        A --fallthrough/alwaysjump--> B(promote!)
        """
        assert self.promote_single_entry_basic_block_clauses
        modified = False

        # NOTE: we does **not** include the first (entrypoint) basic-block
        # Clauses in the first (entrypoint) basic-block should never be "promoted" to its predecessor
        for idx in range(len(state.basic_blocks)-1, 0, -1):
            bb = state.basic_blocks[idx]
            if not (len(bb.predecessors) == 1 and len(bb.clauses) > 0):
                continue

            pred = bb.predecessors[0]
            if pred is bb:
                continue
            if pred.control_flow_at_exit == ControlFlowEnum.CondJump:
                continue

            assert pred.control_flow_at_exit in [ControlFlowEnum.AlwaysJump, ControlFlowEnum.Fallthrough]

            # Now `bb` is a target to promote
            assert pred is not bb
            assert (pred.successor_if_jump is bb) ^ (pred.successor_if_fallthrough is bb)
            assert (pred.successor_if_jump is None) ^ (pred.successor_if_fallthrough is None)

            assert len(bb.clauses) > 0
            program.logger.debug(f"{self}: promote single entry basic-block clauses: {bb}")
            pred.clauses += bb.clauses
            bb.clauses.clear()
            modified = True

        if modified:
            state._sanity_check()
        return modified

    def __rewrite_alwaysjump_to_next_basic_block(self, program, state: DivideBasicBlockPassState) -> bool:
        """
        A --alwaysjump--> B, and A is right before of B:
            A --fallthrough--> B
        """
        assert self.rewrite_alwaysjump_to_next_basic_block
        modified = False

        # NOTE: we do **not** include the first (entrypoint) basic-block
        for idx in range(len(state.basic_blocks)-1, 0, -1):
            bb = state.basic_blocks[idx]
            prev_bb = state.basic_blocks[idx-1]

            if prev_bb.control_flow_at_exit == ControlFlowEnum.AlwaysJump:
                assert prev_bb.successor_if_jump is not None
                assert prev_bb.successor_if_fallthrough is None
                if prev_bb.successor_if_jump is bb:
                    # Now `prev_bb` is a target to rewrite
                    program.logger.debug(f"{self}: rewrite s_branch to next basic-block: {prev_bb}")
                    modified = True

                    prev_bb.successor_if_jump = None
                    prev_bb.successor_if_fallthrough = bb
                    prev_bb.jump_instr = None

                    assert prev_bb in bb.predecessors

        if modified:
            state._sanity_check()
        return modified

    def run(self, program) -> bool:
        assert program.optimizer_state is not None
        assert program.optimizer_state.divide_basic_block is not None
        state = program.optimizer_state.divide_basic_block  # type: DivideBasicBlockPassState
        assert state is not None

        # Run a sanity check before optimization
        state._sanity_check()

        modified_any = False
        while True:
            curr_loop_modified = False
            if self.prune_dead_basic_block:
                curr_loop_modified |= self.__prune_dead_basic_block(program, state)

            if self.prune_empty_fallthrough_basic_block:
                curr_loop_modified |= self.__prune_empty_fallthrough_basic_block(program, state)
            if self.prune_empty_alwaysjump_basic_block:
                curr_loop_modified |= self.__prune_empty_alwaysjump_basic_block(program, state)
            if self.prune_empty_terminate_basic_block:
                curr_loop_modified |= self.__prune_empty_terminate_basic_block(program, state)

            if self.merge_consecutive_fallthrough_basic_block:
                curr_loop_modified |= self.__merge_consecutive_fallthrough_basic_block(program, state)
            if self.promote_single_entry_basic_block_clauses:
                curr_loop_modified |= self.__promote_single_entry_basic_block_clauses(program, state)
            if self.rewrite_alwaysjump_to_next_basic_block:
                curr_loop_modified |= self.__rewrite_alwaysjump_to_next_basic_block(program, state)

            modified_any |= curr_loop_modified
            if not curr_loop_modified:
                break

        # Run a sanity check after optimization (no matter modified or not)
        state._sanity_check()

        return modified_any
