#!/usr/bin/env python3
import logging
import sys
import os

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from seek import *  # noqa: F403 (unable to detect undefined names)
from seek.instructions.gfx906 import *  # noqa: F403 (unable to detect undefined names)


class VectorAddProgram(Program):
    """
    void vector_add(const int* a, const int* b, int* sum)
    {
        sum[threadIdx.x] = a[threadIdx.x] + b[threadIdx.x]
    }
    """
    def __init__(self):
        super().__init__()

    def get_signature(self) -> str:
        return """
            __global__
            void vector_add(const int* a, const int* b, int* sum)
            """

    def setup(self):
        with self.add_block("MAIN") as block_main:
            s_args = new(count=6, align=4)
            s_load_dwordx4(s_args[0:3], s_kernarg, 0)
            s_load_dwordx2(s_args[4:5], s_kernarg, 16)

            s_a_ptr = s_args[0:1].alias()
            s_b_ptr = s_args[2:3].alias()
            s_sum_ptr = s_args[4:5].alias()

            v_offset = new()
            v_lshlrev_b32(v_offset, 2, threadIdx.x, comment="offset = threadIdx.x * sizeof(int)")

            v_a = new()
            v_b = new()
            v_sum = new()
            global_load_dword(v_a, v_offset, s_a_ptr, comment="a = a_ptr[threadIdx.x]")
            global_load_dword(v_b, v_offset, s_b_ptr, comment="b = b_ptr[threadIdx.x]")

            v_add_i32(v_sum, v_a, v_b, comment="sum = a + b")
            global_store_dword(v_offset, v_sum, s_sum_ptr, comment="sum_ptr[threadIdx.x] = sum")

            s_endpgm()


if __name__ == "__main__":
    VectorAddProgram().compile(log_level=logging.INFO, code_object_version=3)
