#!/usr/bin/env python3
import logging
import sys
import os

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from seek import *  # noqa: F403 (unable to detect undefined names)
from seek.instructions.gfx906 import *  # noqa: F403 (unable to detect undefined names)


class SpmvProgram(Program):
    def __init__(self):
        super().__init__()

    def get_signature(self) -> str:
        """
        y <= alpha*A x + beta*y
        :return:
        """
        return """
            void kfn_asm_naive_n(
                /*in*/ const double alpha,                              // converted to double
                /*in*/ const double beta,                               // converted to double
                /*in*/ const uint32_t* __restrict__ d_csr_rowptr,       // size: m+1
                /*in*/ const uint32_t* __restrict__ d_csr_colindex,     // size: row_ptr[m], aka nnz
                /*in*/ const double* __restrict__ d_csr_values,         // size: row_ptr[m], aka nnz
                /*in*/ const double* __restrict__ d_X,
                /*in,out*/ double* __restrict__ d_Y,
                /*in*/ uint32_t m,
                /*in*/ uint32_t n)
            """

    def setup(self) -> None:
        with self.add_block("MAIN"):
            s_args = new(count=16, align=4)
            s_load_dwordx16(s_args, s_kernarg, 0, comment="load all kernargs")
            s_alpha = s_args[0:1].alias()
            s_beta = s_args[2:3].alias()
            s_d_csr_rowptr = s_args[4:5].alias()
            s_d_csr_colindex = s_args[6:7].alias()
            s_d_csr_values = s_args[8:9].alias()
            s_d_X = s_args[10:11].alias()
            s_d_Y = s_args[12:13].alias()
            s_m = s_args[14].alias()
            s_n = s_args[15].alias()

            v_r = new()
            v_lshl_add_u32(v_r, blockIdx.x, 6, threadIdx.x, comment="r = blockIdx.x * BlockDimX + threadIdx.x")
            v_cmpx_lt_u32_e64(vcc_exec, v_r, s_m, comment="(r < m)?  sets exec")

            # Get j_at_begin, j_at_end
            v_j_at_begin_end = new(count=2)
            v_offset = new()
            v_lshlrev_b32(v_offset, 2, v_r, comment="offset = r * SIZEOF_U32")
            global_load_dwordx2(v_j_at_begin_end, v_offset, s_d_csr_rowptr,
                                comment="j_at_begin_end = d_csr_rowptr[r:r+1]")
            v_j_at_begin = v_j_at_begin_end[0].alias()
            v_j_at_end = v_j_at_begin_end[1].alias()

            s_saved_exec = new(count=2, align=2)
            v_cmpx_lt_u32_e64(s_saved_exec, v_r, s_m, comment="(r < m)?")

            v_sum = new(count=2)
            v_mov_b32(v_sum[0], 0)
            v_mov_b32(v_sum[1], 0)

            v_j_at = new()
            v_mov_b32(v_j_at, v_j_at_begin, comment="j_at = j_begin")

        with self.add_block("LOOP"):
            v_cmpx_lt_u32(vcc_exec, v_j_at, v_j_at_end, comment="(j_at < j_at_end)?  sets exec")
            s_cbranch_execz("EXIT", comment="exit loop if all threads are done")
            v_offset = new()

            # Get col, value, x
            v_col = new()
            v_lshlrev_b32(v_offset, 2, v_j_at, comment="offset = j_at * sizeof(uint32_t)")
            global_load_dword(v_col, v_offset, s_d_csr_colindex, comment="col = d_csr_colindex[j_at]")

            v_value = new(count=2)
            v_lshlrev_b32(v_offset, 3, v_j_at, comment="offset = j_at * sizeof(double)")
            global_load_dwordx2(v_value, v_offset, s_d_csr_values, comment="value = d_csr_values[j_at]")

            v_x = new(count=2)
            v_lshlrev_b32(v_offset, 3, v_col, comment="offset = col * sizeof(double)")
            global_load_dwordx2(v_x, v_offset, s_d_X, comment="x = d_X[col]")

            # Accumulate sum and loop again
            v_fma_f64(v_sum, v_value, v_x, v_sum, comment="sum += value * x")

            v_add_u32(v_j_at, 1, v_j_at, comment="j_at += 1")
            s_branch("LOOP")

        with self.add_block("EXIT"):
            s_mov_b64(exec, s_saved_exec, comment="restore exec")

            v_y = new(count=2)
            v_y_offset = new()
            v_lshlrev_b32(v_y_offset, 3, v_r, comment="y_offset = r * sizeof(double)")
            global_load_dwordx2(v_y, v_y_offset, s_d_Y, comment="y = d_Y[r]")

            v_mul_f64(v_sum, s_alpha, v_sum, comment="sum *= alpha")
            v_fma_f64(v_sum, s_beta, v_y, v_sum, comment="sum += beta * y")

            global_store_dwordx2(v_y_offset, v_sum, s_d_Y, comment="d_Y[r] = sum")
            s_endpgm()


if __name__ == "__main__":
    SpmvProgram().compile(log_level=logging.INFO)
