from typing import Set, List, Dict, Union
from ..basic.register import Gpr, GprSet
from ..basic.instr import InstrCall, ExplicitWaitCall, ExplicitUsesCall
from .base_pass import BasePass, PassTag, OptimizerState
from .divide_basic_block_pass import BasicBlock, BasicBlockJumpInstr
from .annotate_clause_pass import AnnClause
import itertools


class AnalyzeLiveVarPass(BasePass):
    """
    Computes def/use for each basic-block, then computes in/out for each basic-block.
    """
    def __init__(self, /, priority: int = PassTag.AnalyzeLiveVar.value):
        super().__init__(priority)

    def required_tags(self) -> Set[PassTag]:
        return {PassTag.OptimizeBasicBlock}

    def generated_tags(self) -> Set[PassTag]:
        return {PassTag.AnalyzeLiveVar}

    def invalidated_tags(self) -> Set[PassTag]:
        return {PassTag.EliminateDeadCode}

    def reset(self, program):
        optimizer_state = program.optimizer_state  # type: OptimizerState
        state = optimizer_state.divide_basic_block
        if state is None:
            return
        for bb in state.basic_blocks:
            assert isinstance(bb, BasicBlock)
            bb.live_var_defs = None
            bb.live_var_uses = None
            bb.live_var_in = None
            bb.live_var_out = None

    def run(self, program) -> bool:
        optimizer_state = program.optimizer_state  # type: OptimizerState
        state = optimizer_state.divide_basic_block
        assert state is not None

        # Compute defs and uses for each individual basic-block
        for bb in state.basic_blocks:
            self.__compute_basic_block_var_defs_uses(bb)

        # Compute in and out for all basic-blocks
        self.__compute_all_basic_blocks_var_in_out(state.basic_blocks)

        # Compute life span for each single Gpr
        # NOTE: different views of the same Gpr may have different life span
        for bb in state.basic_blocks:
            self.__compute_basic_block_var_life_span(bb)

        # This pass always updates live_var_{uses,defs,in,out} for all basic-blocks
        return True

    # noinspection PyMethodMayBeStatic
    def __compute_basic_block_var_defs_uses(self, bb: BasicBlock):
        """
        Within a basic-block:
        - If a Gpr is used before defined, it's added to live_var_uses
        - If a Gpr is used after defined, silently go on (It's **not** added to live_var_uses)
        - If a Gpr defined, it's added to live_var_defs, no matter whether it's used or not later
        """
        bb.live_var_defs = GprSet()
        bb.live_var_uses = GprSet()

        # Loop from the beginning to the end
        # Don't forget the last `jump_instr`, if exists
        for clause in itertools.chain(bb.clauses, [bb.jump_instr] if bb.jump_instr is not None else []):
            if isinstance(clause, ExplicitWaitCall):
                continue

            if isinstance(clause, ExplicitUsesCall):
                tmp_uses = GprSet(*clause.uses)
                tmp_defs = GprSet()
            else:
                assert isinstance(clause, InstrCall) or isinstance(clause, BasicBlockJumpInstr)
                tmp_uses = clause.gpr_uses_to_gprset()
                tmp_defs = clause.gpr_defs_to_gprset()

            tmp_uses.difference_update(bb.live_var_defs)  # this is part of gpr which is not in `defs` before
            bb.live_var_uses.union_update(tmp_uses)  # add this part to `uses`

            # No matter these Gprs are previously used or not, we add them into defs
            # tmp_defs.difference_update(bb.live_var_uses)  # this is part of gpr which is not in `uses` before
            bb.live_var_defs.union_update(tmp_defs)  # add this part to `defs`

        # print(f"{bb.name} defs {bb.live_var_defs}")
        # print(f"{bb.name} uses {bb.live_var_uses}")

    # noinspection PyMethodMayBeStatic
    def __compute_all_basic_blocks_var_in_out(self, all_basic_block_list: List[BasicBlock]):
        """
        B.in = B.uses UNION (B.out DIFF B.defs)
        B.out = {UNION S.in}  // S is a successor of B

        IN[Exit] = {}  // This doesn't matter
        """
        # Now compute live_var_in and live_var_out
        update_queue = []  # type: List[BasicBlock]
        for bb in all_basic_block_list:
            bb.live_var_in = GprSet()
            bb.live_var_out = GprSet()
            update_queue.append(bb)

        while update_queue:
            bb = update_queue.pop(0)

            # bb.live_var_out = bb.successor_if_jump.live_var_in UNION
            #                   bb.successor_if_fallthrough.live_var_in
            bb.live_var_out = GprSet()
            if bb.successor_if_jump is not None:  # `bb.successor_if_jump` might be `bb`
                bb.live_var_out.union_update(bb.successor_if_jump.live_var_in)
            if bb.successor_if_fallthrough is not None:
                bb.live_var_out.union_update(bb.successor_if_fallthrough.live_var_in)

            # `bb.live_var_out` was updated, now update `bb.live_var_in` accordingly
            # bb.live_var_in = bb.live_var_uses UNION (bb.live_var_out DIFF bb.live_var_defs)
            old_live_var_in = bb.live_var_in
            bb.live_var_in = bb.live_var_out.clone() \
                               .difference_update(bb.live_var_defs) \
                               .union_update(bb.live_var_uses)

            # If `bb.live_var_in` is indeed updated, update `bb`'s predecessors
            if bb.live_var_in != old_live_var_in:
                for pred in bb.predecessors:
                    update_queue.append(pred)

    # noinspection PyMethodMayBeStatic
    def __compute_basic_block_var_life_span(self, bb: BasicBlock):
        # We represent life span of each (base_gpr,offset) by a non-negative integer
        # Here, each (base_gpr,offset) tuple has count == 1
        #
        # If a basic block has N stateful instructions, we use 2*N+2 bits to represent the life span:
        #   bit 0: live_var_in (aka defined by prior basic blocks)
        #   bit 1: used    by the 1st instruction
        #   bit 2: defined by the 1st instruction
        #   bit 3: used    by the 2nd instruction
        #   bit 4: defined by the 2nd instruction
        #   ......
        #   bit 2*N-1: used by the N-th instruction
        #   bit 2*N:   defined by the N-th instruction
        #   bit 2*N+1: live_var_out (aka used by following basic blocks)
        #
        # For live_var_gpr_life_span:
        #   - the key is base_gpr (not Gpr views!)
        #   - the value is a list of bitmap regarding all offsets (len == base_gpr.count)
        #
        live_var_gpr_life_span = dict()  # type: Dict[Gpr, List[int]]

        # var_last_def_bit_at is the last bit where a (base_gpr,offset) is defined
        # The values must be in {0,2,4,...,2*N}
        var_last_def_bit_at = dict()  # type: Dict[(Gpr, int), int]

        def mark_active(base_gpr: Gpr, offset: int, bit_at: int):
            assert bit_at >= 0
            assert not base_gpr.is_view
            assert 0 <= offset < base_gpr.count
            if base_gpr not in live_var_gpr_life_span:
                live_var_gpr_life_span[base_gpr] = [0] * base_gpr.count
            live_var_gpr_life_span[base_gpr][offset] |= 1 << bit_at

        def process_used(base_gpr: Gpr, offset: int, bit_at: int):
            assert bit_at >= 0 and bit_at % 2 == 1
            assert not base_gpr.is_view
            assert 0 <= offset < base_gpr.count

            # Get the latest define bit_at of this (base_gpr, offset), which must exist
            # Every bit in interval [last_def_bit_at,bit_at] is marked active
            assert (base_gpr, offset) in var_last_def_bit_at
            last_def_bit_at = var_last_def_bit_at[(base_gpr, offset)]
            assert last_def_bit_at % 2 == 0

            for b in range(last_def_bit_at, bit_at+1):
                mark_active(base_gpr, offset, b)

        def process_defined(base_gpr: Gpr, offset: int, bit_at: int):
            assert bit_at >= 0 and bit_at % 2 == 0
            assert not base_gpr.is_view
            assert 0 <= offset < base_gpr.count

            # We mark this (base_gpr, offset) active at this point, no matter whether it will be used later
            mark_active(base_gpr, offset, bit_at)

            # We update the latest bit_at of this (base_gpr, offset), regardless of its previous value
            var_last_def_bit_at[(base_gpr, offset)] = bit_at

        clauses_and_jump_instr = bb.clauses.copy()  # type: List[Union[InstrCall, ExplicitWaitCall, ExplicitUsesCall, BasicBlockJumpInstr]]  # shallow copy
        if bb.jump_instr is not None:
            clauses_and_jump_instr.append(bb.jump_instr)

        # Process bit 0: live_var_in (aka defined by prior basic blocks)
        for base_gpr in bb.live_var_in.base_gprs:
            offset_list = bb.live_var_in.get_offset_list(base_gpr)
            assert offset_list
            for offset in offset_list:
                process_defined(base_gpr, offset=offset, bit_at=0)

        # Process all instructions (bit 1,2,3,4,...,2*N-1,2*N)
        for idx, clause in enumerate(clauses_and_jump_instr):  # idx in {0,1,...,N-1}
            gprset_uses = clause.gpr_uses_to_gprset()
            for base_gpr in gprset_uses.base_gprs:
                for offset in gprset_uses.get_offset_list(base_gpr):
                    process_used(base_gpr, offset=offset, bit_at=2*idx+1)

            gprset_defs = clause.gpr_defs_to_gprset()
            for base_gpr in gprset_defs.base_gprs:
                for offset in gprset_defs.get_offset_list(base_gpr):
                    process_defined(base_gpr, offset=offset, bit_at=2*idx+2)

        # Process bit 2*N+1: live_var_out (aka used by following basic blocks)
        for base_gpr in bb.live_var_out.base_gprs:
            offset_list = bb.live_var_out.get_offset_list(base_gpr)
            assert offset_list
            for offset in offset_list:
                process_used(base_gpr, offset=offset, bit_at=2*len(clauses_and_jump_instr)+1)

        # Assign the result to basic block
        bb.live_var_gpr_life_span = live_var_gpr_life_span
