#!/usr/bin/env python3 import logging import sys import os sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from seek import * # noqa: F403 (unable to detect undefined names) from seek.instructions.gfx906 import * # noqa: F403 (unable to detect undefined names) class SpmvProgram(Program): def __init__(self): super().__init__() def get_signature(self) -> str: """ y <= alpha*A x + beta*y :return: """ return """ void kfn_asm_naive_n( /*in*/ const double alpha, // converted to double /*in*/ const double beta, // converted to double /*in*/ const uint32_t* __restrict__ d_csr_rowptr, // size: m+1 /*in*/ const uint32_t* __restrict__ d_csr_colindex, // size: row_ptr[m], aka nnz /*in*/ const double* __restrict__ d_csr_values, // size: row_ptr[m], aka nnz /*in*/ const double* __restrict__ d_X, /*in,out*/ double* __restrict__ d_Y, /*in*/ uint32_t m, /*in*/ uint32_t n) """ def setup(self) -> None: with self.add_block("MAIN"): s_args = new(count=16, align=4) s_load_dwordx16(s_args, s_kernarg, 0, comment="load all kernargs") s_alpha = s_args[0:1].alias() s_beta = s_args[2:3].alias() s_d_csr_rowptr = s_args[4:5].alias() s_d_csr_colindex = s_args[6:7].alias() s_d_csr_values = s_args[8:9].alias() s_d_X = s_args[10:11].alias() s_d_Y = s_args[12:13].alias() s_m = s_args[14].alias() s_n = s_args[15].alias() v_r = new() v_lshl_add_u32(v_r, blockIdx.x, 6, threadIdx.x, comment="r = blockIdx.x * BlockDimX + threadIdx.x") v_cmpx_lt_u32_e64(vcc_exec, v_r, s_m, comment="(r < m)? sets exec") # Get j_at_begin, j_at_end v_j_at_begin_end = new(count=2) v_offset = new() v_lshlrev_b32(v_offset, 2, v_r, comment="offset = r * SIZEOF_U32") global_load_dwordx2(v_j_at_begin_end, v_offset, s_d_csr_rowptr, comment="j_at_begin_end = d_csr_rowptr[r:r+1]") v_j_at_begin = v_j_at_begin_end[0].alias() v_j_at_end = v_j_at_begin_end[1].alias() s_saved_exec = new(count=2, align=2) v_cmpx_lt_u32_e64(s_saved_exec, v_r, s_m, comment="(r < m)?") v_sum = new(count=2) v_mov_b32(v_sum[0], 0) v_mov_b32(v_sum[1], 0) v_j_at = new() v_mov_b32(v_j_at, v_j_at_begin, comment="j_at = j_begin") with self.add_block("LOOP"): v_cmpx_lt_u32(vcc_exec, v_j_at, v_j_at_end, comment="(j_at < j_at_end)? sets exec") s_cbranch_execz("EXIT", comment="exit loop if all threads are done") v_offset = new() # Get col, value, x v_col = new() v_lshlrev_b32(v_offset, 2, v_j_at, comment="offset = j_at * sizeof(uint32_t)") global_load_dword(v_col, v_offset, s_d_csr_colindex, comment="col = d_csr_colindex[j_at]") v_value = new(count=2) v_lshlrev_b32(v_offset, 3, v_j_at, comment="offset = j_at * sizeof(double)") global_load_dwordx2(v_value, v_offset, s_d_csr_values, comment="value = d_csr_values[j_at]") v_x = new(count=2) v_lshlrev_b32(v_offset, 3, v_col, comment="offset = col * sizeof(double)") global_load_dwordx2(v_x, v_offset, s_d_X, comment="x = d_X[col]") # Accumulate sum and loop again v_fma_f64(v_sum, v_value, v_x, v_sum, comment="sum += value * x") v_add_u32(v_j_at, 1, v_j_at, comment="j_at += 1") s_branch("LOOP") with self.add_block("EXIT"): s_mov_b64(exec, s_saved_exec, comment="restore exec") v_y = new(count=2) v_y_offset = new() v_lshlrev_b32(v_y_offset, 3, v_r, comment="y_offset = r * sizeof(double)") global_load_dwordx2(v_y, v_y_offset, s_d_Y, comment="y = d_Y[r]") v_mul_f64(v_sum, s_alpha, v_sum, comment="sum *= alpha") v_fma_f64(v_sum, s_beta, v_y, v_sum, comment="sum += beta * y") global_store_dwordx2(v_y_offset, v_sum, s_d_Y, comment="d_Y[r] = sum") s_endpgm() if __name__ == "__main__": SpmvProgram().compile(log_level=logging.INFO)