from typing import Set
from ..basic.instr import InstrCall, ExplicitWaitCall, ExplicitUsesCall
from .base_pass import BasePass, PassTag, OptimizerState
from .divide_basic_block_pass import BasicBlock


class EliminateDeadCodePass(BasePass):
    """
    As the name suggests, this pass does dead code elimination (DCE)
    """
    def __init__(self, /, priority: int = PassTag.EliminateDeadCode.value):
        super().__init__(priority)

    def required_tags(self) -> Set[PassTag]:
        return {PassTag.AnalyzeLiveVar}

    def generated_tags(self) -> Set[PassTag]:
        return {PassTag.EliminateDeadCode}

    def invalidated_tags(self) -> Set[PassTag]:
        return {PassTag.AnalyzeLiveVar, PassTag.OptimizeBasicBlock, PassTag.AnnotateClause}

    def reset(self, program):
        # This pass has nothing to reset
        pass

    def run(self, program) -> bool:
        optimizer_state = program.optimizer_state  # type: OptimizerState
        state = optimizer_state.divide_basic_block
        assert state is not None

        # Compute defs and uses for each individual basic-block
        modified = False
        for bb in state.basic_blocks:
            modified |= self.__dead_code_elimination(program, bb)

        return modified

    # noinspection PyMethodMayBeStatic
    def __dead_code_elimination(self, program, bb: BasicBlock) -> bool:
        """
        Eliminate unnecessary InstrCall for a basic-block
        Returns True if we actually removed something
        """
        curr_uses = bb.live_var_out.clone()

        idx_to_remove = set()  # type: Set[int]

        # Don't forget the last `jump_instr`, if exists
        if bb.jump_instr is not None:
            curr_uses.union_update(bb.jump_instr.gpr_uses_to_gprset())
            assert len(bb.jump_instr.gpr_holds) == 0
            assert len(bb.jump_instr.gpr_defs) == 0

        # Scan from the last to the first
        for idx, clause in reversed(list(enumerate(bb.clauses))):
            if isinstance(clause, ExplicitWaitCall):
                # We don't remove ExplicitWaitCall
                continue

            if isinstance(clause, ExplicitUsesCall):
                # We don't remove ExplicitUsesCall
                # TODO: maybe there could be a hint to remove them too?
                curr_uses.union_update(*clause.uses)
                continue

            assert isinstance(clause, InstrCall)
            if clause.mem_token is not None:
                # If this instruction is memory-related, we (conservatively) don't remove it
                pass
            elif len(clause.gpr_defs) == 0:
                # If this instruction doesn't define any Gpr, supposedly it should have side effects
                program.logger.debug(f"{self}: assumed side effects: `{clause.instr_name}` at {clause.srcloc}")
                pass
            else:
                gprset_defs = clause.gpr_defs_to_gprset()
                if curr_uses.is_intersected(gprset_defs):
                    curr_uses.difference_update(gprset_defs)
                else:
                    idx_to_remove.add(idx)
                    program.logger.debug(f"{self}: eliminated: {clause}")
                    continue

            # Now this instruction is **not** removed
            curr_uses.union_update(clause.gpr_uses_to_gprset())

        # Remove index(es) from `idx_to_remove`
        if len(idx_to_remove) > 0:
            bb.clauses = [clause for idx, clause in enumerate(bb.clauses) if idx not in idx_to_remove]
            return True
        else:
            return False
