from typing import Dict, Set, Tuple, List, DefaultDict, Optional
from collections import defaultdict
from ..basic.register import Gpr
from .base_pass import BasePass, PassTag
from .divide_basic_block_pass import BasicBlock
from .compute_register_interference_pass import OneRegisterInterference


class AllocateRegisterRIGPass(BasePass):
    def __init__(self, /, priority: int = PassTag.AllocateRegisterRIG.value):
        super().__init__(priority)

    def required_tags(self) -> Set[PassTag]:
        return {PassTag.ComputeRegisterInterference}

    def generated_tags(self) -> Set[PassTag]:
        return {PassTag.AllocateRegisterRIG}

    def invalidated_tags(self) -> Set[PassTag]:
        return set()

    def reset(self, program):
        state = program.optimizer_state
        if state is None:
            return

        # For AllocateRegisterRIGPass
        state.register_allocation_vgpr_by_color = None  # # type: List[List[Gpr]]
        state.register_allocation_sgpr_by_color = None  # # type: List[List[Gpr]]
        state.register_allocation_vgpr_count = None  # # type: int
        state.register_allocation_sgpr_count = None  # # type: int

    @staticmethod
    def __allocate_by_perfect_elimination_ordering(
            program,
            interference: Dict[Gpr, OneRegisterInterference],
            perfect_elimination_ordering: List[Gpr]) -> List[List[Gpr]]:
        """
        See also: https://blog.csdn.net/corsica6/article/details/88979383
        """
        base_gpr_colors = dict()  # type: Dict[Gpr, int]
        for idx in range(len(perfect_elimination_ordering)-1, -1, -1):
            base_gpr = perfect_elimination_ordering[idx]
            if base_gpr not in base_gpr_colors:
                neighbors = set(interference[base_gpr].conflicts.keys())  # type: Set[Gpr]
                neighbor_colors = set(base_gpr_colors[x] for x in neighbors if x in base_gpr_colors)  # type: Set[int]
                color = 0
                while color in neighbor_colors:
                    color += 1
                base_gpr_colors[base_gpr] = color

        by_colors = []  # type: List[List[Gpr]]
        while True:
            curr_color_gprs = set(gpr for (gpr, c) in base_gpr_colors.items() if c == len(by_colors))
            if not curr_color_gprs:
                break
            by_colors.append(list(sorted(curr_color_gprs, key=repr)))

        alloc_index_set = set()  # type: Set[int]

        def set_gpr_list_index(gpr_list: List[Gpr], index: int, count: int):
            # Mark these indexes used in alloc_index_set
            assert not alloc_index_set.intersection(set(range(index, index+count)))
            alloc_index_set.update(range(index, index+count))

            for gpr in gpr_list:
                assert not gpr.is_view
                assert count >= gpr.count
                assert index % gpr.align.divisor == gpr.align.remainder, gpr
                if gpr in program.forced_index:
                    assert program.forced_index[gpr] == index
                else:
                    assert gpr not in program.assigned_index
                    program.assigned_index[gpr] = index

        # First, allocate all pre-indexed Gprs
        for gpr_list in by_colors:
            # Compute the max count
            count = max(gpr.count for gpr in gpr_list)

            # Deal with pre-indexed Gpr, if exists
            pre_indexed_gpr_list = [gpr for gpr in gpr_list if gpr in program.forced_index]
            assert len(pre_indexed_gpr_list) <= 1

            if len(pre_indexed_gpr_list) == 1:
                index = program.forced_index[pre_indexed_gpr_list[0]]

                # Assign index to Gprs in gpr_list
                set_gpr_list_index(gpr_list, index, count)

        # Then, allocate other Gprs
        for gpr_list in by_colors:
            # Skip pre-indexed Gpr
            if [gpr for gpr in gpr_list if gpr in program.forced_index]:
                continue

            # Compute the max count, align
            count = max(gpr.count for gpr in gpr_list)
            align = max(gpr.align.divisor for gpr in gpr_list)  # TODO: deal with remainder!
            assert align in (1, 2, 4, 8, 16, 32, 64, 128)

            index = 0
            while alloc_index_set.intersection(set(range(index, index+count))):
                index += align

            # Assign index to Gprs in gpr_list
            set_gpr_list_index(gpr_list, index, count)

        return by_colors

    def run(self, program) -> bool:
        # Check that basic blocks have been divided
        assert program.optimizer_state.divide_basic_block is not None

        # Check that register interference graph and perfect elimination ordering has been computed
        assert program.optimizer_state.register_interference_perfect_elimination_ordering_vgpr is not None, \
            "You haven't run RegisterInterferencePass or VGpr interference graph is not a chordal graph"
        assert program.optimizer_state.register_interference_perfect_elimination_ordering_sgpr is not None, \
            "You haven't run RegisterInterferencePass or SGpr interference graph is not a chordal graph"

        program.optimizer_state.register_allocation_vgpr_by_color = \
            AllocateRegisterRIGPass.__allocate_by_perfect_elimination_ordering(
                program,
                program.optimizer_state.register_interference_vgpr,
                program.optimizer_state.register_interference_perfect_elimination_ordering_vgpr)

        program.optimizer_state.register_allocation_sgpr_by_color = \
            AllocateRegisterRIGPass.__allocate_by_perfect_elimination_ordering(
                program,
                program.optimizer_state.register_interference_sgpr,
                program.optimizer_state.register_interference_perfect_elimination_ordering_sgpr)

        def get_gpr_index(gpr):
            if gpr in program.forced_index:
                return program.forced_index[gpr]
            else:
                assert gpr in program.assigned_index
                return program.assigned_index[gpr]

        # Compute used VGPR count
        program.optimizer_state.register_allocation_vgpr_count = 0
        for gpr_list in program.optimizer_state.register_allocation_vgpr_by_color:
            program.optimizer_state.register_allocation_vgpr_count = max(
                program.optimizer_state.register_allocation_vgpr_count,
                max((get_gpr_index(gpr)+gpr.count) for gpr in gpr_list))

        # Compute used SGPR count
        program.optimizer_state.register_allocation_sgpr_count = 0
        for gpr_list in program.optimizer_state.register_allocation_sgpr_by_color:
            program.optimizer_state.register_allocation_sgpr_count = max(
                program.optimizer_state.register_allocation_sgpr_count,
                max((get_gpr_index(gpr)+gpr.count) for gpr in gpr_list))

        return True
