from typing import Set, Optional, Union
from ..basic.const import _INSTR_STR_WIDTH
from ..basic.instr import InstrCall, ExplicitWaitCall, ExplicitUsesCall, Waitcnt
from .base_pass import BasePass, PassTag, OptimizerState
from .divide_basic_block_pass import BasicBlock, DivideBasicBlockPassState

class AnnClause:
    def __init__(self, clause: Union[InstrCall, ExplicitWaitCall, ExplicitUsesCall]):
        self.clause = clause  # type: Union[InstrCall, ExplicitWaitCall, ExplicitUsesCall]

        # For InsertWaitcntPass
        self.insert_waitcnt = None  # type: Optional[Waitcnt]

    def __repr__(self):
        lines = []
        if self.insert_waitcnt is not None:
            lines.append(f"/*auto waitcnt*/ {self.insert_waitcnt}")
        lines.append(repr(self.clause))
        return "\n".join(lines)

    def generate(self, program, wr):
        if self.insert_waitcnt is not None:
            wr("/*auto*/ s_waitcnt ".ljust(_INSTR_STR_WIDTH) + self.insert_waitcnt.waitcnt_str())

        self.clause.generate(program, wr)


class AnnotateClausePass(BasePass):
    """
    Create annotated clause (AnnClause) from the original clause blindly.
    NOTE:
        Anyway, the priority should make it run after EliminateDeadCodePass.
    """
    def __init__(self, /, priority: int = PassTag.AnnotateClause.value):
        super().__init__(priority)

    def required_tags(self) -> Set[PassTag]:
        return {PassTag.DivideBasicBlock}

    def generated_tags(self) -> Set[PassTag]:
        return {PassTag.AnnotateClause}

    def invalidated_tags(self) -> Set[PassTag]:
        return set()

    def reset(self, program):
        optimizer_state = program.optimizer_state  # type: OptimizerState
        state = optimizer_state.divide_basic_block
        if state is None:
            return
        for bb in state.basic_blocks:
            assert isinstance(bb, BasicBlock)
            bb.annotate_clauses = None

    def run(self, program) -> bool:
        optimizer_state = program.optimizer_state  # type: OptimizerState
        state = optimizer_state.divide_basic_block  # type: DivideBasicBlockPassState
        assert state is not None

        for bb in state.basic_blocks:
            assert isinstance(bb, BasicBlock)
            annclauses = []
            for clause in bb.clauses:
                annclauses.append(AnnClause(clause))
            bb.annotate_clauses = annclauses

        return True
