#!/usr/bin/env python3
import logging
import sys
import os

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from seek import *  # noqa: F403 (unable to detect undefined names)
from seek.instructions.gfx906 import *  # noqa: F403 (unable to detect undefined names)


def s_add_u64_u32(sdst64, ssrc32, comment=None):
    s_add_u32(sdst64[0], ssrc32, sdst64[0])
    s_addc_u32(sdst64[1], 0, sdst64[1], comment=comment)


def v_div_f32(vdst, v0, v2):
    """
    __global__
    void kfn_div_f32(float* a, float* b, float* c)
    {
        a[threadIdx.x] = b[threadIdx.x] / c[threadIdx.x];
    }


    s_load_dwordx2 s[4:5], s[0:1], 0x10     // load from kernargs: float* c@s[4:5]
    s_load_dwordx4 s[0:3], s[0:1], 0x0      // load from kernargs: float* a@s[0:1], float* b@s[2:3]
    v_lshlrev_b32 v4, 2, v0                 // v4 := threadIdx.x * sizeof(float)
    s_waitcnt lgkmcnt(0)

    v_mov_b32 v1, s3
    v_add_co_u32 v0, vcc, s2, v4
    v_addc_co_u32 v1, vcc, 0, v1, vcc       // v[0:1] = &b[threadIdx.x]

    v_mov_b32 v3, s5
    v_add_co_u32 v2, vcc, s4, v4
    v_addc_co_u32 v3, vcc, 0, v3, vcc       // v[2:3] = &c[threadIdx.x]

    global_load_dword v2, v[2:3], off           // v2 := c[threadIdx.x]
    global_load_dword v0, v[0:1], off           // v0 := b[threadIdx.x]  // we will compute v0/v2 -> v13
    s_waitcnt vmcnt(0)

    v_div_scale_f32 v1, s[2:3], v2, v2, v0      // scale c@v2 -> v1
    v_div_scale_f32 v3, vcc   , v0, v2, v0      // scale b@v3 -> v0
    v_rcp_f32 v5, v1                            // v5 = 1.0 / v1
    s_setreg_imm32_b32 hwreg(HW_REG_MODE, 4, 2), 3  // amdhsa_float_denorm_mode_32 = 3
    v_fma_f32 v6, -v1, v5, 1.0
    v_fma_f32 v8, v6, v5, v5
    v_mul_f32 v9, v3, v8
    v_fma_f32 v7, -v1, v9, v3
    v_fma_f32 v10, v7, v8, v9
    v_fma_f32 v11, -v1, v10, v3
    s_setreg_imm32_b32 hwreg(HW_REG_MODE, 4, 2), 0  // amdhsa_float_denorm_mode_32 = 0
    v_div_fmas_f32 v12, v11, v8, v10
    v_div_fixup_f32 v13, v12, v2, v0

    v_mov_b32 v3, s1
    v_add_co_u32 v0, vcc, s0, v4
    v_addc_co_u32 v1, vcc, 0, v3, vcc       // v[0:1] := &a[threadIdx.x]

    global_store_dword v[0:1], v13, off     // a[threadIdx.x] = v13
    s_endpgm
    """
    v1 = new(GprType.V)
    v3 = new(GprType.V)
    v5 = new(GprType.V)
    v6 = new(GprType.V)
    v8 = new(GprType.V)
    v9 = new(GprType.V)
    v7 = new(GprType.V)
    v10 = new(GprType.V)
    v11 = new(GprType.V)
    v12 = new(GprType.V)
    #s2and3 = new(GprType.S, count=2, align=2)
    v_div_scale_f32(v1, vcc, v2, v2, v0)      # scale c@v2 -> v1
    v_div_scale_f32(v3, vcc, v0, v2, v0)      # scale b@v3 -> v0
    v_rcp_f32(v5, v1)                         # v5 = 1.0 / v1
    #s_setreg_imm32_b32 hwreg(HW_REG_MODE, 4, 2), 3  // amdhsa_float_denorm_mode_32 = 3
    v_fma_f32(v6, -v1, v5, 1)  # 1.0
    v_fma_f32(v8, v6, v5, v5)
    v_mul_f32(v9, v3, v8)
    v_fma_f32(v7, -v1, v9, v3)
    v_fma_f32(v10, v7, v8, v9)
    v_fma_f32(v11, -v1, v10, v3)
    #s_setreg_imm32_b32 hwreg(HW_REG_MODE, 4, 2), 0  // amdhsa_float_denorm_mode_32 = 0
    v_div_fmas_f32(v12, v11, v8, v10)
    v_div_fixup_f32(vdst, v12, v2, v0)



