from typing import Set, TextIO, Optional
from ..basic.utility import IndentedWriter
from ..basic.exception import check
from .base_pass import BasePass, PassTag
from .divide_basic_block_pass import DivideBasicBlockPassState, OptimizerState
import sys


class PrintBasicBlockPass(BasePass):
    def __init__(self, /, file: TextIO = sys.stdout, indent: int = 0, priority: int = PassTag.PrintBasicBlock.value):
        # By default, this pass should run in the very end
        super().__init__(priority)

        check(file is not None)
        check(indent >= 0)
        self.wr = IndentedWriter(file, indent)

    def required_tags(self) -> Set[PassTag]:
        return {PassTag.DivideBasicBlock}

    def generated_tags(self) -> Set[PassTag]:
        return {PassTag.PrintBasicBlock}

    def invalidated_tags(self) -> Set[PassTag]:
        return set()

    def reset(self, program):
        # This pass has nothing to reset
        pass

    def run(self, program) -> bool:
        assert program.optimizer_state is not None
        assert program.optimizer_state.divide_basic_block is not None
        state = program.optimizer_state.divide_basic_block  # type: DivideBasicBlockPassState
        assert state is not None

        # Run a sanity check before printing
        state._sanity_check()

        # Get Gpr index
        def get_gpr_index(gpr) -> Optional[int]:
            if gpr in program.forced_index:
                return program.forced_index[gpr]
            elif gpr in program.assigned_index:
                return program.assigned_index[gpr]
            else:
                return None

        self.wr()  # write an empty line at first
        for bb in state.basic_blocks:
            self.wr(f'{bb.name}:')
            with self.wr.indent():
                self.wr(f'//')
                self.wr(f'// bb_predecessors: {[pred.name for pred in bb.predecessors]}')
                self.wr(f'//')
                self.wr(f'// live_var_in: {bb.live_var_in if bb.live_var_in is not None else "-  // not run"}')
                self.wr(f'// live_var_uses: {bb.live_var_uses if bb.live_var_uses is not None else "-  // not run"}')
                if bb.live_var_gpr_life_span is not None:
                    clauses_and_jump_instr_count = len(bb.clauses)
                    if bb.jump_instr is not None:
                        clauses_and_jump_instr_count += 1

                    self.wr(f"//")
                    self.wr(f"// live_var_gpr_life_span:")
                    self.wr(f"//   /* number of instructions: {clauses_and_jump_instr_count} */")
                    self.wr(f"//   /* number of base Spec: {sum(1 for x in bb.live_var_gpr_life_span if x.rtype.is_special())} */")
                    self.wr(f"//   /* number of base SGpr: {sum(1 for x in bb.live_var_gpr_life_span if x.rtype.is_sgpr())} */")
                    self.wr(f"//   /* number of base VGpr: {sum(1 for x in bb.live_var_gpr_life_span if x.rtype.is_vgpr())} */")
                    for base_gpr, bitmap_list in sorted(bb.live_var_gpr_life_span.items(), key=repr):
                        def bitmap_to_string(bitmap: int):
                            s = ""
                            for idx in range(0, 2*clauses_and_jump_instr_count+2):
                                s += 'x' if (bitmap & 1) else '_'
                                bitmap >>= 1
                                if idx % 2 == 0: s += ' '
                            assert bitmap == 0
                            return s
                        self.wr(f"//   {base_gpr}:")
                        for idx, bitmap in enumerate(bitmap_list):
                            self.wr(f"//     [{idx:-2}] = {bitmap_to_string(bitmap)}")
                    self.wr(f"//")
                self.wr(f'//')
                self.wr(f'// pending_mem_in: {bb.pending_mem_in if bb.pending_mem_in is not None else "-  // not run"}')
                self.wr(f'//')
                if bb.annotate_clauses is not None:
                    for annclause in bb.annotate_clauses:
                        self.wr(repr(annclause))
                else:
                    for clause in bb.clauses:
                        self.wr(repr(clause))
                self.wr(f'//')
                self.wr(f'// bb_successor_if_jump: {repr(bb.successor_if_jump.name) if bb.successor_if_jump is not None else "-"}')  # noqa E501: line too long
                self.wr(f'// bb_successor_if_fallthrough: {repr(bb.successor_if_fallthrough.name) if bb.successor_if_fallthrough is not None else "-"}')  # noqa E501: line too long
                self.wr(f'// bb_jump_instr: {bb.jump_instr.instr_name if bb.jump_instr is not None else "-"}  // {bb.control_flow_at_exit.name}')  # noqa E501: line too long
                self.wr(f'//')
                self.wr(f'// live_var_defs: {bb.live_var_defs if bb.live_var_defs is not None else "-  // not run"}')
                self.wr(f'// live_var_out: {bb.live_var_out if bb.live_var_out is not None else "-  // not run"}')
                self.wr(f'//')
                self.wr(f'// pending_mem_out: {bb.pending_mem_out if bb.pending_mem_out is not None else "-  // not run"}')  # noqa E501: line too long
                self.wr(f'//')
            self.wr()

        # All basic-blocks have been printed...
        # Let's print some global information
        self.wr("//" + "=" * 32)
        self.wr("//" + "SUMMARY".center(32))
        self.wr("//" + "=" * 32)
        state = program.optimizer_state  # type: OptimizerState

        # Print register interference if computed
        if state.register_interference_vgpr:
            register_interference_vgpr = state.register_interference_vgpr
            self.wr("//")
            self.wr("// VGpr interference:")
            for base_gpr in sorted(register_interference_vgpr.keys(), key=repr):
                self.wr(f"//   {register_interference_vgpr[base_gpr]}")
            self.wr("//")

        if state.register_interference_sgpr:
            register_interference_sgpr = state.register_interference_sgpr
            self.wr("//")
            self.wr("// SGpr interference:")
            for base_gpr in sorted(register_interference_sgpr.keys(), key=repr):
                self.wr(f"//   {register_interference_sgpr[base_gpr]}")
            self.wr("//")

        # Print register allocation if computed
        if state.register_allocation_vgpr_by_color:
            register_allocation_vgpr_by_color = state.register_allocation_vgpr_by_color
            self.wr("//")
            self.wr("// VGPR allocation by coloring:")
            for color, gpr_list in enumerate(register_allocation_vgpr_by_color):
                self.wr(f"//   Color #{color} (count={max(gpr.count for gpr in gpr_list)}): {gpr_list}")
            self.wr("//")

        if state.register_interference_sgpr:
            register_allocation_sgpr_by_color = state.register_allocation_sgpr_by_color
            self.wr("//")
            self.wr("// SGPR allocation by coloring:")
            for color, gpr_list in enumerate(register_allocation_sgpr_by_color):
                self.wr(f"//   Color #{color} (count={max(gpr.count for gpr in gpr_list)}): {gpr_list}")
            self.wr("//")

        if state.register_allocation_vgpr_count is not None:
            self.wr(f"// VGPR Count: {state.register_allocation_vgpr_count}")

        if state.register_allocation_sgpr_count is not None:
            self.wr(f"// SGPR Count: {state.register_allocation_sgpr_count}")

        self.wr()  # write an empty line at last

        return False
