from typing import Dict, Set, Tuple, List, DefaultDict, Optional
from collections import defaultdict
from ..basic.register import Gpr, GprType
from .base_pass import BasePass, PassTag
from .divide_basic_block_pass import BasicBlock


class OneRegisterInterference:
    def __init__(self, target_base_gpr: Gpr):
        assert not target_base_gpr.is_view
        self.target_base_gpr = target_base_gpr  # type: Gpr

        #
        # How we represent the conflict of two Gpr:
        #   For each (conflict_base_gpr, offset) in the conflict set, the index of `conflict_base_gpr`
        #   cannot be the index of `self.target_base_gpr` + `offset`.
        #   Then we group them by `conflict_base_gpr`.
        #
        self.conflicts = defaultdict(set)  # type: DefaultDict[Gpr, Set[int]]

    def __repr__(self):
        str_conflicts = []  # type: List[str]
        for conflict_base_gpr, conflict_list in sorted(self.conflicts.items()):
            assert not conflict_base_gpr.is_view
            if len(conflict_list) == self.target_base_gpr.count + conflict_base_gpr.count - 1:
                # This means self.target_base_gpr exactly conflicts with conflict_base_gpr
                # (not even a single 4-byte Gpr can be overlapped)
                str_conflicts.append(repr(conflict_base_gpr))
            else:
                # This means self.target_base_gpr does not exactly conflict with conflict_base_gpr (maybe less or more).
                # Less: at least a single 4-byte Gpr can be overlapped
                # More: or pre-indexed Gprs, self.target_base_gpr may conflict with more than conflict_base_gpr
                str_conflicts.append(f"{repr(conflict_base_gpr)}@{{{','.join(repr(x) for x in sorted(conflict_list))}}}")

        return repr(self.target_base_gpr) + " <-> {" + ", ".join(sorted(str_conflicts)) + "}"