NQ = 7
LOG2_NUM_GROUPS = 6  # NumGroups = 64


class Tracer2dStep5and6Program(Program):
    def __init__(self):
        super().__init__()

    def get_signature(self) -> str:
        return f"""
        __global__
        static void tracer_2d_1l_step5and6_nq{NQ}_kfn(
            /*const int is, const int ie,*/
            const int je_minus_js_plus_1,
            const int je_minus_js_plus_2,
            const int je_minus_js_plus_7,
            const int delta_common,
            const int delta_common_mul_npz,
            /*const int isd, const int ied,*/ /*const int jsd, const int jed,*/
            /*const int npx, const int npy,*/ const int npz, /*const int nq,*/
            real* __restrict d_q_i,            // [out] real(isd:ied,js:je) for k in (1..npz), iq in (1..nq)
            const real* __restrict d_q,        // [in] real(isd:ied, jsd:jed, npz, nq)
            const real* __restrict d_area,     // [in] real(isd:ied, jsd:jed)
            /*const real* __restrict d_ra_y,*/     // [in] real(isd:ied,js:je) for k in (1..npz)
            const real* __restrict d_yfx,      // [in] real(isd:ied, js:je+1) for k in (1..npz)
            const real* __restrict d_fy2)      // [in] real(isd:ied,js:je+1) for k in (1..npz), iq in (1..nq)
        """

    def setup(self) -> None:

        block_main = self.add_block("MAIN")
        block_loop = self.add_block("LOOP")
        block_exit = self.add_block("EXIT")

        with block_main:
            s_args = new(count=16, align=4)
            s_load_dwordx16(s_args, s_kernarg, 0, comment="load all kernargs")
            je_minus_js_plus_1 = s_args[0].alias()
            je_minus_js_plus_2 = s_args[1].alias()
            je_minus_js_plus_7 = s_args[2].alias()
            delta_common = s_args[3].alias()
            delta_common_mul_npz = s_args[4].alias()
            npz = s_args[5].alias()
            d_q_i = s_args[6:7].alias()
            d_q = s_args[8:9].alias()
            d_area = s_args[10:11].alias()
            d_yfx = s_args[12:13].alias()
            d_fy2 = s_args[14:15].alias()

            group = blockIdx.x.alias()
            k = blockIdx.y.alias()

            v_i4 = new()
            v_lshlrev_b32(v_i4, 2, threadIdx.x, comment="v_i4 = threadIdx.x * 4")
            v_offset = new()

            # const int j_start = je_minus_js_plus_1 * group / NumGroups;
            j_start = new(GprType.S)
            s_mul_i32(j_start, je_minus_js_plus_1, group)
            s_lshr_b32(j_start, j_start, LOG2_NUM_GROUPS, comment="j_start = je_minus_js_plus_1 * group / NumGroups")

            # const int j_end = je_minus_js_plus_1 * (group+1) / NumGroups - 1;
            j_end = new(GprType.S)
            s_mul_i32(j_end, je_minus_js_plus_1, group)
            s_add_i32(j_end, je_minus_js_plus_1, j_end)
            s_lshr_b32(j_end, j_end, LOG2_NUM_GROUPS)
            s_sub_i32(j_end, j_end, 1, comment="j_end = je_minus_js_plus_1 * (group+1) / NumGroups - 1")

            # const int s_tmp3 = delta_common*(j_start + k*je_minus_js_plus_2);
            s_tmp3 = new()
            s_mul_i32(s_tmp3, k, je_minus_js_plus_2)
            s_add_i32(s_tmp3, j_start, s_tmp3)
            s_mul_i32(s_tmp3, delta_common, s_tmp3, comment="s_tmp3 = delta_common*(j_start + k*je_minus_js_plus_2)")

            # const char* curr_yfx = (const char*)d_yfx + s_tmp3;
            curr_yfx = d_yfx.alias()
            s_add_u64_u32(curr_yfx, s_tmp3)

            # real yfx_i_j_k = *(real*)(curr_yfx + v_i4);
            yfx_i_j_k = new(GprType.V)
            global_load_dword(yfx_i_j_k, v_i4, curr_yfx, comment="yfx_i_j_k = *(real*)(curr_yfx + v_i4)")

            # const char* curr_fy2 = (const char*)d_fy2 + s_tmp3;
            curr_fy2 = d_fy2.alias()
            s_add_u64_u32(curr_fy2, s_tmp3)

            # const int deltaiq_fy2 = delta_common_mul_npz*je_minus_js_plus_2;
            deltaiq_fy2 = new(GprType.S)
            s_mul_i32(deltaiq_fy2, delta_common_mul_npz, je_minus_js_plus_2)

            # real fyy_i_j_k_iq[NQ];
            fyy_i_j_k_iq = new[NQ](GprType.V)

            # v_offset = v_i4;
            v_mov_b32(v_offset, v_i4)

            # for (int iq = 0; iq < NQ; ++iq) {
            #     fyy_i_j_k_iq[iq] = *(real*)(curr_fy2 + v_offset);
            #     v_offset += deltaiq_fy2;
            # }
            for iq in range(0, NQ):
                global_load_dword(fyy_i_j_k_iq[iq], v_offset, curr_fy2, comment=f"load fyy_i_j_k_iq[{iq}]")
                if iq < NQ - 1:
                    v_add_u32(v_offset, deltaiq_fy2, v_offset, comment="v_offset += deltaiq_fy2")

            # const int s_tmp5 = delta_common*(j_start+3);
            s_tmp5 = new()
            s_add_i32(s_tmp5, j_start, 3)
            s_mul_i32(s_tmp5, delta_common, s_tmp5, comment="s_tmp5 = delta_common*(j_start+3)")

            # const char* curr_area = (const char*)d_area + s_tmp5;
            curr_area = d_area.alias()
            s_add_u64_u32(curr_area, s_tmp5)

            # const int s_tmp6 = delta_common*(j_start + k*je_minus_js_plus_1);
            s_tmp6 = new()
            s_mul_i32(s_tmp6, k, je_minus_js_plus_1)
            s_add_i32(s_tmp6, j_start, s_tmp6)
            s_mul_i32(s_tmp6, delta_common, s_tmp6, comment="s_tmp6 = delta_common*(j_start + k*je_minus_js_plus_1)")

            # char* curr_q_i = (char*)d_q_i + s_tmp6;
            curr_q_i = d_q_i.alias()
            s_add_u64_u32(curr_q_i, s_tmp6)

            # const int deltaiq_q_i = delta_common_mul_npz*je_minus_js_plus_1;
            deltaiq_q_i = new(GprType.S)
            s_mul_i32(deltaiq_q_i, delta_common_mul_npz, je_minus_js_plus_1)

            # const int s_tmp8 = s_tmp5 + delta_common*(k*je_minus_js_plus_7);
            s_tmp8 = new()
            s_mul_i32(s_tmp8, k, je_minus_js_plus_7)
            s_mul_i32(s_tmp8, delta_common, s_tmp8)
            s_add_i32(s_tmp8, s_tmp5, s_tmp8, comment="s_tmp8 = s_tmp5 + delta_common*(k*je_minus_js_plus_7)")

            # const char* curr_q = (const char*)d_q + s_tmp8;
            curr_q = d_q.alias()
            s_add_u64_u32(curr_q, s_tmp8)

            # const int deltaiq_q = delta_common_mul_npz*je_minus_js_plus_7;
            deltaiq_q = new(GprType.S)
            s_mul_i32(deltaiq_q, delta_common_mul_npz, je_minus_js_plus_7)

            # for (int iq = 0; iq < NQ; ++iq) {
            #     fyy_i_j_k_iq[iq] = yfx_i_j_k * fyy_i_j_k_iq[iq];
            # }
            for iq in range(0, NQ):
                v_mul_f32(fyy_i_j_k_iq[iq], yfx_i_j_k, fyy_i_j_k_iq[iq])

            j = j_start.alias()

        with block_loop:
            s_cmp_le_i32(j, j_end, comment="(j <= j_end) ?")
            s_cbranch_scc0(block_exit)

            # const real area_i_j = *(real*)(curr_area + v_i4);
            area_i_j = new(GprType.V)
            global_load_dword(area_i_j, v_i4, curr_area)

            # curr_area += delta_common
            s_add_u64_u32(curr_area, delta_common)

            # curr_yfx += delta_common;
            s_add_u64_u32(curr_yfx, delta_common)

            # const real yfx_i_jplus1_k = *(real*)(curr_yfx + v_i4);
            yfx_i_jplus1_k = new(GprType.V)
            global_load_dword(yfx_i_jplus1_k, v_i4, curr_yfx)

            # real q[NQ];
            q = new[NQ](GprType.V)

            # v_offset = v_i4;
            v_mov_b32(v_offset, v_i4)

            # real fyy_i_jplus1_k_iq[NQ];
            fyy_i_jplus1_k_iq = new[NQ](GprType.V)

            # curr_fy2 += delta_common;
            s_add_u64_u32(curr_fy2, delta_common)

            # int v_offset2 = v_i4;
            v_offset2 = new()
            v_mov_b32(v_offset2, v_i4)

            # for (int iq = 0; iq < NQ; ++iq) {
            #     q[iq] = *(real*)(curr_q + v_offset);
            #     fyy_i_jplus1_k_iq[iq] = *(real*)(curr_fy2 + v_offset2);
            #     v_offset += deltaiq_q;
            #     v_offset2 += deltaiq_fy2;
            # }
            for iq in range(0, NQ):
                global_load_dword(q[iq], v_offset, curr_q)
                global_load_dword(fyy_i_jplus1_k_iq[iq], v_offset2, curr_fy2)
                if iq < NQ - 1:
                    v_add_u32(v_offset, deltaiq_q, v_offset)
                    v_add_u32(v_offset2, deltaiq_fy2, v_offset2)

            # curr_q += delta_common;
            s_add_u64_u32(curr_q, delta_common)

            # const real ra_y_i_j_k = area_i_j + yfx_i_j_k - yfx_i_jplus1_k;
            ra_y_i_j_k = new(GprType.V)
            v_add_f32(ra_y_i_j_k, area_i_j, yfx_i_j_k)
            v_sub_f32(ra_y_i_j_k, ra_y_i_j_k, yfx_i_jplus1_k,
                      comment="ra_y_i_j_k = area_i_j + yfx_i_j_k - yfx_i_jplus1_k")

            # v_offset = v_i4;
            v_mov_b32(v_offset, v_i4)

            # for (int iq = 0; iq < NQ; ++iq) {
            #     real tmp = q[iq]*area_i_j;
            #     fyy_i_jplus1_k_iq[iq] = yfx_i_jplus1_k * fyy_i_jplus1_k_iq[iq];
            #     *(real*)(curr_q_i + v_offset) = (tmp + fyy_i_j_k_iq[iq] - fyy_i_jplus1_k_iq[iq]) / ra_y_i_j_k;
            #     v_offset += deltaiq_q_i;
            # }
            tmp = new(GprType.V)
            tmp_dst = new(GprType.V)
            for iq in range(0, NQ):
                v_mul_f32(tmp, q[iq], area_i_j, comment="tmp = q[iq]*area_i_j")
                v_add_f32(tmp, tmp, fyy_i_j_k_iq[iq], comment="tmp = q[iq]*area_i+fyy_i_j_k_iq[iq]")
                v_mul_f32(fyy_i_jplus1_k_iq[iq], yfx_i_jplus1_k, fyy_i_jplus1_k_iq[iq],
                          comment="fyy_i_jplus1_k_iq[iq] = yfx_i_jplus1_k * fyy_i_jplus1_k_iq[iq]")
                v_sub_f32(tmp, tmp, fyy_i_jplus1_k_iq[iq],
                          comment="tmp = q[iq]*area_i_j + fyy_i_j_k_iq[iq] - fyy_i_jplus1_k_iq[iq]")
                # tmp_dst := tmp / ra_y_i_j_k
                v_div_f32(tmp_dst, tmp, ra_y_i_j_k)
                if iq < NQ - 1:
                    global_store_dword(v_offset, tmp_dst, curr_q_i)
                    v_add_u32(v_offset, deltaiq_q_i, v_offset)
                else:  # iq == NQ-1
                    global_store_dword(v_offset, tmp_dst, curr_q_i, mem_token_object="last_write_token")

            # curr_q_i += delta_common;
            s_add_u64_u32(curr_q_i, delta_common)

            # yfx_i_j_k = yfx_i_jplus1_k;
            v_mov_b32(yfx_i_j_k, yfx_i_jplus1_k)

            # for (int iq = 0; iq < NQ; ++iq)
            #     fyy_i_j_k_iq[iq] = fyy_i_jplus1_k_iq[iq];
            for iq in range(0, NQ):
                v_mov_b32(fyy_i_j_k_iq[iq], fyy_i_jplus1_k_iq[iq])

            s_add_i32(j, j, 1, comment="++j")
            explicit_wait("last_write_token")
            s_branch(block_loop)

        with block_exit:
            # Nothing more to do
            s_endpgm()


if __name__ == "__main__":
    Tracer2dStep5and6Program().compile(log_level=logging.INFO)
