from typing import List, Optional, TextIO, Set, Dict, Union, DefaultDict, Tuple
from ..basic.utility import IndentedWriter
from ..basic.exception import SeekException, check, SeekTODOException
from ..basic.register import Gpr, blockIdx, threadIdx, s_kernarg
from ..basic.instr import Block
from .base_pass import BasePass, OptimizerState, PassTag
from .divide_basic_block_pass import DivideBasicBlockPass, BasicBlock
from .optimize_basic_block_pass import OptimizeBasicBlockPass
from .print_basic_block_pass import PrintBasicBlockPass
from .analyze_live_var_pass import AnalyzeLiveVarPass
from .eliminate_dead_code_pass import EliminateDeadCodePass
from .annotate_clause_pass import AnnotateClausePass
from .insert_waitcnt_pass import InsertWaitcntPass
from .compute_register_interference_pass import ComputeRegisterInterferencePass
from .allocate_register_rig_pass import AllocateRegisterRIGPass
from abc import ABC
from collections import defaultdict
import sys
import logging
import textwrap
import yaml
import io


class MyDumper(yaml.Dumper):
    def increase_indent(self, flow=False, indentless=False):
        return super(MyDumper, self).increase_indent(flow, False)


class ParsedSignature:
    def parse_signature(self, signature: str) -> Tuple[str, List, int]:
        tokens = []  # type: List[str]

        # Tokenization
        signature += "\n"  # tricky trick here...
        curr_token = ""

        in_line_comment = False
        in_block_comment = False
        for idx, ch in enumerate(signature):
            prev_ch = signature[idx-1] if idx > 0 else None
            next_ch = signature[idx+1] if idx+1 < len(signature) else None

            if in_line_comment:
                if ch == '\n':
                    in_line_comment = False
            elif in_block_comment:
                if prev_ch == '*' and ch == '/':
                    in_block_comment = False
            else:  # not in comment
                if ch == '/' and next_ch == '/':
                    in_line_comment = True
                elif ch == '/' and next_ch == '*':
                    in_block_comment = True
                else:
                    if ch in " \t\r\n":
                        if curr_token:
                            tokens.append(curr_token)
                            curr_token = ""
                    elif ch in "`~!@#$%^&*()-=+;:'\",<.>/?[{]}\\|":
                        if curr_token:
                            tokens.append(curr_token)
                            curr_token = ""
                        tokens.append(str(ch))
                    else:
                        curr_token += ch

        # Now we parse the tokens (in a naive way)
        while tokens[0] in ("__global__", "static", "void", "inline"):
            tokens.pop(0)

        check(tokens, "Kernel name not given")
        kernel_name = tokens.pop(0)

        check(tokens, f"Expecting '(' after kernel name {kernel_name}")
        popped = tokens.pop(0)
        check(popped == "(", f"Invalid signature: {signature}")

        check(tokens, f"Expecting ')' at the end of signature")
        popped = tokens.pop(-1)
        check(popped == ")", f"Invalid signature: {signature}")

        def split_by_comma():
            lst = None
            template_depth = 0
            for token in tokens:
                if token == "," and template_depth == 0:
                    yield lst
                    lst = None
                else:
                    if token == "<":
                        template_depth += 1
                    elif token == ">":
                        template_depth -= 1
                    if token not in ("const", "volatile", "__restrict", "__restrict__"):
                        if lst is None:
                            lst = []
                        lst.append(token)
            assert template_depth == 0
            if lst is not None:
                yield lst

        arg_offset = 0
        kernel_arguments = []
        for arg_tokens in split_by_comma():
            assert len(arg_tokens) >= 2, arg_tokens

            arg_name = arg_tokens.pop(-1)
            result = {
                ".name": arg_name,
            }

            arg_size = None
            arg_value_kind = "by_value"
            arg_value_type = None
            arg_address_space = None

            if arg_tokens[-1] == '*':  # pointer
                arg_tokens.pop(-1)

                arg_size = 8
                arg_value_kind = "global_buffer"
                arg_value_type = None  # to be resolved later
                arg_address_space = "generic"

            if len(arg_tokens) == 1:
                if arg_tokens[0] in ["double"]:
                    arg_value_type = "f64"
                    if arg_size is None: arg_size = 8
                elif arg_tokens[0] in ["float", "real"]:
                    arg_value_type = "f32"
                    if arg_size is None: arg_size = 4
                elif arg_tokens[0] in ["int64_t"]:
                    arg_value_type = "i64"
                    if arg_size is None: arg_size = 8
                elif arg_tokens[0] in ["int", "int32_t"]:
                    arg_value_type = "i32"
                    if arg_size is None: arg_size = 4
                elif arg_tokens[0] in ["uint64_t"]:
                    arg_value_type = "u64"
                    if arg_size is None: arg_size = 8
                elif arg_tokens[0] in ["unsigned", "uint32_t"]:
                    arg_value_type = "u32"
                    if arg_size is None: arg_size = 4
                else:
                    raise SeekTODOException(arg_tokens)
            else:
                raise SeekTODOException(arg_tokens)

            # align `arg_offset` to `arg_alignment`
            assert arg_size in (1, 2, 4, 8)
            arg_alignment = arg_size
            arg_offset = (arg_offset + arg_alignment - 1) // arg_alignment * arg_alignment

            result.update({
                ".size": arg_size,
                ".offset": arg_offset,
                ".value_kind": arg_value_kind,
                ".value_type": arg_value_type,
            })
            if arg_address_space is not None:
                result.update({
                    ".address_space": arg_address_space,
                })

            kernel_arguments.append(result)

            arg_offset += arg_size

        kernarg_alignment = 8
        kernel_arguments_size = (arg_offset + kernarg_alignment - 1) // kernarg_alignment * kernarg_alignment
        return kernel_name, kernel_arguments, kernel_arguments_size

    def __init__(self, signature: str):
        signature = signature.strip()
        check(signature, f"Invalid signature: {repr(signature)}")

        kernel_name, kernel_arguments, kernel_arguments_size = self.parse_signature(signature)

        self.kernel_name = kernel_name  # type: str
        self.kernel_arguments = kernel_arguments  # type: List
        self.kernel_arguments_size = kernel_arguments_size  # type: int