class ComputeRegisterInterferencePass(BasePass):
    def __init__(self, /, priority: int = PassTag.ComputeRegisterInterference.value):
        super().__init__(priority)

    def required_tags(self) -> Set[PassTag]:
        return {PassTag.AnnotateClause}

    def generated_tags(self) -> Set[PassTag]:
        return {PassTag.ComputeRegisterInterference}

    def invalidated_tags(self) -> Set[PassTag]:
        return set()

    def reset(self, program):
        state = program.optimizer_state
        if state is None:
            return

        state.register_interference_vgpr = None
        state.register_interference_sgpr = None
        state.register_interference_perfect_elimination_ordering_vgpr = None
        state.register_interference_perfect_elimination_ordering_sgpr = None

    @staticmethod
    def __compute_perfect_elimination_ordering(interference: Dict[Gpr, OneRegisterInterference]) -> Optional[List[Gpr]]:
        labels = dict()  # type: Dict[Gpr, int]
        for base_gpr in interference:
            assert base_gpr not in labels
            labels[base_gpr] = 0

        perfect_elimination_ordering = []  # type: List[Gpr]
        while labels:
            max_label_value = max(labels.values())
            max_label_base_gpr = [base_gpr for base_gpr, label_value in labels.items() if label_value == max_label_value][0]

            perfect_elimination_ordering.insert(0, max_label_base_gpr)
            labels.pop(max_label_base_gpr)

            for neighbor in sorted(interference[max_label_base_gpr].conflicts):
                assert not neighbor.is_view
                if neighbor in labels:
                    labels[neighbor] += 1

        assert len(perfect_elimination_ordering) == len(interference)

        # Check whether the sequence is indeed a perfect elimination ordering
        # See: https://www.dazhuanlan.com/2019/11/10/5dc7f1f735a82/
        is_chordal_graph = True
        for idx, base_gpr in enumerate(perfect_elimination_ordering):
            following_neighbors = [x for x in perfect_elimination_ordering[idx+1:] if x in interference[base_gpr].conflicts]
            if len(following_neighbors) > 1:
                car = following_neighbors[0]
                cdr = following_neighbors[1:]
                for x in cdr:
                    if car not in interference[x].conflicts:
                        is_chordal_graph = False

        if is_chordal_graph:
            return perfect_elimination_ordering
        else:
            return perfect_elimination_ordering  # None  # TODO: return it even not chordal graph

    def run(self, program) -> bool:
        # Check that basic blocks have been divided
        assert program.optimizer_state.divide_basic_block is not None
        all_basic_block_list = program.optimizer_state.divide_basic_block.basic_blocks  # type: List[BasicBlock]

        # Check that Gpr life span has been computed
        for bb in all_basic_block_list:
            assert bb.live_var_gpr_life_span is not None
            for base_gpr, bitmap_list in bb.live_var_gpr_life_span.items():  # type: (Gpr, List[int])
                assert isinstance(base_gpr, Gpr)
                assert isinstance(bitmap_list, list)
                assert not base_gpr.is_view
                assert len(bitmap_list) == base_gpr.count

        register_interference_vgpr = dict()  # type: Dict[Gpr, OneRegisterInterference]
        register_interference_sgpr = dict()  # type: Dict[Gpr, OneRegisterInterference]

        def compute_interference_rtype(result: Dict[Gpr, OneRegisterInterference], rtype: GprType):
            # Scan every basic blocks and collect Gpr interference
            for bb in all_basic_block_list:
                clauses_and_jump_instr_count = len(bb.clauses)
                if bb.jump_instr is not None:
                    clauses_and_jump_instr_count += 1

                for bit_at in range(0, 2*clauses_and_jump_instr_count+2):
                    conflict_set = set()  # type: Set[Tuple[Gpr, int]]  # (base_gpr, offset)
                    for base_gpr, bitmap_list in bb.live_var_gpr_life_span.items():  # type: (Gpr, List[int])
                        assert not base_gpr.is_view
                        if base_gpr.rtype != rtype:
                            continue
                        assert len(bitmap_list) == base_gpr.count
                        for offset, bitmap in enumerate(bitmap_list):
                            if bitmap & (1 << bit_at):
                                # This (base_gpr, offset) is active at this point (bit_at)
                                conflict_set.add((base_gpr, offset))

                    # All (base_gpr, offset) in conflict_set is active at this point (bit_at),
                    # thus conflict with each other
                    for (base_gpr1, offset1) in conflict_set:
                        if base_gpr1 not in result:
                            result[base_gpr1] = OneRegisterInterference(base_gpr1)
                        for (base_gpr2, offset2) in conflict_set:
                            if base_gpr1 is not base_gpr2:
                                result[base_gpr1].conflicts[base_gpr2].add(offset2 - offset1)

            def get_forced_index(base_gpr: Gpr) -> Optional[int]:
                assert not base_gpr.is_view
                if base_gpr in program.forced_index:
                    return program.forced_index[base_gpr]
                return None

            # Special case: Pre-indexed Gprs conflicts with each other
            pre_indexed_gprs = set(gpr for gpr in result if get_forced_index(gpr) is not None)
            for base_gpr1 in pre_indexed_gprs:
                assert base_gpr1 in result
                for base_gpr2 in pre_indexed_gprs:
                    if base_gpr1 is not base_gpr2:
                        for offset1 in range(0, base_gpr1.count):
                            for offset2 in range(0, base_gpr2.count):
                                result[base_gpr1].conflicts[base_gpr2].add(offset2 - offset1)
                                result[base_gpr2].conflicts[base_gpr1].add(offset1 - offset2)

            # Special case: (quite conservative)
            #
            # If a pre-indexed Gpr r1 has interference with another Gpr r2,
            # then any other pre-indexed Gpr has interference with r2
            #
            for base_gpr1 in result:
                if get_forced_index(base_gpr1) is None:  # base_gpr1 is not pre-indexed
                    continue

                for base_gpr2 in result:
                    if (get_forced_index(base_gpr2) is None) or (base_gpr2 is base_gpr1):
                        continue

                    for gpr, offset1_list in sorted(result[base_gpr1].conflicts.items()):
                        if get_forced_index(gpr) is not None:
                            continue
                        for offset1 in offset1_list:
                            result[base_gpr2].conflicts[gpr].add(get_forced_index(base_gpr2) - (get_forced_index(base_gpr1) - offset1))
                            result[gpr].conflicts[base_gpr2].add((get_forced_index(base_gpr1) - offset1) - get_forced_index(base_gpr2))

        # Compute VGPR & SGPR interference
        # We don't need to compute interference for Special Gpr
        compute_interference_rtype(register_interference_vgpr, GprType.V)
        compute_interference_rtype(register_interference_sgpr, GprType.S)

        # Save the results
        program.optimizer_state.register_interference_vgpr = register_interference_vgpr
        program.optimizer_state.register_interference_sgpr = register_interference_sgpr

        # Check whether this is chordal graph
        program.optimizer_state.register_interference_perfect_elimination_ordering_vgpr = \
            ComputeRegisterInterferencePass.__compute_perfect_elimination_ordering(register_interference_vgpr)
        program.optimizer_state.register_interference_perfect_elimination_ordering_sgpr = \
            ComputeRegisterInterferencePass.__compute_perfect_elimination_ordering(register_interference_sgpr)

        return True