class Program(ABC):
    def __init__(self):
        self.blocks = []  # type: List[Block]
        self.optimizer_state = OptimizerState()
        self.parsed_signature = None  # type: Optional[ParsedSignature]
        self.logger = None  # type: Optional[logging.Logger]
        self.forced_index = dict()  # type: Dict[Gpr, int]  # base_gpr -> index
        self.assigned_index = dict()  # type: Dict[Gpr, int]  # base_gpr -> index

    def __try_get_block(self, block_name: str) -> Optional[Block]:
        for block in self.blocks:
            if block.block_name == block_name:
                return block
        return None

    def add_block(self, block_name: str, /,
                  after: Optional[Union[Block, str]] = None,
                  before: Optional[Union[Block, str]] = None) -> Block:
        if self.__try_get_block(block_name) is not None:
            raise SeekException(f"Block {block_name} already exists")

        block = Block(block_name)
        if after is not None and before is not None:
            raise SeekException("Can't specify both `after` and `before`")
        elif after is not None:
            after = self.get_block(after)  # after could be Block or string name of a block
            index = self.blocks.index(after)  # raises ValueError if not present
            self.blocks.insert(index + 1, block)
        elif before is not None:
            before = self.get_block(before)  # before could be Block or string name of a block
            index = self.blocks.index(before)  # raises ValueError if not present
            self.blocks.insert(index, block)
        else:
            self.blocks.append(block)

        return block

    def get_block(self, /, block_or_name: Union[Block, str]) -> Block:
        if isinstance(block_or_name, Block):
            if block_or_name not in self.blocks:
                raise SeekException(f"{block_or_name} is not in current program")
            return block_or_name

        if isinstance(block_or_name, str):
            block = self.__try_get_block(block_or_name)
            if block is None:
                raise SeekException(f"Can't find block {block_or_name}")
            return block

        raise SeekException(f"Unknown parameter: {block_or_name}")

    def compile(self, file: TextIO = sys.stdout, /,
                log_level: Union[int, str] = logging.INFO,
                code_object_version: int = 2,
                lds_size: int = 0,
                force_vgpr_count: Optional[int] = None,
                force_sgpr_count: Optional[int] = None):
        # First, we reset everything
        self.__reset()

        # Then, we call setup() to adding blocks and instructions
        self.__setup(log_level)

        # OK, some more preparation before running these passes
        assert len(self.forced_index) == 0
        assert len(self.assigned_index) == 0
        self.forced_index[threadIdx.x] = 0
        self.forced_index[threadIdx.y] = 1
        self.forced_index[threadIdx.z] = 2

        self.forced_index[s_kernarg] = 0
        self.forced_index[blockIdx.x] = 2
        self.forced_index[blockIdx.y] = 3
        self.forced_index[blockIdx.z] = 4

        # Then, we call __run_passes() to run all optimization passes
        self.__run_passes()

        # Finally, we call generate() to write to given file
        self.__generate(file, code_object_version, lds_size, force_vgpr_count, force_sgpr_count)

    def __setup(self, log_level: Union[int, str]):
        # First, we acquire the signature of the kernel program
        signature = self.get_signature()
        self.parsed_signature = ParsedSignature(signature)

        # We create a logger for this program, use the kernel name as logger name
        self.logger = logging.Logger(self.parsed_signature.kernel_name)
        self.logger.setLevel(log_level)
        handler = logging.StreamHandler(sys.stderr)
        handler.setFormatter(logging.Formatter("[%(levelname)s] %(name)s: %(message)s"))
        self.logger.addHandler(handler)

        # Now call the user-provided setup() method
        self.setup()

    def __reset(self):
        # We first call user defined reset() method, just as we destruct derived class before base class
        self.reset()

        # Then, we reset everything
        self.blocks.clear()
        self.optimizer_state = OptimizerState()
        self.parsed_signature = None
        self.logger = None
        self.forced_index = dict()  # type: Dict[Gpr, int]  # base_gpr -> index
        self.assigned_index = dict()  # type: Dict[Gpr, int]  # base_gpr -> index

    def __run_passes(self):
        # Get optimization passes, and build dependency graph
        tags_done = set()  # type: Set[PassTag]
        passes_done = {}  # type: Dict[BasePass, bool]
        required_tags_to_passes = defaultdict(list)  # type: DefaultDict[PassTag, List[BasePass]]
        generated_tags_to_passes = defaultdict(list)  # type: DefaultDict[PassTag, List[BasePass]]

        for pss in self.config_passes():  # type: BasePass
            for prerequisite_tag in pss.required_tags():
                required_tags_to_passes[prerequisite_tag].append(pss)
            for generated_tag in pss.generated_tags():
                generated_tags_to_passes[generated_tag].append(pss)
            passes_done[pss] = False

            # Reset this pass anyway
            pss.reset(self)

        # Loop until nothing to do
        while True:
            ready_passes = []  # type: List[BasePass]
            for pss in passes_done:  # type: BasePass  # the order doesn't matter
                if passes_done[pss]:  # this pass is done
                    for tag in pss.required_tags():
                        assert tag in tags_done
                else:  # this pass has not run or is invalidated
                    prerequisites_ok = True
                    for tag in pss.required_tags():
                        if tag not in tags_done:
                            prerequisites_ok = False
                            break
                    if prerequisites_ok:
                        ready_passes.append(pss)

            if not ready_passes:
                # No pass is selected out to run
                break

            # Select pass with minimal priority value (highest priority) to run
            ready_passes.sort(key=lambda x: (x.priority, repr(x)))
            select_pass = ready_passes[0]  # type: BasePass

            modified = select_pass.run(self)  # type: bool
            assert modified is True or modified is False, modified
            self.logger.debug(f"{select_pass}: return {modified}")

            # Set `select_pass` done before go on.
            # If this pass modifies anything, it might be set back to not done later
            passes_done[select_pass] = True

            invalidated_tags = set(select_pass.invalidated_tags())  # type: Set[PassTag]  # the order doesn't matter

            if modified:
                for invalidated_tag in invalidated_tags:
                    if invalidated_tag in tags_done:
                        tags_done.remove(invalidated_tag)
                    for pss in required_tags_to_passes[invalidated_tag]:
                        # `pss` may be `select_pass`
                        passes_done[pss] = False
                    for pss in generated_tags_to_passes[invalidated_tag]:
                        assert pss is not select_pass  # `generated_tags` is not intersected with `invalidated_tags`
                        passes_done[pss] = False

            for generated_tag in select_pass.generated_tags():
                assert generated_tag not in invalidated_tags, \
                    f"BUG: {select_pass} generates {generated_tag} but also invalidates it"
                tags_done.add(generated_tag)

        # Now we can't select a pass to run any more
        # Make sure every pass has run
        not_done_passes = [pss for pss, done in passes_done.items() if not done]
        if not_done_passes:
            raise SeekException(f"These passes are not done: {not_done_passes}")

    def __write_headers_code_object_v2(
            self, wr,
            lds_size: int,
            origin_vgpr_count: int, origin_sgpr_count: int,
            vgpr_count: int, sgpr_count: int,
            use_blockIdx_x: bool, use_blockIdx_y: bool, use_blockIdx_z: bool,
            use_threadIdx: int):

        str_header = textwrap.dedent(f"""
            .hsa_code_object_version 2,0
            .hsa_code_object_isa 9, 0, 6, "AMD", "AMDGPU"
            .text
            .protected {self.parsed_signature.kernel_name}
            .globl {self.parsed_signature.kernel_name}
            .p2align 8
            .type {self.parsed_signature.kernel_name},@function
            .amdgpu_hsa_kernel {self.parsed_signature.kernel_name}
        
            #define ORIGIN_VGPR_COUNT {origin_vgpr_count}
            #define ORIGIN_SGPR_COUNT {origin_sgpr_count}
            #define VGPR_COUNT {vgpr_count}
            #define SGPR_COUNT {sgpr_count}
        
            {self.parsed_signature.kernel_name}:
                .amd_kernel_code_t
                    is_ptr64 = 1
                    enable_sgpr_kernarg_segment_ptr = 1
                    kernarg_segment_byte_size = {self.parsed_signature.kernel_arguments_size}  // bytes of kern args
                    workitem_vgpr_count = VGPR_COUNT  // vgprs
                    wavefront_sgpr_count = SGPR_COUNT+2  // sgprs (plus VCC)
                    compute_pgm_rsrc1_vgprs = (VGPR_COUNT-1)/4  // (vgprs-1)/4
                    compute_pgm_rsrc1_sgprs = (SGPR_COUNT+2-1)/8  // (sgprs-1)/8
                    compute_pgm_rsrc2_tidig_comp_cnt = {use_threadIdx}  // use threadIdx: 0=x 1=xy 2=xyz
                    compute_pgm_rsrc2_tgid_x_en = {int(use_blockIdx_x or use_blockIdx_y or use_blockIdx_z)}  // use blockIdx.x
                    compute_pgm_rsrc2_tgid_y_en = {int(use_blockIdx_y or use_blockIdx_z)}  // use blockIdx.y
                    compute_pgm_rsrc2_tgid_z_en = {int(use_blockIdx_z)}  // use blockIdx.z
                    workgroup_group_segment_byte_size = {lds_size}  // lds bytes
                    compute_pgm_rsrc2_user_sgpr = 2  // VCC
                    group_segment_alignment = 8
                    private_segment_alignment = 8
                .end_amd_kernel_code_t
            """)
        wr(str_header)

    def __write_headers_code_object_v3(
            self, wr,
            lds_size: int,
            origin_vgpr_count: int, origin_sgpr_count: int,
            vgpr_count: int, sgpr_count: int,
            use_blockIdx_x: bool, use_blockIdx_y: bool, use_blockIdx_z: bool,
            use_threadIdx: int):

        # Yaml for kernel arguments
        metadata = {
            "amdhsa.version": [1, 0],
            "amdhsa.kernels": [
                {
                    ".name": f"{self.parsed_signature.kernel_name}",
                    ".symbol": f"{self.parsed_signature.kernel_name}.kd",
                    ".language": "OpenCL C",
                    ".language_version": [2, 0],
                    ".args": self.parsed_signature.kernel_arguments,  # kernel_arguments is a list of dict
                    ".group_segment_fixed_size": 4096,  # TODO: what's this?
                    ".kernarg_segment_align": 8,
                    ".kernarg_segment_size": self.parsed_signature.kernel_arguments_size,
                    ".max_flat_workgroup_size": 1024,  # TODO: what's this? (Maximum flat work-group size supported by the kernel in work-items.)
                    ".private_segment_fixed_size": 0,
                    ".sgpr_count": sgpr_count,
                    ".sgpr_spill_count": 0,
                    ".vgpr_count": vgpr_count,
                    ".vgpr_spill_count": 0,
                    ".wavefront_size": 64,
                }
            ]
        }
        str_io = io.StringIO()
        yaml.dump(metadata, str_io, Dumper=MyDumper, default_flow_style=False, sort_keys=False)
        str_metadata = str_io.getvalue()

        str_header = textwrap.dedent(f"""
.amdgcn_target "amdgcn-amd-amdhsa--gfx906+sram-ecc"
.text
.protected {self.parsed_signature.kernel_name}
.globl {self.parsed_signature.kernel_name}
.p2align 8
.type {self.parsed_signature.kernel_name},@function
.section .rodata,#alloc
.p2align 6
.amdhsa_kernel {self.parsed_signature.kernel_name}
    .amdhsa_user_sgpr_kernarg_segment_ptr 1
    .amdhsa_next_free_vgpr {vgpr_count}  // vgprs
    .amdhsa_next_free_sgpr {sgpr_count}  // sgprs
    .amdhsa_group_segment_fixed_size {lds_size}  // lds bytes
    .amdhsa_private_segment_fixed_size 0
    .amdhsa_system_sgpr_workgroup_id_x {int(use_blockIdx_x or use_blockIdx_y or use_blockIdx_z)}  // use blockIdx.x
    .amdhsa_system_sgpr_workgroup_id_y {int(use_blockIdx_y or use_blockIdx_z)}  // use blockIdx.y
    .amdhsa_system_sgpr_workgroup_id_z {int(use_blockIdx_z)}  // use blockIdx.z
    .amdhsa_system_vgpr_workitem_id {use_threadIdx}  // use threadIdx: 0=x 1=xy 2=xyz
.end_amdhsa_kernel

.text

.amdgpu_metadata
---
{str_metadata}
...
.end_amdgpu_metadata

{self.parsed_signature.kernel_name}:
""")
        wr(str_header)

    def __generate(self, file: TextIO, /,
                   code_object_version: int,
                   lds_size: int = 0,
                   force_vgpr_count=None,
                   force_sgpr_count=None):
        check(code_object_version in (2, 3), f"Invalid code object version: {code_object_version}")
        check(lds_size >= 0, f"Invalid lds_size: {lds_size}")

        # `parsed_signature` should have been initialized by `__setup()`
        assert self.parsed_signature is not None

        wr = IndentedWriter(file)

        vgpr_count = self.optimizer_state.register_allocation_vgpr_count
        sgpr_count = self.optimizer_state.register_allocation_sgpr_count
        origin_vgpr_count = vgpr_count
        origin_sgpr_count = sgpr_count
        if force_vgpr_count is not None:
            check(force_vgpr_count >= vgpr_count, f"{force_vgpr_count} VGPR is not enough, {vgpr_count} VGPR required")
            vgpr_count = force_vgpr_count
        if force_sgpr_count is not None:
            check(force_sgpr_count >= sgpr_count, f"{force_sgpr_count} SGPR is not enough, {sgpr_count} SGPR required")
            sgpr_count = force_sgpr_count
        assert vgpr_count > 0
        assert sgpr_count > 0

        # Have we used blockIdx.{x,y,z}?
        use_blockIdx_x = False
        use_blockIdx_y = False
        use_blockIdx_z = False
        for gpr_list in self.optimizer_state.register_allocation_sgpr_by_color:
            for gpr in gpr_list:
                if gpr.base_gpr is blockIdx.x.base_gpr: use_blockIdx_x = True
                if gpr.base_gpr is blockIdx.y.base_gpr: use_blockIdx_y = True
                if gpr.base_gpr is blockIdx.z.base_gpr: use_blockIdx_z = True

        # Have we used threadIdx.{x,y,z}?
        use_threadIdx = 0
        for gpr_list in self.optimizer_state.register_allocation_vgpr_by_color:
            for gpr in gpr_list:
                if gpr.base_gpr is threadIdx.x.base_gpr: use_threadIdx = max(use_threadIdx, 0)
                if gpr.base_gpr is threadIdx.y.base_gpr: use_threadIdx = max(use_threadIdx, 1)
                if gpr.base_gpr is threadIdx.z.base_gpr: use_threadIdx = max(use_threadIdx, 2)

        if code_object_version == 3:
            self.__write_headers_code_object_v3(
                wr,
                lds_size,
                origin_vgpr_count, origin_sgpr_count,
                vgpr_count, sgpr_count,
                use_blockIdx_x, use_blockIdx_y, use_blockIdx_z,
                use_threadIdx)
        else:
            assert code_object_version == 2
            self.__write_headers_code_object_v2(
                wr,
                lds_size,
                origin_vgpr_count, origin_sgpr_count,
                vgpr_count, sgpr_count,
                use_blockIdx_x, use_blockIdx_y, use_blockIdx_z,
                use_threadIdx)

        wr(textwrap.dedent(f"""
            #define vcc_exec        vcc         // Just vcc, but used in v_cmpx_... for clarity

                .macro __dbg_print_vgpr V Lane
                    s_mov_b64 exec, -1
                    v_readlane_b32 s0, \V, \Lane
                    v_mov_b32 v0, s0
                    v_mov_b32 v1, 0
                    v_lshlrev_b64 v[0:1], 16, v[0:1]
                    global_load_dword v0, v[0:1], off offset:0
                    s_endpgm
                .endm

                .macro __dbg_print_sgpr S
                    s_mov_b64 exec, -1
                    v_mov_b32 v0, \S
                    v_mov_b32 v1, 0
                    v_lshlrev_b64 v[0:1], 16, v[0:1]
                    global_load_dword v0, v[0:1], off offset:0
                    s_endpgm
                .endm
        
            """))

        # Write all basic-blocks:
        for bb in self.optimizer_state.divide_basic_block.basic_blocks:  # type: BasicBlock
            clauses_and_jump_instr_count = len(bb.clauses)
            if bb.jump_instr is not None:
                clauses_and_jump_instr_count += 1

            wr(f'{bb.name}:')
            with wr.indent():
                wr(f'//')
                wr(f'// bb_predecessors: {[pred.name for pred in bb.predecessors]}')
                wr(f'//')
                wr(f'// live_var_in: {bb.live_var_in}')
                wr(f'// live_var_uses: {bb.live_var_uses}')
                wr(f"//")
                if False:  # TODO: add verbosity option
                    wr(f"// live_var_gpr_life_span:")
                    wr(f"//   /* number of instructions: {clauses_and_jump_instr_count} */")
                    wr(f"//   /* number of base Spec: {sum(1 for x in bb.live_var_gpr_life_span if x.rtype.is_special())} */")
                    wr(f"//   /* number of base SGpr: {sum(1 for x in bb.live_var_gpr_life_span if x.rtype.is_sgpr())} */")
                    wr(f"//   /* number of base VGpr: {sum(1 for x in bb.live_var_gpr_life_span if x.rtype.is_vgpr())} */")
                    for base_gpr, bitmap_list in sorted(bb.live_var_gpr_life_span.items(), key=repr):
                        def bitmap_to_string(bitmap: int):
                            s = ""
                            for idx in range(0, 2*clauses_and_jump_instr_count+2):
                                s += 'x' if (bitmap & 1) else '_'
                                bitmap >>= 1
                                if idx % 2 == 0: s += ' '
                            assert bitmap == 0
                            return s
                        wr(f"//   {base_gpr}:")
                        for idx, bitmap in enumerate(bitmap_list):
                            wr(f"//     [{idx:-2}] = {bitmap_to_string(bitmap)}")
                    wr(f"//")
                wr(f'// pending_mem_in: {bb.pending_mem_in if bb.pending_mem_in is not None else "-  // not run"}')
                wr(f'//')
                for annclause in bb.annotate_clauses:
                    annclause.generate(self, wr)
                if bb.jump_instr is not None:
                    bb.jump_instr.generate(self, wr, bb.successor_if_jump)
                wr(f'//')
                wr(f'// bb_successor_if_jump: {repr(bb.successor_if_jump.name) if bb.successor_if_jump is not None else "-"}')  # noqa E501: line too long
                wr(f'// bb_successor_if_fallthrough: {repr(bb.successor_if_fallthrough.name) if bb.successor_if_fallthrough is not None else "-"}')  # noqa E501: line too long
                wr(f'// bb_jump_instr: {bb.jump_instr.instr_name if bb.jump_instr is not None else "-"}  // {bb.control_flow_at_exit.name}')  # noqa E501: line too long
                wr(f'//')
                wr(f'// live_var_defs: {bb.live_var_defs if bb.live_var_defs is not None else "-  // not run"}')
                wr(f'// live_var_out: {bb.live_var_out if bb.live_var_out is not None else "-  // not run"}')
                wr(f'//')
                wr(f'// pending_mem_out: {bb.pending_mem_out if bb.pending_mem_out is not None else "-  // not run"}')  # noqa E501: line too long
                wr(f'//')
                wr()

        # All basic-blocks have been printed...
        # Let's print some global information
        wr("//" + "=" * 32)
        wr("//" + "SUMMARY".center(32))
        wr("//" + "=" * 32)
        state = self.optimizer_state  # type: OptimizerState

        # Print register interference if computed
        if state.register_interference_vgpr:
            register_interference_vgpr = state.register_interference_vgpr
            wr("//")
            wr("// VGpr interference:")
            for base_gpr in sorted(register_interference_vgpr.keys(), key=repr):
                wr(f"//   {register_interference_vgpr[base_gpr]}")
            wr("//")

        if state.register_interference_sgpr:
            register_interference_sgpr = state.register_interference_sgpr
            wr("//")
            wr("// SGpr interference:")
            for base_gpr in sorted(register_interference_sgpr.keys(), key=repr):
                wr(f"//   {register_interference_sgpr[base_gpr]}")
            wr("//")

        # Print register allocation if computed
        if state.register_allocation_vgpr_by_color:
            register_allocation_vgpr_by_color = state.register_allocation_vgpr_by_color
            wr("//")
            wr("// VGPR allocation by coloring:")
            for color, gpr_list in enumerate(register_allocation_vgpr_by_color):
                wr(f"//   Color #{color} (count={max(gpr.count for gpr in gpr_list)}): {gpr_list}")
            wr("//")

        if state.register_interference_sgpr:
            register_allocation_sgpr_by_color = state.register_allocation_sgpr_by_color
            wr("//")
            wr("// SGPR allocation by coloring:")
            for color, gpr_list in enumerate(register_allocation_sgpr_by_color):
                wr(f"//   Color #{color} (count={max(gpr.count for gpr in gpr_list)}): {gpr_list}")
            wr("//")

        if state.register_allocation_vgpr_count is not None:
            wr(f"// VGPR Count: {state.register_allocation_vgpr_count}")

        if state.register_allocation_sgpr_count is not None:
            wr(f"// SGPR Count: {state.register_allocation_sgpr_count}")

    def setup(self) -> None:  # pure virtual method
        raise NotImplementedError("Setup this kernel program in derived class")

    def reset(self) -> None:  # virtual method
        # By default, this does nothing
        pass

    def get_signature(self) -> str:  # pure virtual method
        raise NotImplementedError("Declare signature of the kernel program in derived class")

    # noinspection PyMethodMayBeStatic
    def config_divide_basic_block_pass(self) -> Optional[DivideBasicBlockPass]:  # virtual method
        return DivideBasicBlockPass()

    # noinspection PyMethodMayBeStatic
    def config_optimize_basic_block_pass(self) -> Optional[OptimizeBasicBlockPass]:  # virtual method
        return OptimizeBasicBlockPass()

    # noinspection PyMethodMayBeStatic
    def config_print_basic_block_pass(self) -> Optional[PrintBasicBlockPass]:  # virtual method
        # By default, we don't print basic-blocks
        return None

    # noinspection PyMethodMayBeStatic
    def config_analyze_live_var_pass(self) -> Optional[AnalyzeLiveVarPass]:  # virtual method
        return AnalyzeLiveVarPass()

    # noinspection PyMethodMayBeStatic
    def config_eliminate_dead_code_pass(self) -> Optional[EliminateDeadCodePass]:  # virtual method
        return EliminateDeadCodePass()

    # noinspection PyMethodMayBeStatic
    def config_annotate_clause_pass(self) -> Optional[AnnotateClausePass]:  # virtual method
        return AnnotateClausePass()

    # noinspection PyMethodMayBeStatic
    def config_insert_waitcnt_pass(self) -> Optional[InsertWaitcntPass]:  # virtual method
        return InsertWaitcntPass()

    # noinspection PyMethodMayBeStatic
    def config_compute_register_interference_pass(self) -> Optional[ComputeRegisterInterferencePass]:  # virtual method
        return ComputeRegisterInterferencePass()

    # noinspection PyMethodMayBeStatic
    def config_allocate_register_rig(self) -> Optional[AllocateRegisterRIGPass]:  # virtual method
        return AllocateRegisterRIGPass()

    def config_passes(self) -> List[BasePass]:  # virtual method
        passes = []

        def __add_pass_from_config(opt_pass):
            if opt_pass is not None:
                assert isinstance(opt_pass, BasePass)
                passes.append(opt_pass)

        __add_pass_from_config(self.config_divide_basic_block_pass())
        __add_pass_from_config(self.config_optimize_basic_block_pass())
        __add_pass_from_config(self.config_print_basic_block_pass())
        __add_pass_from_config(self.config_analyze_live_var_pass())
        __add_pass_from_config(self.config_eliminate_dead_code_pass())
        __add_pass_from_config(self.config_annotate_clause_pass())
        __add_pass_from_config(self.config_insert_waitcnt_pass())
        __add_pass_from_config(self.config_compute_register_interference_pass())
        __add_pass_from_config(self.config_allocate_register_rig())
        return passes
