Commit e00dc7c6 authored by catchyrime's avatar catchyrime
Browse files

Add Seek

parents
Pipeline #121 canceled with stages
[flake8]
max-line-length = 120
ignore =
# F541: f-string is missing placeholders
F541
# F405: 'xxx' may be undefined, or defined from star imports: xxx
F405
# missing whitespace after ','
E231
# missing whitespace around arithmetic operator
E226
# line break after binary operator
W504
# module level import not at top of file
E402
# Ignore IDE files
.idea/
# Ignore cache files
*.pyc
__pycache__/
The MIT License (MIT)
Copyright © 2021 Wenbin Hou
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the “Software”), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
#!/usr/bin/env python3
import logging
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from seek import * # noqa: F403 (unable to detect undefined names)
from seek.instructions.gfx906 import * # noqa: F403 (unable to detect undefined names)
class SpmvProgram(Program):
def __init__(self):
super().__init__()
def get_signature(self) -> str:
"""
y <= alpha*A x + beta*y
:return:
"""
return """
void kfn_asm_naive_n(
/*in*/ const double alpha, // converted to double
/*in*/ const double beta, // converted to double
/*in*/ const uint32_t* __restrict__ d_csr_rowptr, // size: m+1
/*in*/ const uint32_t* __restrict__ d_csr_colindex, // size: row_ptr[m], aka nnz
/*in*/ const double* __restrict__ d_csr_values, // size: row_ptr[m], aka nnz
/*in*/ const double* __restrict__ d_X,
/*in,out*/ double* __restrict__ d_Y,
/*in*/ uint32_t m,
/*in*/ uint32_t n)
"""
def setup(self) -> None:
with self.add_block("MAIN"):
s_args = new(count=16, align=4)
s_load_dwordx16(s_args, s_kernarg, 0, comment="load all kernargs")
s_alpha = s_args[0:1].alias()
s_beta = s_args[2:3].alias()
s_d_csr_rowptr = s_args[4:5].alias()
s_d_csr_colindex = s_args[6:7].alias()
s_d_csr_values = s_args[8:9].alias()
s_d_X = s_args[10:11].alias()
s_d_Y = s_args[12:13].alias()
s_m = s_args[14].alias()
s_n = s_args[15].alias()
v_r = new()
v_lshl_add_u32(v_r, blockIdx.x, 6, threadIdx.x, comment="r = blockIdx.x * BlockDimX + threadIdx.x")
v_cmpx_lt_u32_e64(vcc_exec, v_r, s_m, comment="(r < m)? sets exec")
# Get j_at_begin, j_at_end
v_j_at_begin_end = new(count=2)
v_offset = new()
v_lshlrev_b32(v_offset, 2, v_r, comment="offset = r * SIZEOF_U32")
global_load_dwordx2(v_j_at_begin_end, v_offset, s_d_csr_rowptr,
comment="j_at_begin_end = d_csr_rowptr[r:r+1]")
v_j_at_begin = v_j_at_begin_end[0].alias()
v_j_at_end = v_j_at_begin_end[1].alias()
s_saved_exec = new(count=2, align=2)
v_cmpx_lt_u32_e64(s_saved_exec, v_r, s_m, comment="(r < m)?")
v_sum = new(count=2)
v_mov_b32(v_sum[0], 0)
v_mov_b32(v_sum[1], 0)
v_j_at = new()
v_mov_b32(v_j_at, v_j_at_begin, comment="j_at = j_begin")
with self.add_block("LOOP"):
v_cmpx_lt_u32(vcc_exec, v_j_at, v_j_at_end, comment="(j_at < j_at_end)? sets exec")
s_cbranch_execz("EXIT", comment="exit loop if all threads are done")
v_offset = new()
# Get col, value, x
v_col = new()
v_lshlrev_b32(v_offset, 2, v_j_at, comment="offset = j_at * sizeof(uint32_t)")
global_load_dword(v_col, v_offset, s_d_csr_colindex, comment="col = d_csr_colindex[j_at]")
v_value = new(count=2)
v_lshlrev_b32(v_offset, 3, v_j_at, comment="offset = j_at * sizeof(double)")
global_load_dwordx2(v_value, v_offset, s_d_csr_values, comment="value = d_csr_values[j_at]")
v_x = new(count=2)
v_lshlrev_b32(v_offset, 3, v_col, comment="offset = col * sizeof(double)")
global_load_dwordx2(v_x, v_offset, s_d_X, comment="x = d_X[col]")
# Accumulate sum and loop again
v_fma_f64(v_sum, v_value, v_x, v_sum, comment="sum += value * x")
v_add_u32(v_j_at, 1, v_j_at, comment="j_at += 1")
s_branch("LOOP")
with self.add_block("EXIT"):
s_mov_b64(exec, s_saved_exec, comment="restore exec")
v_y = new(count=2)
v_y_offset = new()
v_lshlrev_b32(v_y_offset, 3, v_r, comment="y_offset = r * sizeof(double)")
global_load_dwordx2(v_y, v_y_offset, s_d_Y, comment="y = d_Y[r]")
v_mul_f64(v_sum, s_alpha, v_sum, comment="sum *= alpha")
v_fma_f64(v_sum, s_beta, v_y, v_sum, comment="sum += beta * y")
global_store_dwordx2(v_y_offset, v_sum, s_d_Y, comment="d_Y[r] = sum")
s_endpgm()
if __name__ == "__main__":
SpmvProgram().compile(log_level=logging.INFO)
#!/usr/bin/env python3
import logging
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from seek import * # noqa: F403 (unable to detect undefined names)
from seek.instructions.gfx906 import * # noqa: F403 (unable to detect undefined names)
def s_add_u64_u32(sdst64, ssrc32, comment=None):
s_add_u32(sdst64[0], ssrc32, sdst64[0])
s_addc_u32(sdst64[1], 0, sdst64[1], comment=comment)
def v_div_f32(vdst, v0, v2):
"""
__global__
void kfn_div_f32(float* a, float* b, float* c)
{
a[threadIdx.x] = b[threadIdx.x] / c[threadIdx.x];
}
s_load_dwordx2 s[4:5], s[0:1], 0x10 // load from kernargs: float* c@s[4:5]
s_load_dwordx4 s[0:3], s[0:1], 0x0 // load from kernargs: float* a@s[0:1], float* b@s[2:3]
v_lshlrev_b32 v4, 2, v0 // v4 := threadIdx.x * sizeof(float)
s_waitcnt lgkmcnt(0)
v_mov_b32 v1, s3
v_add_co_u32 v0, vcc, s2, v4
v_addc_co_u32 v1, vcc, 0, v1, vcc // v[0:1] = &b[threadIdx.x]
v_mov_b32 v3, s5
v_add_co_u32 v2, vcc, s4, v4
v_addc_co_u32 v3, vcc, 0, v3, vcc // v[2:3] = &c[threadIdx.x]
global_load_dword v2, v[2:3], off // v2 := c[threadIdx.x]
global_load_dword v0, v[0:1], off // v0 := b[threadIdx.x] // we will compute v0/v2 -> v13
s_waitcnt vmcnt(0)
v_div_scale_f32 v1, s[2:3], v2, v2, v0 // scale c@v2 -> v1
v_div_scale_f32 v3, vcc , v0, v2, v0 // scale b@v3 -> v0
v_rcp_f32 v5, v1 // v5 = 1.0 / v1
s_setreg_imm32_b32 hwreg(HW_REG_MODE, 4, 2), 3 // amdhsa_float_denorm_mode_32 = 3
v_fma_f32 v6, -v1, v5, 1.0
v_fma_f32 v8, v6, v5, v5
v_mul_f32 v9, v3, v8
v_fma_f32 v7, -v1, v9, v3
v_fma_f32 v10, v7, v8, v9
v_fma_f32 v11, -v1, v10, v3
s_setreg_imm32_b32 hwreg(HW_REG_MODE, 4, 2), 0 // amdhsa_float_denorm_mode_32 = 0
v_div_fmas_f32 v12, v11, v8, v10
v_div_fixup_f32 v13, v12, v2, v0
v_mov_b32 v3, s1
v_add_co_u32 v0, vcc, s0, v4
v_addc_co_u32 v1, vcc, 0, v3, vcc // v[0:1] := &a[threadIdx.x]
global_store_dword v[0:1], v13, off // a[threadIdx.x] = v13
s_endpgm
"""
v1 = new(GprType.V)
v3 = new(GprType.V)
v5 = new(GprType.V)
v6 = new(GprType.V)
v8 = new(GprType.V)
v9 = new(GprType.V)
v7 = new(GprType.V)
v10 = new(GprType.V)
v11 = new(GprType.V)
v12 = new(GprType.V)
#s2and3 = new(GprType.S, count=2, align=2)
v_div_scale_f32(v1, vcc, v2, v2, v0) # scale c@v2 -> v1
v_div_scale_f32(v3, vcc, v0, v2, v0) # scale b@v3 -> v0
v_rcp_f32(v5, v1) # v5 = 1.0 / v1
#s_setreg_imm32_b32 hwreg(HW_REG_MODE, 4, 2), 3 // amdhsa_float_denorm_mode_32 = 3
v_fma_f32(v6, -v1, v5, 1) # 1.0
v_fma_f32(v8, v6, v5, v5)
v_mul_f32(v9, v3, v8)
v_fma_f32(v7, -v1, v9, v3)
v_fma_f32(v10, v7, v8, v9)
v_fma_f32(v11, -v1, v10, v3)
#s_setreg_imm32_b32 hwreg(HW_REG_MODE, 4, 2), 0 // amdhsa_float_denorm_mode_32 = 0
v_div_fmas_f32(v12, v11, v8, v10)
v_div_fixup_f32(vdst, v12, v2, v0)
NQ = 7
LOG2_NUM_GROUPS = 6 # NumGroups = 64
class Tracer2dStep5and6Program(Program):
def __init__(self):
super().__init__()
def get_signature(self) -> str:
return f"""
__global__
static void tracer_2d_1l_step5and6_nq{NQ}_kfn(
/*const int is, const int ie,*/
const int je_minus_js_plus_1,
const int je_minus_js_plus_2,
const int je_minus_js_plus_7,
const int delta_common,
const int delta_common_mul_npz,
/*const int isd, const int ied,*/ /*const int jsd, const int jed,*/
/*const int npx, const int npy,*/ const int npz, /*const int nq,*/
real* __restrict d_q_i, // [out] real(isd:ied,js:je) for k in (1..npz), iq in (1..nq)
const real* __restrict d_q, // [in] real(isd:ied, jsd:jed, npz, nq)
const real* __restrict d_area, // [in] real(isd:ied, jsd:jed)
/*const real* __restrict d_ra_y,*/ // [in] real(isd:ied,js:je) for k in (1..npz)
const real* __restrict d_yfx, // [in] real(isd:ied, js:je+1) for k in (1..npz)
const real* __restrict d_fy2) // [in] real(isd:ied,js:je+1) for k in (1..npz), iq in (1..nq)
"""
def setup(self) -> None:
block_main = self.add_block("MAIN")
block_loop = self.add_block("LOOP")
block_exit = self.add_block("EXIT")
with block_main:
s_args = new(count=16, align=4)
s_load_dwordx16(s_args, s_kernarg, 0, comment="load all kernargs")
je_minus_js_plus_1 = s_args[0].alias()
je_minus_js_plus_2 = s_args[1].alias()
je_minus_js_plus_7 = s_args[2].alias()
delta_common = s_args[3].alias()
delta_common_mul_npz = s_args[4].alias()
npz = s_args[5].alias()
d_q_i = s_args[6:7].alias()
d_q = s_args[8:9].alias()
d_area = s_args[10:11].alias()
d_yfx = s_args[12:13].alias()
d_fy2 = s_args[14:15].alias()
group = blockIdx.x.alias()
k = blockIdx.y.alias()
v_i4 = new()
v_lshlrev_b32(v_i4, 2, threadIdx.x, comment="v_i4 = threadIdx.x * 4")
v_offset = new()
# const int j_start = je_minus_js_plus_1 * group / NumGroups;
j_start = new(GprType.S)
s_mul_i32(j_start, je_minus_js_plus_1, group)
s_lshr_b32(j_start, j_start, LOG2_NUM_GROUPS, comment="j_start = je_minus_js_plus_1 * group / NumGroups")
# const int j_end = je_minus_js_plus_1 * (group+1) / NumGroups - 1;
j_end = new(GprType.S)
s_mul_i32(j_end, je_minus_js_plus_1, group)
s_add_i32(j_end, je_minus_js_plus_1, j_end)
s_lshr_b32(j_end, j_end, LOG2_NUM_GROUPS)
s_sub_i32(j_end, j_end, 1, comment="j_end = je_minus_js_plus_1 * (group+1) / NumGroups - 1")
# const int s_tmp3 = delta_common*(j_start + k*je_minus_js_plus_2);
s_tmp3 = new()
s_mul_i32(s_tmp3, k, je_minus_js_plus_2)
s_add_i32(s_tmp3, j_start, s_tmp3)
s_mul_i32(s_tmp3, delta_common, s_tmp3, comment="s_tmp3 = delta_common*(j_start + k*je_minus_js_plus_2)")
# const char* curr_yfx = (const char*)d_yfx + s_tmp3;
curr_yfx = d_yfx.alias()
s_add_u64_u32(curr_yfx, s_tmp3)
# real yfx_i_j_k = *(real*)(curr_yfx + v_i4);
yfx_i_j_k = new(GprType.V)
global_load_dword(yfx_i_j_k, v_i4, curr_yfx, comment="yfx_i_j_k = *(real*)(curr_yfx + v_i4)")
# const char* curr_fy2 = (const char*)d_fy2 + s_tmp3;
curr_fy2 = d_fy2.alias()
s_add_u64_u32(curr_fy2, s_tmp3)
# const int deltaiq_fy2 = delta_common_mul_npz*je_minus_js_plus_2;
deltaiq_fy2 = new(GprType.S)
s_mul_i32(deltaiq_fy2, delta_common_mul_npz, je_minus_js_plus_2)
# real fyy_i_j_k_iq[NQ];
fyy_i_j_k_iq = new[NQ](GprType.V)
# v_offset = v_i4;
v_mov_b32(v_offset, v_i4)
# for (int iq = 0; iq < NQ; ++iq) {
# fyy_i_j_k_iq[iq] = *(real*)(curr_fy2 + v_offset);
# v_offset += deltaiq_fy2;
# }
for iq in range(0, NQ):
global_load_dword(fyy_i_j_k_iq[iq], v_offset, curr_fy2, comment=f"load fyy_i_j_k_iq[{iq}]")
if iq < NQ - 1:
v_add_u32(v_offset, deltaiq_fy2, v_offset, comment="v_offset += deltaiq_fy2")
# const int s_tmp5 = delta_common*(j_start+3);
s_tmp5 = new()
s_add_i32(s_tmp5, j_start, 3)
s_mul_i32(s_tmp5, delta_common, s_tmp5, comment="s_tmp5 = delta_common*(j_start+3)")
# const char* curr_area = (const char*)d_area + s_tmp5;
curr_area = d_area.alias()
s_add_u64_u32(curr_area, s_tmp5)
# const int s_tmp6 = delta_common*(j_start + k*je_minus_js_plus_1);
s_tmp6 = new()
s_mul_i32(s_tmp6, k, je_minus_js_plus_1)
s_add_i32(s_tmp6, j_start, s_tmp6)
s_mul_i32(s_tmp6, delta_common, s_tmp6, comment="s_tmp6 = delta_common*(j_start + k*je_minus_js_plus_1)")
# char* curr_q_i = (char*)d_q_i + s_tmp6;
curr_q_i = d_q_i.alias()
s_add_u64_u32(curr_q_i, s_tmp6)
# const int deltaiq_q_i = delta_common_mul_npz*je_minus_js_plus_1;
deltaiq_q_i = new(GprType.S)
s_mul_i32(deltaiq_q_i, delta_common_mul_npz, je_minus_js_plus_1)
# const int s_tmp8 = s_tmp5 + delta_common*(k*je_minus_js_plus_7);
s_tmp8 = new()
s_mul_i32(s_tmp8, k, je_minus_js_plus_7)
s_mul_i32(s_tmp8, delta_common, s_tmp8)
s_add_i32(s_tmp8, s_tmp5, s_tmp8, comment="s_tmp8 = s_tmp5 + delta_common*(k*je_minus_js_plus_7)")
# const char* curr_q = (const char*)d_q + s_tmp8;
curr_q = d_q.alias()
s_add_u64_u32(curr_q, s_tmp8)
# const int deltaiq_q = delta_common_mul_npz*je_minus_js_plus_7;
deltaiq_q = new(GprType.S)
s_mul_i32(deltaiq_q, delta_common_mul_npz, je_minus_js_plus_7)
# for (int iq = 0; iq < NQ; ++iq) {
# fyy_i_j_k_iq[iq] = yfx_i_j_k * fyy_i_j_k_iq[iq];
# }
for iq in range(0, NQ):
v_mul_f32(fyy_i_j_k_iq[iq], yfx_i_j_k, fyy_i_j_k_iq[iq])
j = j_start.alias()
with block_loop:
s_cmp_le_i32(j, j_end, comment="(j <= j_end) ?")
s_cbranch_scc0(block_exit)
# const real area_i_j = *(real*)(curr_area + v_i4);
area_i_j = new(GprType.V)
global_load_dword(area_i_j, v_i4, curr_area)
# curr_area += delta_common
s_add_u64_u32(curr_area, delta_common)
# curr_yfx += delta_common;
s_add_u64_u32(curr_yfx, delta_common)
# const real yfx_i_jplus1_k = *(real*)(curr_yfx + v_i4);
yfx_i_jplus1_k = new(GprType.V)
global_load_dword(yfx_i_jplus1_k, v_i4, curr_yfx)
# real q[NQ];
q = new[NQ](GprType.V)
# v_offset = v_i4;
v_mov_b32(v_offset, v_i4)
# real fyy_i_jplus1_k_iq[NQ];
fyy_i_jplus1_k_iq = new[NQ](GprType.V)
# curr_fy2 += delta_common;
s_add_u64_u32(curr_fy2, delta_common)
# int v_offset2 = v_i4;
v_offset2 = new()
v_mov_b32(v_offset2, v_i4)
# for (int iq = 0; iq < NQ; ++iq) {
# q[iq] = *(real*)(curr_q + v_offset);
# fyy_i_jplus1_k_iq[iq] = *(real*)(curr_fy2 + v_offset2);
# v_offset += deltaiq_q;
# v_offset2 += deltaiq_fy2;
# }
for iq in range(0, NQ):
global_load_dword(q[iq], v_offset, curr_q)
global_load_dword(fyy_i_jplus1_k_iq[iq], v_offset2, curr_fy2)
if iq < NQ - 1:
v_add_u32(v_offset, deltaiq_q, v_offset)
v_add_u32(v_offset2, deltaiq_fy2, v_offset2)
# curr_q += delta_common;
s_add_u64_u32(curr_q, delta_common)
# const real ra_y_i_j_k = area_i_j + yfx_i_j_k - yfx_i_jplus1_k;
ra_y_i_j_k = new(GprType.V)
v_add_f32(ra_y_i_j_k, area_i_j, yfx_i_j_k)
v_sub_f32(ra_y_i_j_k, ra_y_i_j_k, yfx_i_jplus1_k,
comment="ra_y_i_j_k = area_i_j + yfx_i_j_k - yfx_i_jplus1_k")
# v_offset = v_i4;
v_mov_b32(v_offset, v_i4)
# for (int iq = 0; iq < NQ; ++iq) {
# real tmp = q[iq]*area_i_j;
# fyy_i_jplus1_k_iq[iq] = yfx_i_jplus1_k * fyy_i_jplus1_k_iq[iq];
# *(real*)(curr_q_i + v_offset) = (tmp + fyy_i_j_k_iq[iq] - fyy_i_jplus1_k_iq[iq]) / ra_y_i_j_k;
# v_offset += deltaiq_q_i;
# }
tmp = new(GprType.V)
tmp_dst = new(GprType.V)
for iq in range(0, NQ):
v_mul_f32(tmp, q[iq], area_i_j, comment="tmp = q[iq]*area_i_j")
v_add_f32(tmp, tmp, fyy_i_j_k_iq[iq], comment="tmp = q[iq]*area_i+fyy_i_j_k_iq[iq]")
v_mul_f32(fyy_i_jplus1_k_iq[iq], yfx_i_jplus1_k, fyy_i_jplus1_k_iq[iq],
comment="fyy_i_jplus1_k_iq[iq] = yfx_i_jplus1_k * fyy_i_jplus1_k_iq[iq]")
v_sub_f32(tmp, tmp, fyy_i_jplus1_k_iq[iq],
comment="tmp = q[iq]*area_i_j + fyy_i_j_k_iq[iq] - fyy_i_jplus1_k_iq[iq]")
# tmp_dst := tmp / ra_y_i_j_k
v_div_f32(tmp_dst, tmp, ra_y_i_j_k)
if iq < NQ - 1:
global_store_dword(v_offset, tmp_dst, curr_q_i)
v_add_u32(v_offset, deltaiq_q_i, v_offset)
else: # iq == NQ-1
global_store_dword(v_offset, tmp_dst, curr_q_i, mem_token_object="last_write_token")
# curr_q_i += delta_common;
s_add_u64_u32(curr_q_i, delta_common)
# yfx_i_j_k = yfx_i_jplus1_k;
v_mov_b32(yfx_i_j_k, yfx_i_jplus1_k)
# for (int iq = 0; iq < NQ; ++iq)
# fyy_i_j_k_iq[iq] = fyy_i_jplus1_k_iq[iq];
for iq in range(0, NQ):
v_mov_b32(fyy_i_j_k_iq[iq], fyy_i_jplus1_k_iq[iq])
s_add_i32(j, j, 1, comment="++j")
explicit_wait("last_write_token")
s_branch(block_loop)
with block_exit:
# Nothing more to do
s_endpgm()
if __name__ == "__main__":
Tracer2dStep5and6Program().compile(log_level=logging.INFO)
#!/usr/bin/env python3
import logging
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from seek import * # noqa: F403 (unable to detect undefined names)
from seek.instructions.gfx906 import * # noqa: F403 (unable to detect undefined names)
class VectorAddProgram(Program):
"""
void vector_add(const int* a, const int* b, int* sum)
{
sum[threadIdx.x] = a[threadIdx.x] + b[threadIdx.x]
}
"""
def __init__(self):
super().__init__()
def get_signature(self) -> str:
return """
__global__
void vector_add(const int* a, const int* b, int* sum)
"""
def setup(self):
with self.add_block("MAIN") as block_main:
s_args = new(count=6, align=4)
s_load_dwordx4(s_args[0:3], s_kernarg, 0)
s_load_dwordx2(s_args[4:5], s_kernarg, 16)
s_a_ptr = s_args[0:1].alias()
s_b_ptr = s_args[2:3].alias()
s_sum_ptr = s_args[4:5].alias()
v_offset = new()
v_lshlrev_b32(v_offset, 2, threadIdx.x, comment="offset = threadIdx.x * sizeof(int)")
v_a = new()
v_b = new()
v_sum = new()
global_load_dword(v_a, v_offset, s_a_ptr, comment="a = a_ptr[threadIdx.x]")
global_load_dword(v_b, v_offset, s_b_ptr, comment="b = b_ptr[threadIdx.x]")
v_add_i32(v_sum, v_a, v_b, comment="sum = a + b")
global_store_dword(v_offset, v_sum, s_sum_ptr, comment="sum_ptr[threadIdx.x] = sum")
s_endpgm()
if __name__ == "__main__":
VectorAddProgram().compile(log_level=logging.INFO, code_object_version=3)
# We disable these warnings from flake8:
# F401: 'xxx' imported but unused
# F403: 'from xxx import *' used; unable to detect undefined names
from .basic import * # noqa: F401, F403
from .backend import * # noqa: F401, F403
# We disable these warnings from flake8:
# F401: 'xxx' imported but unused
# F403: 'from xxx import *' used; unable to detect undefined names
from .base_pass import BasePass # noqa: F401
from .divide_basic_block_pass import DivideBasicBlockPass # noqa: F401
from .optimize_basic_block_pass import OptimizeBasicBlockPass # noqa: F401
from .print_basic_block_pass import PrintBasicBlockPass # noqa: F401
from .analyze_live_var_pass import AnalyzeLiveVarPass # noqa: F401
from .eliminate_dead_code_pass import EliminateDeadCodePass # noqa: F401
from .annotate_clause_pass import AnnotateClausePass # noqa: F401
from .insert_waitcnt_pass import InsertWaitcntPass # noqa: F401
from .compute_register_interference_pass import ComputeRegisterInterferencePass # noqa: F401
from .allocate_register_rig_pass import AllocateRegisterRIGPass # noqa: F401
from .program import Program, Optional # noqa: F401
from typing import Dict, Set, Tuple, List, DefaultDict, Optional
from collections import defaultdict
from ..basic.register import Gpr
from .base_pass import BasePass, PassTag
from .divide_basic_block_pass import BasicBlock
from .compute_register_interference_pass import OneRegisterInterference
class AllocateRegisterRIGPass(BasePass):
def __init__(self, /, priority: int = PassTag.AllocateRegisterRIG.value):
super().__init__(priority)
def required_tags(self) -> Set[PassTag]:
return {PassTag.ComputeRegisterInterference}
def generated_tags(self) -> Set[PassTag]:
return {PassTag.AllocateRegisterRIG}
def invalidated_tags(self) -> Set[PassTag]:
return set()
def reset(self, program):
state = program.optimizer_state
if state is None:
return
# For AllocateRegisterRIGPass
state.register_allocation_vgpr_by_color = None # # type: List[List[Gpr]]
state.register_allocation_sgpr_by_color = None # # type: List[List[Gpr]]
state.register_allocation_vgpr_count = None # # type: int
state.register_allocation_sgpr_count = None # # type: int
@staticmethod
def __allocate_by_perfect_elimination_ordering(
program,
interference: Dict[Gpr, OneRegisterInterference],
perfect_elimination_ordering: List[Gpr]) -> List[List[Gpr]]:
"""
See also: https://blog.csdn.net/corsica6/article/details/88979383
"""
base_gpr_colors = dict() # type: Dict[Gpr, int]
for idx in range(len(perfect_elimination_ordering)-1, -1, -1):
base_gpr = perfect_elimination_ordering[idx]
if base_gpr not in base_gpr_colors:
neighbors = set(interference[base_gpr].conflicts.keys()) # type: Set[Gpr]
neighbor_colors = set(base_gpr_colors[x] for x in neighbors if x in base_gpr_colors) # type: Set[int]
color = 0
while color in neighbor_colors:
color += 1
base_gpr_colors[base_gpr] = color
by_colors = [] # type: List[List[Gpr]]
while True:
curr_color_gprs = set(gpr for (gpr, c) in base_gpr_colors.items() if c == len(by_colors))
if not curr_color_gprs:
break
by_colors.append(list(sorted(curr_color_gprs, key=repr)))
alloc_index_set = set() # type: Set[int]
def set_gpr_list_index(gpr_list: List[Gpr], index: int, count: int):
# Mark these indexes used in alloc_index_set
assert not alloc_index_set.intersection(set(range(index, index+count)))
alloc_index_set.update(range(index, index+count))
for gpr in gpr_list:
assert not gpr.is_view
assert count >= gpr.count
assert index % gpr.align.divisor == gpr.align.remainder, gpr
if gpr in program.forced_index:
assert program.forced_index[gpr] == index
else:
assert gpr not in program.assigned_index
program.assigned_index[gpr] = index
# First, allocate all pre-indexed Gprs
for gpr_list in by_colors:
# Compute the max count
count = max(gpr.count for gpr in gpr_list)
# Deal with pre-indexed Gpr, if exists
pre_indexed_gpr_list = [gpr for gpr in gpr_list if gpr in program.forced_index]
assert len(pre_indexed_gpr_list) <= 1
if len(pre_indexed_gpr_list) == 1:
index = program.forced_index[pre_indexed_gpr_list[0]]
# Assign index to Gprs in gpr_list
set_gpr_list_index(gpr_list, index, count)
# Then, allocate other Gprs
for gpr_list in by_colors:
# Skip pre-indexed Gpr
if [gpr for gpr in gpr_list if gpr in program.forced_index]:
continue
# Compute the max count, align
count = max(gpr.count for gpr in gpr_list)
align = max(gpr.align.divisor for gpr in gpr_list) # TODO: deal with remainder!
assert align in (1, 2, 4, 8, 16, 32, 64, 128)
index = 0
while alloc_index_set.intersection(set(range(index, index+count))):
index += align
# Assign index to Gprs in gpr_list
set_gpr_list_index(gpr_list, index, count)
return by_colors
def run(self, program) -> bool:
# Check that basic blocks have been divided
assert program.optimizer_state.divide_basic_block is not None
# Check that register interference graph and perfect elimination ordering has been computed
assert program.optimizer_state.register_interference_perfect_elimination_ordering_vgpr is not None, \
"You haven't run RegisterInterferencePass or VGpr interference graph is not a chordal graph"
assert program.optimizer_state.register_interference_perfect_elimination_ordering_sgpr is not None, \
"You haven't run RegisterInterferencePass or SGpr interference graph is not a chordal graph"
program.optimizer_state.register_allocation_vgpr_by_color = \
AllocateRegisterRIGPass.__allocate_by_perfect_elimination_ordering(
program,
program.optimizer_state.register_interference_vgpr,
program.optimizer_state.register_interference_perfect_elimination_ordering_vgpr)
program.optimizer_state.register_allocation_sgpr_by_color = \
AllocateRegisterRIGPass.__allocate_by_perfect_elimination_ordering(
program,
program.optimizer_state.register_interference_sgpr,
program.optimizer_state.register_interference_perfect_elimination_ordering_sgpr)
def get_gpr_index(gpr):
if gpr in program.forced_index:
return program.forced_index[gpr]
else:
assert gpr in program.assigned_index
return program.assigned_index[gpr]
# Compute used VGPR count
program.optimizer_state.register_allocation_vgpr_count = 0
for gpr_list in program.optimizer_state.register_allocation_vgpr_by_color:
program.optimizer_state.register_allocation_vgpr_count = max(
program.optimizer_state.register_allocation_vgpr_count,
max((get_gpr_index(gpr)+gpr.count) for gpr in gpr_list))
# Compute used SGPR count
program.optimizer_state.register_allocation_sgpr_count = 0
for gpr_list in program.optimizer_state.register_allocation_sgpr_by_color:
program.optimizer_state.register_allocation_sgpr_count = max(
program.optimizer_state.register_allocation_sgpr_count,
max((get_gpr_index(gpr)+gpr.count) for gpr in gpr_list))
return True
from typing import Set, List, Dict, Union
from ..basic.register import Gpr, GprSet
from ..basic.instr import InstrCall, ExplicitWaitCall, ExplicitUsesCall
from .base_pass import BasePass, PassTag, OptimizerState
from .divide_basic_block_pass import BasicBlock, BasicBlockJumpInstr
from .annotate_clause_pass import AnnClause
import itertools
class AnalyzeLiveVarPass(BasePass):
"""
Computes def/use for each basic-block, then computes in/out for each basic-block.
"""
def __init__(self, /, priority: int = PassTag.AnalyzeLiveVar.value):
super().__init__(priority)
def required_tags(self) -> Set[PassTag]:
return {PassTag.OptimizeBasicBlock}
def generated_tags(self) -> Set[PassTag]:
return {PassTag.AnalyzeLiveVar}
def invalidated_tags(self) -> Set[PassTag]:
return {PassTag.EliminateDeadCode}
def reset(self, program):
optimizer_state = program.optimizer_state # type: OptimizerState
state = optimizer_state.divide_basic_block
if state is None:
return
for bb in state.basic_blocks:
assert isinstance(bb, BasicBlock)
bb.live_var_defs = None
bb.live_var_uses = None
bb.live_var_in = None
bb.live_var_out = None
def run(self, program) -> bool:
optimizer_state = program.optimizer_state # type: OptimizerState
state = optimizer_state.divide_basic_block
assert state is not None
# Compute defs and uses for each individual basic-block
for bb in state.basic_blocks:
self.__compute_basic_block_var_defs_uses(bb)
# Compute in and out for all basic-blocks
self.__compute_all_basic_blocks_var_in_out(state.basic_blocks)
# Compute life span for each single Gpr
# NOTE: different views of the same Gpr may have different life span
for bb in state.basic_blocks:
self.__compute_basic_block_var_life_span(bb)
# This pass always updates live_var_{uses,defs,in,out} for all basic-blocks
return True
# noinspection PyMethodMayBeStatic
def __compute_basic_block_var_defs_uses(self, bb: BasicBlock):
"""
Within a basic-block:
- If a Gpr is used before defined, it's added to live_var_uses
- If a Gpr is used after defined, silently go on (It's **not** added to live_var_uses)
- If a Gpr defined, it's added to live_var_defs, no matter whether it's used or not later
"""
bb.live_var_defs = GprSet()
bb.live_var_uses = GprSet()
# Loop from the beginning to the end
# Don't forget the last `jump_instr`, if exists
for clause in itertools.chain(bb.clauses, [bb.jump_instr] if bb.jump_instr is not None else []):
if isinstance(clause, ExplicitWaitCall):
continue
if isinstance(clause, ExplicitUsesCall):
tmp_uses = GprSet(*clause.uses)
tmp_defs = GprSet()
else:
assert isinstance(clause, InstrCall) or isinstance(clause, BasicBlockJumpInstr)
tmp_uses = clause.gpr_uses_to_gprset()
tmp_defs = clause.gpr_defs_to_gprset()
tmp_uses.difference_update(bb.live_var_defs) # this is part of gpr which is not in `defs` before
bb.live_var_uses.union_update(tmp_uses) # add this part to `uses`
# No matter these Gprs are previously used or not, we add them into defs
# tmp_defs.difference_update(bb.live_var_uses) # this is part of gpr which is not in `uses` before
bb.live_var_defs.union_update(tmp_defs) # add this part to `defs`
# print(f"{bb.name} defs {bb.live_var_defs}")
# print(f"{bb.name} uses {bb.live_var_uses}")
# noinspection PyMethodMayBeStatic
def __compute_all_basic_blocks_var_in_out(self, all_basic_block_list: List[BasicBlock]):
"""
B.in = B.uses UNION (B.out DIFF B.defs)
B.out = {UNION S.in} // S is a successor of B
IN[Exit] = {} // This doesn't matter
"""
# Now compute live_var_in and live_var_out
update_queue = [] # type: List[BasicBlock]
for bb in all_basic_block_list:
bb.live_var_in = GprSet()
bb.live_var_out = GprSet()
update_queue.append(bb)
while update_queue:
bb = update_queue.pop(0)
# bb.live_var_out = bb.successor_if_jump.live_var_in UNION
# bb.successor_if_fallthrough.live_var_in
bb.live_var_out = GprSet()
if bb.successor_if_jump is not None: # `bb.successor_if_jump` might be `bb`
bb.live_var_out.union_update(bb.successor_if_jump.live_var_in)
if bb.successor_if_fallthrough is not None:
bb.live_var_out.union_update(bb.successor_if_fallthrough.live_var_in)
# `bb.live_var_out` was updated, now update `bb.live_var_in` accordingly
# bb.live_var_in = bb.live_var_uses UNION (bb.live_var_out DIFF bb.live_var_defs)
old_live_var_in = bb.live_var_in
bb.live_var_in = bb.live_var_out.clone() \
.difference_update(bb.live_var_defs) \
.union_update(bb.live_var_uses)
# If `bb.live_var_in` is indeed updated, update `bb`'s predecessors
if bb.live_var_in != old_live_var_in:
for pred in bb.predecessors:
update_queue.append(pred)
# noinspection PyMethodMayBeStatic
def __compute_basic_block_var_life_span(self, bb: BasicBlock):
# We represent life span of each (base_gpr,offset) by a non-negative integer
# Here, each (base_gpr,offset) tuple has count == 1
#
# If a basic block has N stateful instructions, we use 2*N+2 bits to represent the life span:
# bit 0: live_var_in (aka defined by prior basic blocks)
# bit 1: used by the 1st instruction
# bit 2: defined by the 1st instruction
# bit 3: used by the 2nd instruction
# bit 4: defined by the 2nd instruction
# ......
# bit 2*N-1: used by the N-th instruction
# bit 2*N: defined by the N-th instruction
# bit 2*N+1: live_var_out (aka used by following basic blocks)
#
# For live_var_gpr_life_span:
# - the key is base_gpr (not Gpr views!)
# - the value is a list of bitmap regarding all offsets (len == base_gpr.count)
#
live_var_gpr_life_span = dict() # type: Dict[Gpr, List[int]]
# var_last_def_bit_at is the last bit where a (base_gpr,offset) is defined
# The values must be in {0,2,4,...,2*N}
var_last_def_bit_at = dict() # type: Dict[(Gpr, int), int]
def mark_active(base_gpr: Gpr, offset: int, bit_at: int):
assert bit_at >= 0
assert not base_gpr.is_view
assert 0 <= offset < base_gpr.count
if base_gpr not in live_var_gpr_life_span:
live_var_gpr_life_span[base_gpr] = [0] * base_gpr.count
live_var_gpr_life_span[base_gpr][offset] |= 1 << bit_at
def process_used(base_gpr: Gpr, offset: int, bit_at: int):
assert bit_at >= 0 and bit_at % 2 == 1
assert not base_gpr.is_view
assert 0 <= offset < base_gpr.count
# Get the latest define bit_at of this (base_gpr, offset), which must exist
# Every bit in interval [last_def_bit_at,bit_at] is marked active
assert (base_gpr, offset) in var_last_def_bit_at
last_def_bit_at = var_last_def_bit_at[(base_gpr, offset)]
assert last_def_bit_at % 2 == 0
for b in range(last_def_bit_at, bit_at+1):
mark_active(base_gpr, offset, b)
def process_defined(base_gpr: Gpr, offset: int, bit_at: int):
assert bit_at >= 0 and bit_at % 2 == 0
assert not base_gpr.is_view
assert 0 <= offset < base_gpr.count
# We mark this (base_gpr, offset) active at this point, no matter whether it will be used later
mark_active(base_gpr, offset, bit_at)
# We update the latest bit_at of this (base_gpr, offset), regardless of its previous value
var_last_def_bit_at[(base_gpr, offset)] = bit_at
clauses_and_jump_instr = bb.clauses.copy() # type: List[Union[InstrCall, ExplicitWaitCall, ExplicitUsesCall, BasicBlockJumpInstr]] # shallow copy
if bb.jump_instr is not None:
clauses_and_jump_instr.append(bb.jump_instr)
# Process bit 0: live_var_in (aka defined by prior basic blocks)
for base_gpr in bb.live_var_in.base_gprs:
offset_list = bb.live_var_in.get_offset_list(base_gpr)
assert offset_list
for offset in offset_list:
process_defined(base_gpr, offset=offset, bit_at=0)
# Process all instructions (bit 1,2,3,4,...,2*N-1,2*N)
for idx, clause in enumerate(clauses_and_jump_instr): # idx in {0,1,...,N-1}
gprset_uses = clause.gpr_uses_to_gprset()
for base_gpr in gprset_uses.base_gprs:
for offset in gprset_uses.get_offset_list(base_gpr):
process_used(base_gpr, offset=offset, bit_at=2*idx+1)
gprset_defs = clause.gpr_defs_to_gprset()
for base_gpr in gprset_defs.base_gprs:
for offset in gprset_defs.get_offset_list(base_gpr):
process_defined(base_gpr, offset=offset, bit_at=2*idx+2)
# Process bit 2*N+1: live_var_out (aka used by following basic blocks)
for base_gpr in bb.live_var_out.base_gprs:
offset_list = bb.live_var_out.get_offset_list(base_gpr)
assert offset_list
for offset in offset_list:
process_used(base_gpr, offset=offset, bit_at=2*len(clauses_and_jump_instr)+1)
# Assign the result to basic block
bb.live_var_gpr_life_span = live_var_gpr_life_span
from typing import Set, Optional, Union
from ..basic.const import _INSTR_STR_WIDTH
from ..basic.instr import InstrCall, ExplicitWaitCall, ExplicitUsesCall, Waitcnt
from .base_pass import BasePass, PassTag, OptimizerState
from .divide_basic_block_pass import BasicBlock, DivideBasicBlockPassState
class AnnClause:
def __init__(self, clause: Union[InstrCall, ExplicitWaitCall, ExplicitUsesCall]):
self.clause = clause # type: Union[InstrCall, ExplicitWaitCall, ExplicitUsesCall]
# For InsertWaitcntPass
self.insert_waitcnt = None # type: Optional[Waitcnt]
def __repr__(self):
lines = []
if self.insert_waitcnt is not None:
lines.append(f"/*auto waitcnt*/ {self.insert_waitcnt}")
lines.append(repr(self.clause))
return "\n".join(lines)
def generate(self, program, wr):
if self.insert_waitcnt is not None:
wr("/*auto*/ s_waitcnt ".ljust(_INSTR_STR_WIDTH) + self.insert_waitcnt.waitcnt_str())
self.clause.generate(program, wr)
class AnnotateClausePass(BasePass):
"""
Create annotated clause (AnnClause) from the original clause blindly.
NOTE:
Anyway, the priority should make it run after EliminateDeadCodePass.
"""
def __init__(self, /, priority: int = PassTag.AnnotateClause.value):
super().__init__(priority)
def required_tags(self) -> Set[PassTag]:
return {PassTag.DivideBasicBlock}
def generated_tags(self) -> Set[PassTag]:
return {PassTag.AnnotateClause}
def invalidated_tags(self) -> Set[PassTag]:
return set()
def reset(self, program):
optimizer_state = program.optimizer_state # type: OptimizerState
state = optimizer_state.divide_basic_block
if state is None:
return
for bb in state.basic_blocks:
assert isinstance(bb, BasicBlock)
bb.annotate_clauses = None
def run(self, program) -> bool:
optimizer_state = program.optimizer_state # type: OptimizerState
state = optimizer_state.divide_basic_block # type: DivideBasicBlockPassState
assert state is not None
for bb in state.basic_blocks:
assert isinstance(bb, BasicBlock)
annclauses = []
for clause in bb.clauses:
annclauses.append(AnnClause(clause))
bb.annotate_clauses = annclauses
return True
from typing import Set
from abc import ABC
import enum
class PassTag(enum.IntEnum):
DivideBasicBlock = 0
OptimizeBasicBlock = 100
AnalyzeLiveVar = 110
EliminateDeadCode = 120
AnnotateClause = 200
InsertWaitcnt = 300
ComputeRegisterInterference = 310
AllocateRegisterRIG = 400
PrintBasicBlock = 1000
class OptimizerState:
def __init__(self):
self.divide_basic_block = None
# For ComputeRegisterInterferencePass
self.register_interference_vgpr = None
self.register_interference_sgpr = None
self.register_interference_perfect_elimination_ordering_vgpr = None
self.register_interference_perfect_elimination_ordering_sgpr = None
# For AllocateRegisterRIGPass
self.register_allocation_vgpr_by_color = None # # type: List[List[Gpr]]
self.register_allocation_sgpr_by_color = None # # type: List[List[Gpr]]
self.register_allocation_vgpr_count = None # # type: int
self.register_allocation_sgpr_count = None # # type: int
class BasePass(ABC):
def __init__(self, /, priority: int):
"""
Specify smaller integer for higher priority (aka. the passes runs earlier).
By default, priority is the (integer) value of pass's tag.
"""
if self.__class__.__name__ == 'BasePass':
raise NotImplementedError("Can't instantiate BasePass")
assert isinstance(priority, int)
self.__priority = priority # type: int
@property
def priority(self) -> int:
return self.__priority
def required_tags(self) -> Set[PassTag]:
raise NotImplementedError("Override required_tags() in derived class")
def generated_tags(self) -> Set[PassTag]:
raise NotImplementedError("Override generated_tags() in derived class")
def invalidated_tags(self) -> Set[PassTag]:
raise NotImplementedError("Override invalidated_tags() in derived class")
def reset(self, program):
raise NotImplementedError("Override reset() in derived class")
def run(self, program) -> bool:
"""
Returns true if this pass modifies anything. Returns false otherwise.
"""
raise NotImplementedError("Override run() in derived class")
def __repr__(self): # virtual method
return self.__class__.__name__
from typing import Dict, Set, Tuple, List, DefaultDict, Optional
from collections import defaultdict
from ..basic.register import Gpr, GprType
from .base_pass import BasePass, PassTag
from .divide_basic_block_pass import BasicBlock
class OneRegisterInterference:
def __init__(self, target_base_gpr: Gpr):
assert not target_base_gpr.is_view
self.target_base_gpr = target_base_gpr # type: Gpr
#
# How we represent the conflict of two Gpr:
# For each (conflict_base_gpr, offset) in the conflict set, the index of `conflict_base_gpr`
# cannot be the index of `self.target_base_gpr` + `offset`.
# Then we group them by `conflict_base_gpr`.
#
self.conflicts = defaultdict(set) # type: DefaultDict[Gpr, Set[int]]
def __repr__(self):
str_conflicts = [] # type: List[str]
for conflict_base_gpr, conflict_list in sorted(self.conflicts.items()):
assert not conflict_base_gpr.is_view
if len(conflict_list) == self.target_base_gpr.count + conflict_base_gpr.count - 1:
# This means self.target_base_gpr exactly conflicts with conflict_base_gpr
# (not even a single 4-byte Gpr can be overlapped)
str_conflicts.append(repr(conflict_base_gpr))
else:
# This means self.target_base_gpr does not exactly conflict with conflict_base_gpr (maybe less or more).
# Less: at least a single 4-byte Gpr can be overlapped
# More: or pre-indexed Gprs, self.target_base_gpr may conflict with more than conflict_base_gpr
str_conflicts.append(f"{repr(conflict_base_gpr)}@{{{','.join(repr(x) for x in sorted(conflict_list))}}}")
return repr(self.target_base_gpr) + " <-> {" + ", ".join(sorted(str_conflicts)) + "}"
class ComputeRegisterInterferencePass(BasePass):
def __init__(self, /, priority: int = PassTag.ComputeRegisterInterference.value):
super().__init__(priority)
def required_tags(self) -> Set[PassTag]:
return {PassTag.AnnotateClause}
def generated_tags(self) -> Set[PassTag]:
return {PassTag.ComputeRegisterInterference}
def invalidated_tags(self) -> Set[PassTag]:
return set()
def reset(self, program):
state = program.optimizer_state
if state is None:
return
state.register_interference_vgpr = None
state.register_interference_sgpr = None
state.register_interference_perfect_elimination_ordering_vgpr = None
state.register_interference_perfect_elimination_ordering_sgpr = None
@staticmethod
def __compute_perfect_elimination_ordering(interference: Dict[Gpr, OneRegisterInterference]) -> Optional[List[Gpr]]:
labels = dict() # type: Dict[Gpr, int]
for base_gpr in interference:
assert base_gpr not in labels
labels[base_gpr] = 0
perfect_elimination_ordering = [] # type: List[Gpr]
while labels:
max_label_value = max(labels.values())
max_label_base_gpr = [base_gpr for base_gpr, label_value in labels.items() if label_value == max_label_value][0]
perfect_elimination_ordering.insert(0, max_label_base_gpr)
labels.pop(max_label_base_gpr)
for neighbor in sorted(interference[max_label_base_gpr].conflicts):
assert not neighbor.is_view
if neighbor in labels:
labels[neighbor] += 1
assert len(perfect_elimination_ordering) == len(interference)
# Check whether the sequence is indeed a perfect elimination ordering
# See: https://www.dazhuanlan.com/2019/11/10/5dc7f1f735a82/
is_chordal_graph = True
for idx, base_gpr in enumerate(perfect_elimination_ordering):
following_neighbors = [x for x in perfect_elimination_ordering[idx+1:] if x in interference[base_gpr].conflicts]
if len(following_neighbors) > 1:
car = following_neighbors[0]
cdr = following_neighbors[1:]
for x in cdr:
if car not in interference[x].conflicts:
is_chordal_graph = False
if is_chordal_graph:
return perfect_elimination_ordering
else:
return perfect_elimination_ordering # None # TODO: return it even not chordal graph
def run(self, program) -> bool:
# Check that basic blocks have been divided
assert program.optimizer_state.divide_basic_block is not None
all_basic_block_list = program.optimizer_state.divide_basic_block.basic_blocks # type: List[BasicBlock]
# Check that Gpr life span has been computed
for bb in all_basic_block_list:
assert bb.live_var_gpr_life_span is not None
for base_gpr, bitmap_list in bb.live_var_gpr_life_span.items(): # type: (Gpr, List[int])
assert isinstance(base_gpr, Gpr)
assert isinstance(bitmap_list, list)
assert not base_gpr.is_view
assert len(bitmap_list) == base_gpr.count
register_interference_vgpr = dict() # type: Dict[Gpr, OneRegisterInterference]
register_interference_sgpr = dict() # type: Dict[Gpr, OneRegisterInterference]
def compute_interference_rtype(result: Dict[Gpr, OneRegisterInterference], rtype: GprType):
# Scan every basic blocks and collect Gpr interference
for bb in all_basic_block_list:
clauses_and_jump_instr_count = len(bb.clauses)
if bb.jump_instr is not None:
clauses_and_jump_instr_count += 1
for bit_at in range(0, 2*clauses_and_jump_instr_count+2):
conflict_set = set() # type: Set[Tuple[Gpr, int]] # (base_gpr, offset)
for base_gpr, bitmap_list in bb.live_var_gpr_life_span.items(): # type: (Gpr, List[int])
assert not base_gpr.is_view
if base_gpr.rtype != rtype:
continue
assert len(bitmap_list) == base_gpr.count
for offset, bitmap in enumerate(bitmap_list):
if bitmap & (1 << bit_at):
# This (base_gpr, offset) is active at this point (bit_at)
conflict_set.add((base_gpr, offset))
# All (base_gpr, offset) in conflict_set is active at this point (bit_at),
# thus conflict with each other
for (base_gpr1, offset1) in conflict_set:
if base_gpr1 not in result:
result[base_gpr1] = OneRegisterInterference(base_gpr1)
for (base_gpr2, offset2) in conflict_set:
if base_gpr1 is not base_gpr2:
result[base_gpr1].conflicts[base_gpr2].add(offset2 - offset1)
def get_forced_index(base_gpr: Gpr) -> Optional[int]:
assert not base_gpr.is_view
if base_gpr in program.forced_index:
return program.forced_index[base_gpr]
return None
# Special case: Pre-indexed Gprs conflicts with each other
pre_indexed_gprs = set(gpr for gpr in result if get_forced_index(gpr) is not None)
for base_gpr1 in pre_indexed_gprs:
assert base_gpr1 in result
for base_gpr2 in pre_indexed_gprs:
if base_gpr1 is not base_gpr2:
for offset1 in range(0, base_gpr1.count):
for offset2 in range(0, base_gpr2.count):
result[base_gpr1].conflicts[base_gpr2].add(offset2 - offset1)
result[base_gpr2].conflicts[base_gpr1].add(offset1 - offset2)
# Special case: (quite conservative)
#
# If a pre-indexed Gpr r1 has interference with another Gpr r2,
# then any other pre-indexed Gpr has interference with r2
#
for base_gpr1 in result:
if get_forced_index(base_gpr1) is None: # base_gpr1 is not pre-indexed
continue
for base_gpr2 in result:
if (get_forced_index(base_gpr2) is None) or (base_gpr2 is base_gpr1):
continue
for gpr, offset1_list in sorted(result[base_gpr1].conflicts.items()):
if get_forced_index(gpr) is not None:
continue
for offset1 in offset1_list:
result[base_gpr2].conflicts[gpr].add(get_forced_index(base_gpr2) - (get_forced_index(base_gpr1) - offset1))
result[gpr].conflicts[base_gpr2].add((get_forced_index(base_gpr1) - offset1) - get_forced_index(base_gpr2))
# Compute VGPR & SGPR interference
# We don't need to compute interference for Special Gpr
compute_interference_rtype(register_interference_vgpr, GprType.V)
compute_interference_rtype(register_interference_sgpr, GprType.S)
# Save the results
program.optimizer_state.register_interference_vgpr = register_interference_vgpr
program.optimizer_state.register_interference_sgpr = register_interference_sgpr
# Check whether this is chordal graph
program.optimizer_state.register_interference_perfect_elimination_ordering_vgpr = \
ComputeRegisterInterferencePass.__compute_perfect_elimination_ordering(register_interference_vgpr)
program.optimizer_state.register_interference_perfect_elimination_ordering_sgpr = \
ComputeRegisterInterferencePass.__compute_perfect_elimination_ordering(register_interference_sgpr)
return True
from __future__ import annotations
from typing import Set, List, Optional, Union, Dict, Any
from ..basic.const import _INSTR_STR_WIDTH
from ..basic.register import Gpr, GprSet
from ..basic.exception import SeekException, check
from ..basic.instr import InstrCall, ExplicitWaitCall, ExplicitUsesCall, MemToken, Block, ControlFlowEnum
from .base_pass import BasePass, PassTag, OptimizerState
class BasicBlockJumpInstr:
def __init__(self,
control_flow_enum: ControlFlowEnum,
instr_name: str,
gpr_uses: Dict[Gpr, Dict[int, int]], # base_gpr -> {offset -> count}
gpr_holds: Dict[Gpr, Dict[int, int]], # base_gpr -> {offset -> count}
gpr_defs: Dict[Gpr, Dict[int, int]]): # base_gpr -> {offset -> count}
assert control_flow_enum != ControlFlowEnum.Fallthrough
self.control_flow_enum = control_flow_enum # type: ControlFlowEnum
self.instr_name = instr_name # type: str
self.gpr_uses = gpr_uses # type: Dict[Gpr, Dict[int, int]] # base_gpr -> {offset -> count}
self.gpr_holds = gpr_holds # type: Dict[Gpr, Dict[int, int]] # base_gpr -> {offset -> count}
self.gpr_defs = gpr_defs # type: Dict[Gpr, Dict[int, int]] # base_gpr -> {offset -> count}
def gpr_uses_to_gprset(self) -> GprSet:
return InstrCall.gpr_uses_holds_defs_to_gprset(self.gpr_uses)
def gpr_holds_to_gprset(self) -> GprSet:
return InstrCall.gpr_uses_holds_defs_to_gprset(self.gpr_holds)
def gpr_defs_to_gprset(self) -> GprSet:
return InstrCall.gpr_uses_holds_defs_to_gprset(self.gpr_defs)
def generate(self, program, wr, successor_if_jump):
instr_text = f"{self.instr_name} ".ljust(_INSTR_STR_WIDTH)
if successor_if_jump is None:
assert self.control_flow_enum == ControlFlowEnum.Terminate
wr(f"{instr_text}// {self.control_flow_enum.name}")
else:
assert self.control_flow_enum in (ControlFlowEnum.AlwaysJump, ControlFlowEnum.CondJump)
wr(f"{instr_text}{successor_if_jump.name} // {self.control_flow_enum.name}")
@staticmethod
def from_instr_call(instr: InstrCall) -> BasicBlockJumpInstr:
assert instr.control_flow_enum != ControlFlowEnum.Fallthrough
return BasicBlockJumpInstr(instr.control_flow_enum,
instr.instr_name,
instr.gpr_uses,
instr.gpr_holds,
instr.gpr_defs)
@staticmethod
def make_alwaysjump() -> BasicBlockJumpInstr:
# NOTE: no uses/holds/defs
return BasicBlockJumpInstr(ControlFlowEnum.AlwaysJump, "s_branch", {}, {}, {})
@staticmethod
def make_terminate() -> BasicBlockJumpInstr:
# NOTE: no uses/holds/defs
return BasicBlockJumpInstr(ControlFlowEnum.Terminate, "s_endpgm", {}, {}, {})
class BasicBlock:
def __init__(self, name: str):
self.name = name # type: str
# For s_branch: successor_if_jump is target if jumping, successor_if_fallthrough is None
# For s_cbranch_xxx family: successor_if_jump is target if jumping, successor_if_fallthrough is target if no jumping # noqa E501: line too long
# For s_endpgm family: successor_if_jump is None, successor_if_fallthrough is None
# For fall-through: successor_if_jump is None, successor_if_fallthrough is target
self.successor_if_jump = None # type: Optional[BasicBlock]
self.successor_if_fallthrough = None # type: Optional[BasicBlock]
self.predecessors = [] # type: List[BasicBlock]
self.jump_instr = None # type: Optional[BasicBlockJumpInstr]
self.clauses = [] # type: List[Union[InstrCall, ExplicitWaitCall, ExplicitUsesCall]]
# For AnalyzeLiveVarPass
self.live_var_defs = None # type: GprSet
self.live_var_uses = None # type: GprSet
self.live_var_in = None # type: GprSet
self.live_var_out = None # type: GprSet
self.live_var_gpr_life_span = None # type: Dict[Gpr, List[int]]
# For AnnotateClausePass
self.annotate_clauses = None # type: List # List[AnnClause]
# For InsertWaitcntPass
self.pending_mem_in = None # type: Set # Set[PendingMem]
self.pending_mem_out = None # type: Set # Set[PendingMem]
def _sanity_check_control_flow_at_exit(self):
if self.jump_instr is None:
# For fall-through: successor_if_jump is None, successor_if_fallthrough is target
assert self.successor_if_jump is None
assert isinstance(self.successor_if_fallthrough, BasicBlock)
else:
control_flow = self.jump_instr.control_flow_enum
if control_flow == ControlFlowEnum.AlwaysJump:
# For s_branch: successor_if_jump is target if jumping, successor_if_fallthrough is None
assert isinstance(self.successor_if_jump, BasicBlock)
assert self.successor_if_fallthrough is None
elif control_flow == ControlFlowEnum.CondJump:
# For s_cbranch_xxx family: successor_if_jump is target if jumping, successor_if_fallthrough is target if no jumping # noqa E501: line too long
assert isinstance(self.successor_if_jump, BasicBlock)
assert isinstance(self.successor_if_fallthrough, BasicBlock)
# NOTE:
# A basic-block, which conditionally jumps to its successor basic-block,
# makes `successor_if_jump` and `successor_if_fallthrough` refer to the same basic-block.
#
# We just make this evil basic-block just fall through to its successor.
#
# This is not only an optimization, but also necessary:
# after this fix-up, no (successor) basic-block B will have two same precedent basic-blocks A
# (aka A.successor_if_jump is B and A.successor_if_fallthrough is B)
assert self.successor_if_jump is not self.successor_if_fallthrough
elif control_flow == ControlFlowEnum.Terminate:
# For s_endpgm family: successor_if_jump is None, successor_if_fallthrough is None
assert self.successor_if_jump is None
assert self.successor_if_fallthrough is None
else:
assert False, f"Unexpected control_flow: {control_flow}"
@property
def control_flow_at_exit(self) -> ControlFlowEnum:
self._sanity_check_control_flow_at_exit()
if self.jump_instr is None:
return ControlFlowEnum.Fallthrough
else:
return self.jump_instr.control_flow_enum
def __repr__(self):
return f"BasicBlock({repr(self.name)})"
class DivideBasicBlockPassState:
def __init__(self):
self.basic_blocks = [] # type: List[BasicBlock]
self.mem_token_object_to_mem_tokens = {} # type: Dict[Any, MemToken]
def _sanity_check(self):
assert isinstance(self.basic_blocks, list)
# Check all `successor_if_fallthrough` are sane
for idx, bb in enumerate(self.basic_blocks):
if bb.successor_if_fallthrough is not None:
assert bb.successor_if_fallthrough is self.basic_blocks[idx+1]
# Check there are no duplicated basic-blocks
bb_set = set() # type: Set[BasicBlock]
for bb in self.basic_blocks:
assert isinstance(bb, BasicBlock)
assert bb not in bb_set, f"Duplicates in self.basic_blocks: {bb}"
bb_set.add(bb)
# Check `predecessors` and `jump_instr` are sane
for bb in bb_set:
assert len(bb.predecessors) == len(set(bb.predecessors)), \
f"Duplicated predecessors in {bb}: {bb.predecessors}"
for pred in bb.predecessors:
assert isinstance(pred, BasicBlock)
assert pred in bb_set
assert pred.successor_if_jump is bb or pred.successor_if_fallthrough is bb
assert not (pred.successor_if_jump is bb and pred.successor_if_fallthrough is bb)
if bb.successor_if_jump is not None:
assert bb in bb.successor_if_jump.predecessors
if bb.successor_if_fallthrough is not None:
assert bb in bb.successor_if_fallthrough.predecessors
# Also, check that `jump_instr` is sane
bb._sanity_check_control_flow_at_exit()
class DivideBasicBlockPass(BasePass):
"""
Divide program blocks into basic-blocks.
NOTE: Dead basic-blocks are **not** pruned in this pass. Leave them to OptimizeBasicBlockPass
"""
def __init__(self, /, priority: int = PassTag.DivideBasicBlock.value):
super().__init__(priority)
def required_tags(self) -> Set[PassTag]:
return set()
def generated_tags(self) -> Set[PassTag]:
return {PassTag.DivideBasicBlock}
def invalidated_tags(self) -> Set[PassTag]:
return set()
def reset(self, program):
optimizer_state = program.optimizer_state # type: OptimizerState
optimizer_state.divide_basic_block = None
def run(self, program) -> bool:
optimizer_state = program.optimizer_state # type: OptimizerState
#
# Each `Block` could be divided to one or more basic-blocks (due to branching instructions)
# Here, each element of `bb_by_block` is a list, corresponding to all basic-blocks of a Block
#
bb_by_block = {} # type: Dict[Block, List[BasicBlock]]
for block in program.blocks: # type: Block
assert block not in bb_by_block
bb = BasicBlock(block.block_name)
bb_by_block[block] = [bb]
for block_idx, block in enumerate(program.blocks): # type: (int, Block)
curr_bb = bb_by_block[block][0] # this should be the very first basic-block
# Init `alwaysjump_or_terminate_just_now` in case `block.clauses` is empty
alwaysjump_or_terminate_just_now = False
for clause in block.clauses: # type: Union[InstrCall, ExplicitWaitCall, ExplicitUsesCall]
assert curr_bb is not None
alwaysjump_or_terminate_just_now = False
# NOTE: InstrCall is added to curr_bb even if they are control-flow instructions
# This will be fixed (aka removed) later when setting `jump_instr`
curr_bb.clauses.append(clause)
if isinstance(clause, ExplicitWaitCall):
# Nothing to do with ExplicitWaitCall
continue
if isinstance(clause, ExplicitUsesCall):
# Nothing to do with ExplicitUsesCall
continue
assert isinstance(clause, InstrCall)
instr = clause # type: InstrCall
new_bb = None
if instr.control_flow_enum in (ControlFlowEnum.AlwaysJump, ControlFlowEnum.Terminate):
alwaysjump_or_terminate_just_now = True
if instr.control_flow_enum != ControlFlowEnum.Fallthrough:
# We need to switch to a new basic-block **after** this instruction
new_bb = BasicBlock(f"{block.block_name}_BB{len(bb_by_block[block])}")
bb_by_block[block].append(new_bb)
assert curr_bb.successor_if_jump is None
assert curr_bb.successor_if_fallthrough is None
if instr.control_flow_enum == ControlFlowEnum.CondJump:
cond_jump_target = program.get_block(instr.operands["label"])
curr_bb.successor_if_jump = bb_by_block[cond_jump_target][0]
curr_bb.successor_if_fallthrough = new_bb
elif instr.control_flow_enum == ControlFlowEnum.AlwaysJump:
alwaysjump_target = program.get_block(instr.operands["label"])
curr_bb.successor_if_jump = bb_by_block[alwaysjump_target][0]
# curr_bb.successor_if_fallthrough = None # not necessary as it is already None
elif instr.control_flow_enum == ControlFlowEnum.Terminate:
# curr_bb.successor_if_jump = None # not necessary as it is already None
# curr_bb.successor_if_fallthrough = None # not necessary as it is already None
pass
# Switch to a new basic-block if desired
if new_bb is not None:
curr_bb = new_bb
assert curr_bb is not None
assert curr_bb.successor_if_jump is None
assert curr_bb.successor_if_fallthrough is None
if block_idx+1 == len(program.blocks):
# If this is the last block: we must have ended with s_endpgm or s_branch
if not alwaysjump_or_terminate_just_now:
raise SeekException("Program doesn't terminate with s_endpgm family or s_branch")
# Remove this (empty) basic-block from current block
assert len(curr_bb.clauses) == 0 # this basic-block must be empty
assert len(bb_by_block[block]) >= 1 # the last block must have at least 1 (just added) basic-block
popped_bb = bb_by_block[block].pop()
assert popped_bb is curr_bb
else:
# Now all instructions in this Block have been added to basic-blocks
# Link curr_bb to next block unconditionally (aka fall-through)
next_block = program.blocks[block_idx+1]
curr_bb.successor_if_fallthrough = bb_by_block[next_block][0]
# Flatten basic-block list of all Blocks
state = DivideBasicBlockPassState()
for bb_list in bb_by_block.values():
state.basic_blocks += bb_list
if len(state.basic_blocks) == 0:
raise SeekException("Program doesn't contain any instructions")
# Fix-up setting `jump_instr`
# Fix-up all last branching instruction for basic-blocks: remove them from current BasicBlock
for bb in state.basic_blocks:
assert bb.jump_instr is None
if bb.successor_if_jump is None and bb.successor_if_fallthrough is not None:
# For fall-through: successor_if_jump is None, successor_if_fallthrough is target
pass
else:
last_instr = bb.clauses.pop()
assert isinstance(last_instr, InstrCall)
# NOTE: may be changed back to None for ControlFlowEnum.CondJump
bb.jump_instr = BasicBlockJumpInstr.from_instr_call(last_instr)
if last_instr.control_flow_enum == ControlFlowEnum.CondJump:
# If we are conditionally jumping to the successor basic-block, eliminate the s_cbranch_xxx
# This is not only an optimization, but also necessary
#
# After this fix-up, no (successor) basic-block B will have two same precedent basic-blocks A
# (aka A.successor_if_jump is B and A.successor_if_fallthrough is B)
if bb.successor_if_jump is bb.successor_if_fallthrough:
# We simply set successor_if_jump to None, and leave successor_if_fallthrough as is (like normal fall-through) # noqa E501: line too long
# Moreover, `last_instr` (s_cbranch_xxx) is dropped
bb.successor_if_jump = None
bb.jump_instr = None
bb._sanity_check_control_flow_at_exit()
# Fix-up: compute predecessors for all basic-blocks
for bb in state.basic_blocks:
if bb.successor_if_jump is not None:
assert bb not in bb.successor_if_jump.predecessors
bb.successor_if_jump.predecessors.append(bb)
if bb.successor_if_fallthrough is not None:
assert bb not in bb.successor_if_fallthrough.predecessors
bb.successor_if_fallthrough.predecessors.append(bb)
# All done
state._sanity_check()
# Resolve mem_token: `state.mem_token_object_to_mem_tokens` is filled
self.__resolve_mem_token(program, state)
optimizer_state.divide_basic_block = state
return True
# noinspection PyMethodMayBeStatic
def __resolve_mem_token(self, program, state: DivideBasicBlockPassState):
"""
Check there are no duplicated mem_token_object.
Resolves mem_token_object to mem_token in explicit_call.
This must be done after after basic-blocks are divided, but before OptimizeBasicBlockPass,
in case some mem_token are removed by dead basic-block elimination.
"""
explicit_wait_mem_token_or_token_objects = set() # type: Set[Any]
for bb in state.basic_blocks: # type: BasicBlock
for clause in bb.clauses:
if isinstance(clause, ExplicitWaitCall):
for mem_token_or_token_object in clause.mem_token_or_token_objects:
assert mem_token_or_token_object is not None
explicit_wait_mem_token_or_token_objects.add(mem_token_or_token_object)
elif isinstance(clause, ExplicitUsesCall):
pass
else:
assert isinstance(clause, InstrCall)
if clause.mem_token is not None:
assert isinstance(clause.mem_token, MemToken)
check(clause.mem_token.token_object not in state.mem_token_object_to_mem_tokens,
f"Duplicated mem_token: {clause.mem_token}")
state.mem_token_object_to_mem_tokens[clause.mem_token.token_object] = clause.mem_token
# Check that all explicit waits are valid
for mem_token_or_token_object in explicit_wait_mem_token_or_token_objects:
if isinstance(mem_token_or_token_object, MemToken):
mem_token = mem_token_or_token_object
check(mem_token.token_object in state.mem_token_object_to_mem_tokens,
f"mem_token not found: {mem_token}")
my_mem_token = state.mem_token_object_to_mem_tokens[mem_token.token_object]
check(mem_token is my_mem_token,
f"mem_token mismatched: {mem_token} is not {my_mem_token}")
else: # this is a mem_token_object
mem_token_object = mem_token_or_token_object
check(mem_token_object in state.mem_token_object_to_mem_tokens,
f"mem_token_object not found: {mem_token_object}")
from typing import Set
from ..basic.instr import InstrCall, ExplicitWaitCall, ExplicitUsesCall
from .base_pass import BasePass, PassTag, OptimizerState
from .divide_basic_block_pass import BasicBlock
class EliminateDeadCodePass(BasePass):
"""
As the name suggests, this pass does dead code elimination (DCE)
"""
def __init__(self, /, priority: int = PassTag.EliminateDeadCode.value):
super().__init__(priority)
def required_tags(self) -> Set[PassTag]:
return {PassTag.AnalyzeLiveVar}
def generated_tags(self) -> Set[PassTag]:
return {PassTag.EliminateDeadCode}
def invalidated_tags(self) -> Set[PassTag]:
return {PassTag.AnalyzeLiveVar, PassTag.OptimizeBasicBlock, PassTag.AnnotateClause}
def reset(self, program):
# This pass has nothing to reset
pass
def run(self, program) -> bool:
optimizer_state = program.optimizer_state # type: OptimizerState
state = optimizer_state.divide_basic_block
assert state is not None
# Compute defs and uses for each individual basic-block
modified = False
for bb in state.basic_blocks:
modified |= self.__dead_code_elimination(program, bb)
return modified
# noinspection PyMethodMayBeStatic
def __dead_code_elimination(self, program, bb: BasicBlock) -> bool:
"""
Eliminate unnecessary InstrCall for a basic-block
Returns True if we actually removed something
"""
curr_uses = bb.live_var_out.clone()
idx_to_remove = set() # type: Set[int]
# Don't forget the last `jump_instr`, if exists
if bb.jump_instr is not None:
curr_uses.union_update(bb.jump_instr.gpr_uses_to_gprset())
assert len(bb.jump_instr.gpr_holds) == 0
assert len(bb.jump_instr.gpr_defs) == 0
# Scan from the last to the first
for idx, clause in reversed(list(enumerate(bb.clauses))):
if isinstance(clause, ExplicitWaitCall):
# We don't remove ExplicitWaitCall
continue
if isinstance(clause, ExplicitUsesCall):
# We don't remove ExplicitUsesCall
# TODO: maybe there could be a hint to remove them too?
curr_uses.union_update(*clause.uses)
continue
assert isinstance(clause, InstrCall)
if clause.mem_token is not None:
# If this instruction is memory-related, we (conservatively) don't remove it
pass
elif len(clause.gpr_defs) == 0:
# If this instruction doesn't define any Gpr, supposedly it should have side effects
program.logger.debug(f"{self}: assumed side effects: `{clause.instr_name}` at {clause.srcloc}")
pass
else:
gprset_defs = clause.gpr_defs_to_gprset()
if curr_uses.is_intersected(gprset_defs):
curr_uses.difference_update(gprset_defs)
else:
idx_to_remove.add(idx)
program.logger.debug(f"{self}: eliminated: {clause}")
continue
# Now this instruction is **not** removed
curr_uses.union_update(clause.gpr_uses_to_gprset())
# Remove index(es) from `idx_to_remove`
if len(idx_to_remove) > 0:
bb.clauses = [clause for idx, clause in enumerate(bb.clauses) if idx not in idx_to_remove]
return True
else:
return False
from __future__ import annotations
from typing import Set, List, Optional, Callable
from ..basic.exception import SeekException
from ..basic.register import GprSet
from ..basic.instr import MemToken, ExplicitWaitCall, ExplicitUsesCall, InstrCall, Waitcnt, ControlFlowEnum
from .base_pass import BasePass, PassTag, OptimizerState
from .divide_basic_block_pass import BasicBlock, DivideBasicBlockPassState
from .annotate_clause_pass import AnnClause
class PendingMem:
def __init__(self):
self.pending_vector = [] # type: List[MemToken]
self.pending_lds = [] # type: List[MemToken]
self.pending_gds = [] # type: List[MemToken]
self.pending_scalar = [] # type: List[MemToken]
self.pending_msg = [] # type: List[MemToken]
self.pending_export = [] # type: List[MemToken]
def __eq__(self, other):
if not isinstance(other, PendingMem):
return NotImplemented
return self.pending_vector == other.pending_vector and \
self.pending_lds == other.pending_lds and \
self.pending_gds == other.pending_gds and \
self.pending_scalar == other.pending_scalar and \
self.pending_msg == other.pending_msg and \
self.pending_export == other.pending_export
def __hash__(self):
return hash((tuple(self.pending_vector), tuple(self.pending_lds), tuple(self.pending_gds),
tuple(self.pending_scalar), tuple(self.pending_msg), tuple(self.pending_export)))
def __repr__(self):
pending_list_reprs = [] # type: List[str]
if self.pending_vector:
pending_list_reprs.append(f"pending_vector={self.pending_vector}")
if self.pending_lds:
pending_list_reprs.append(f"pending_lds={self.pending_lds}")
if self.pending_gds:
pending_list_reprs.append(f"pending_gds={self.pending_gds}")
if self.pending_scalar:
pending_list_reprs.append(f"pending_scalar={self.pending_scalar}")
if self.pending_msg:
pending_list_reprs.append(f"pending_msg={self.pending_msg}")
if self.pending_export:
pending_list_reprs.append(f"pending_export={self.pending_export}")
return f"PendingMem({', '.join(pending_list_reprs)})"
def clone(self) -> PendingMem:
result = PendingMem()
result.pending_vector = self.pending_vector.copy()
result.pending_lds = self.pending_lds.copy()
result.pending_gds = self.pending_gds.copy()
result.pending_scalar = self.pending_scalar.copy()
result.pending_msg = self.pending_msg.copy()
result.pending_export = self.pending_export.copy()
return result
def add_pending(self, mem_token: MemToken):
# Currently, there shall be **exactly** one inc_xxx > 0
assert isinstance(mem_token, MemToken), mem_token
assert int(mem_token.inc_vector > 0) + int(mem_token.inc_lds > 0) + int(mem_token.inc_gds > 0) + \
int(mem_token.inc_scalar > 0) + int(mem_token.inc_msg > 0) + int(mem_token.inc_export > 0) == 1, mem_token
def __add_if_exists(inc: int, pending_list: List[MemToken]):
if inc > 0:
assert mem_token not in pending_list
pending_list.append(mem_token)
__add_if_exists(mem_token.inc_vector, self.pending_vector)
__add_if_exists(mem_token.inc_lds, self.pending_lds)
__add_if_exists(mem_token.inc_gds, self.pending_gds)
__add_if_exists(mem_token.inc_scalar, self.pending_scalar)
__add_if_exists(mem_token.inc_msg, self.pending_msg)
__add_if_exists(mem_token.inc_export, self.pending_export)
def update_by_waitcnt(self, waitcnt: Waitcnt) -> None:
def __maybe_pop(waitcnt_value: Optional[int],
pending_list: List[MemToken],
select_waitcnt: Callable[[MemToken], int]):
if waitcnt_value is not None:
cnt_sum = sum(select_waitcnt(x) for x in pending_list)
while cnt_sum > waitcnt_value:
cnt_sum -= select_waitcnt(pending_list.pop(0))
assert sum(select_waitcnt(x) for x in pending_list) <= waitcnt_value
__maybe_pop(waitcnt.vmcnt, self.pending_vector, lambda w: w.total_inc_vmcnt)
__maybe_pop(waitcnt.lgkmcnt, self.pending_lds, lambda w: w.total_inc_lgkmcnt)
__maybe_pop(waitcnt.lgkmcnt, self.pending_gds, lambda w: w.total_inc_lgkmcnt)
if waitcnt.lgkmcnt == 0:
# Scalar reads/writes return out of order! Only valid wait is: `s_waitcnt lgkmcnt(0)`
__maybe_pop(waitcnt.lgkmcnt, self.pending_scalar, lambda w: w.total_inc_lgkmcnt)
__maybe_pop(waitcnt.lgkmcnt, self.pending_msg, lambda w: w.total_inc_lgkmcnt)
__maybe_pop(waitcnt.expcnt, self.pending_export, lambda w: w.total_inc_expcnt)
def update_by_explicit_wait(self, mem_tokens: Set[MemToken]) -> Optional[Waitcnt]:
waitcnt = None # type: Optional[Waitcnt]
for mem_token in mem_tokens:
assert isinstance(mem_token, MemToken), mem_token
def __maybe_update(pending_list: List[MemToken],
key: str,
out_of_order_return: bool,
select_waitcnt: Callable[[MemToken], int]):
assert key in ("vmcnt", "lgkmcnt", "expcnt")
nonlocal waitcnt
if mem_token in pending_list:
if out_of_order_return:
pending_list.clear()
waitcnt = Waitcnt.lcm(waitcnt, Waitcnt(**{key: 0}))
else:
while True:
popped = pending_list.pop(0)
if popped == mem_token:
break
assert mem_token not in pending_list
waitcnt = Waitcnt.lcm(waitcnt, Waitcnt(**{key: sum(select_waitcnt(x) for x in pending_list)}))
__maybe_update(self.pending_vector, "vmcnt", False, lambda w: w.total_inc_vmcnt)
__maybe_update(self.pending_lds, "lgkmcnt", False, lambda w: w.total_inc_lgkmcnt)
__maybe_update(self.pending_gds, "lgkmcnt", False, lambda w: w.total_inc_lgkmcnt)
__maybe_update(self.pending_scalar, "lgkmcnt", True, lambda w: w.total_inc_lgkmcnt) # out of order return!
__maybe_update(self.pending_msg, "lgkmcnt", False, lambda w: w.total_inc_lgkmcnt)
__maybe_update(self.pending_export, "expcnt", False, lambda w: w.total_inc_expcnt)
# Well, this waitcnt may cause further updates on pending_xxx
if waitcnt is not None:
self.update_by_waitcnt(waitcnt)
return waitcnt
def update_by_uses(self, gprset: GprSet) -> Waitcnt:
waitcnt = None # type: Optional[Waitcnt]
def __maybe_update(pending_list: List[MemToken],
key: str,
out_of_order_return: bool,
select_waitcnt: Callable[[MemToken], int]):
assert key in ("vmcnt", "lgkmcnt", "expcnt")
nonlocal waitcnt
for idx in range(len(pending_list)-1, -1, -1):
if GprSet(*pending_list[idx].load_mem_to_gprs).is_intersected(gprset):
if out_of_order_return:
pending_list.clear()
waitcnt = Waitcnt.lcm(waitcnt, Waitcnt(**{key: 0}))
else:
for _ in range(idx+1):
pending_list.pop(0)
waitcnt = Waitcnt.lcm(waitcnt, Waitcnt(**{key: sum(select_waitcnt(x) for x in pending_list)}))
break
__maybe_update(self.pending_vector, "vmcnt", False, lambda w: w.total_inc_vmcnt)
__maybe_update(self.pending_lds, "lgkmcnt", False, lambda w: w.total_inc_lgkmcnt)
__maybe_update(self.pending_gds, "lgkmcnt", False, lambda w: w.total_inc_lgkmcnt)
__maybe_update(self.pending_scalar, "lgkmcnt", True, lambda w: w.total_inc_lgkmcnt) # out of order return!
__maybe_update(self.pending_msg, "lgkmcnt", False, lambda w: w.total_inc_lgkmcnt)
__maybe_update(self.pending_export, "expcnt", False, lambda w: w.total_inc_expcnt)
# Well, this waitcnt may cause further updates on pending_xxx
if waitcnt is not None:
self.update_by_waitcnt(waitcnt)
return waitcnt
def check_by_defs(self, gprset: GprSet, instr: InstrCall) -> None:
def __check(pending_list: List[MemToken]):
for mem_token in pending_list:
if GprSet(*mem_token.load_mem_to_gprs).is_intersected(gprset):
raise SeekException(f"Instruction defines Gpr writen by pending memory {mem_token}: {instr}")
__check(self.pending_vector)
__check(self.pending_lds)
__check(self.pending_gds)
__check(self.pending_scalar) # out of order return!
__check(self.pending_msg)
__check(self.pending_export)
class InsertWaitcntPass(BasePass):
def __init__(self, /,
s_barrier_implies_wait_vmcnt_0: bool = True,
priority: int = PassTag.InsertWaitcnt.value):
"""
s_barrier_implies_wait_vmcnt_0:
Does `s_barrier` imply `s_waitcnt vmcnt(0)`?
This is not a documented behavior, but it seems true from micro_benchmark.
"""
super().__init__(priority)
self.s_barrier_implies_wait_vmcnt_0 = s_barrier_implies_wait_vmcnt_0
def required_tags(self) -> Set[PassTag]:
return {PassTag.AnnotateClause}
def generated_tags(self) -> Set[PassTag]:
return {PassTag.InsertWaitcnt}
def invalidated_tags(self) -> Set[PassTag]:
return set()
def reset(self, program):
optimizer_state = program.optimizer_state # type: OptimizerState
state = optimizer_state.divide_basic_block # type: DivideBasicBlockPassState
if state is None:
return
for bb in state.basic_blocks:
assert isinstance(bb, BasicBlock)
bb.pending_mem_in = None
bb.pending_mem_out = None
if bb.annotate_clauses is None:
continue
for annclause in bb.annotate_clauses:
assert isinstance(annclause, AnnClause)
annclause.insert_waitcnt = None
def run(self, program) -> bool:
# Note: let's reset all existing `annclause.insert_waitcnt`, if any
self.reset(program)
optimizer_state = program.optimizer_state # type: OptimizerState
state = optimizer_state.divide_basic_block # type: DivideBasicBlockPassState
assert state is not None
update_queue = [] # type: List[BasicBlock]
for bb in state.basic_blocks:
assert isinstance(bb, BasicBlock)
assert bb.annotate_clauses is not None
for annclause in bb.annotate_clauses:
assert isinstance(annclause, AnnClause)
assert annclause.insert_waitcnt is None
bb.pending_mem_in = []
bb.pending_mem_out = []
state.basic_blocks[0].pending_mem_in = [PendingMem()]
update_queue.append(state.basic_blocks[0])
while update_queue:
bb = update_queue.pop(0)
# Shallow copy current basic-block's set of PendingMem at entrance to a list.
# After its updates by all clauses by current basic-block,
# it's distincted and then becomes the set of PendingMem at exit.
pending_mem_list = list(pending_mem.clone() for pending_mem in bb.pending_mem_in) # type: List[PendingMem]
# Loop from first to last clause
for annclause in bb.annotate_clauses:
clause = annclause.clause
waitcnt = annclause.insert_waitcnt # type: Optional[Waitcnt]
if waitcnt is not None:
for pending_mem in pending_mem_list:
pending_mem.update_by_waitcnt(waitcnt)
if isinstance(clause, ExplicitWaitCall):
mem_tokens = set() # type: Set[MemToken]
for mem_token_or_token_object in clause.mem_token_or_token_objects:
if isinstance(mem_token_or_token_object, MemToken):
mem_tokens.add(mem_token_or_token_object)
else:
assert mem_token_or_token_object in state.mem_token_object_to_mem_tokens
mem_tokens.add(state.mem_token_object_to_mem_tokens[mem_token_or_token_object])
for pending_mem in pending_mem_list:
new_waitcnt = pending_mem.update_by_explicit_wait(mem_tokens)
waitcnt = Waitcnt.lcm(waitcnt, new_waitcnt)
elif isinstance(clause, ExplicitUsesCall):
for pending_mem in pending_mem_list:
new_waitcnt = pending_mem.update_by_uses(GprSet(*clause.uses))
waitcnt = Waitcnt.lcm(waitcnt, new_waitcnt)
else:
assert isinstance(clause, InstrCall)
# Update by the instruction's uses first
for pending_mem in pending_mem_list:
new_waitcnt = pending_mem.update_by_uses(clause.gpr_uses_to_gprset())
waitcnt = Waitcnt.lcm(waitcnt, new_waitcnt)
# Then, check by the instruction's defs
for pending_mem in pending_mem_list:
pending_mem.check_by_defs(clause.gpr_defs_to_gprset(), clause)
# Deal with s_waitcnt
if clause.instr_name == "s_waitcnt":
operand = clause.operands["waitcnt"]
if isinstance(operand, int):
new_waitcnt = Waitcnt.from_int(operand)
else:
assert isinstance(operand, Waitcnt)
new_waitcnt = waitcnt
waitcnt = Waitcnt.lcm(waitcnt, new_waitcnt)
# Deal with s_barrier
if self.s_barrier_implies_wait_vmcnt_0 and clause.instr_name == "s_barrier":
waitcnt = Waitcnt.lcm(waitcnt, Waitcnt(vmcnt=0))
# OK, `waitcnt` may (or may not) be updated.
# Let's apply it to `pending_mem_list`.
if waitcnt is not None:
for pending_mem in pending_mem_list:
pending_mem.update_by_waitcnt(waitcnt)
annclause.insert_waitcnt = waitcnt
# If this is a memory instruction, we have to update `pending_mem_list` here
if isinstance(clause, InstrCall):
if clause.mem_token is not None:
pending_mem_list.append(PendingMem())
for pending_mem in pending_mem_list:
pending_mem.add_pending(clause.mem_token)
# Possibly update `bb.pending_mem_out`
if bb.control_flow_at_exit == ControlFlowEnum.Terminate:
# s_endpgm family implies wait all pending memory
pending_mem_list = []
else:
pending_mem_list = list(set(pending_mem_list)) # unique
pending_mem_list.sort(key=repr)
if bb.pending_mem_out != pending_mem_list:
bb.pending_mem_out = pending_mem_list
if bb.successor_if_jump is not None:
# Calculate successor_if_jump basic-block's set of PendingMem at entrance.
bb.successor_if_jump.pending_mem_in = []
for pred in bb.successor_if_jump.predecessors:
bb.successor_if_jump.pending_mem_in += pred.pending_mem_out
bb.successor_if_jump.pending_mem_in = list(set(bb.successor_if_jump.pending_mem_in)) # unique
bb.successor_if_jump.pending_mem_in.sort(key=repr)
update_queue.append(bb.successor_if_jump)
if bb.successor_if_fallthrough is not None:
# Calculate successor_if_fallthrough basic-block's set of PendingMem at entrance.
bb.successor_if_fallthrough.pending_mem_in = []
for pred in bb.successor_if_fallthrough.predecessors:
bb.successor_if_fallthrough.pending_mem_in += pred.pending_mem_out
bb.successor_if_fallthrough.pending_mem_in = list(set(bb.successor_if_fallthrough.pending_mem_in)) # unique # noqa E501: line too long
bb.successor_if_fallthrough.pending_mem_in.sort(key=repr)
update_queue.append(bb.successor_if_fallthrough)
return True
This diff is collapsed.
from typing import Set, TextIO, Optional
from ..basic.utility import IndentedWriter
from ..basic.exception import check
from .base_pass import BasePass, PassTag
from .divide_basic_block_pass import DivideBasicBlockPassState, OptimizerState
import sys
class PrintBasicBlockPass(BasePass):
def __init__(self, /, file: TextIO = sys.stdout, indent: int = 0, priority: int = PassTag.PrintBasicBlock.value):
# By default, this pass should run in the very end
super().__init__(priority)
check(file is not None)
check(indent >= 0)
self.wr = IndentedWriter(file, indent)
def required_tags(self) -> Set[PassTag]:
return {PassTag.DivideBasicBlock}
def generated_tags(self) -> Set[PassTag]:
return {PassTag.PrintBasicBlock}
def invalidated_tags(self) -> Set[PassTag]:
return set()
def reset(self, program):
# This pass has nothing to reset
pass
def run(self, program) -> bool:
assert program.optimizer_state is not None
assert program.optimizer_state.divide_basic_block is not None
state = program.optimizer_state.divide_basic_block # type: DivideBasicBlockPassState
assert state is not None
# Run a sanity check before printing
state._sanity_check()
# Get Gpr index
def get_gpr_index(gpr) -> Optional[int]:
if gpr in program.forced_index:
return program.forced_index[gpr]
elif gpr in program.assigned_index:
return program.assigned_index[gpr]
else:
return None
self.wr() # write an empty line at first
for bb in state.basic_blocks:
self.wr(f'{bb.name}:')
with self.wr.indent():
self.wr(f'//')
self.wr(f'// bb_predecessors: {[pred.name for pred in bb.predecessors]}')
self.wr(f'//')
self.wr(f'// live_var_in: {bb.live_var_in if bb.live_var_in is not None else "- // not run"}')
self.wr(f'// live_var_uses: {bb.live_var_uses if bb.live_var_uses is not None else "- // not run"}')
if bb.live_var_gpr_life_span is not None:
clauses_and_jump_instr_count = len(bb.clauses)
if bb.jump_instr is not None:
clauses_and_jump_instr_count += 1
self.wr(f"//")
self.wr(f"// live_var_gpr_life_span:")
self.wr(f"// /* number of instructions: {clauses_and_jump_instr_count} */")
self.wr(f"// /* number of base Spec: {sum(1 for x in bb.live_var_gpr_life_span if x.rtype.is_special())} */")
self.wr(f"// /* number of base SGpr: {sum(1 for x in bb.live_var_gpr_life_span if x.rtype.is_sgpr())} */")
self.wr(f"// /* number of base VGpr: {sum(1 for x in bb.live_var_gpr_life_span if x.rtype.is_vgpr())} */")
for base_gpr, bitmap_list in sorted(bb.live_var_gpr_life_span.items(), key=repr):
def bitmap_to_string(bitmap: int):
s = ""
for idx in range(0, 2*clauses_and_jump_instr_count+2):
s += 'x' if (bitmap & 1) else '_'
bitmap >>= 1
if idx % 2 == 0: s += ' '
assert bitmap == 0
return s
self.wr(f"// {base_gpr}:")
for idx, bitmap in enumerate(bitmap_list):
self.wr(f"// [{idx:-2}] = {bitmap_to_string(bitmap)}")
self.wr(f"//")
self.wr(f'//')
self.wr(f'// pending_mem_in: {bb.pending_mem_in if bb.pending_mem_in is not None else "- // not run"}')
self.wr(f'//')
if bb.annotate_clauses is not None:
for annclause in bb.annotate_clauses:
self.wr(repr(annclause))
else:
for clause in bb.clauses:
self.wr(repr(clause))
self.wr(f'//')
self.wr(f'// bb_successor_if_jump: {repr(bb.successor_if_jump.name) if bb.successor_if_jump is not None else "-"}') # noqa E501: line too long
self.wr(f'// bb_successor_if_fallthrough: {repr(bb.successor_if_fallthrough.name) if bb.successor_if_fallthrough is not None else "-"}') # noqa E501: line too long
self.wr(f'// bb_jump_instr: {bb.jump_instr.instr_name if bb.jump_instr is not None else "-"} // {bb.control_flow_at_exit.name}') # noqa E501: line too long
self.wr(f'//')
self.wr(f'// live_var_defs: {bb.live_var_defs if bb.live_var_defs is not None else "- // not run"}')
self.wr(f'// live_var_out: {bb.live_var_out if bb.live_var_out is not None else "- // not run"}')
self.wr(f'//')
self.wr(f'// pending_mem_out: {bb.pending_mem_out if bb.pending_mem_out is not None else "- // not run"}') # noqa E501: line too long
self.wr(f'//')
self.wr()
# All basic-blocks have been printed...
# Let's print some global information
self.wr("//" + "=" * 32)
self.wr("//" + "SUMMARY".center(32))
self.wr("//" + "=" * 32)
state = program.optimizer_state # type: OptimizerState
# Print register interference if computed
if state.register_interference_vgpr:
register_interference_vgpr = state.register_interference_vgpr
self.wr("//")
self.wr("// VGpr interference:")
for base_gpr in sorted(register_interference_vgpr.keys(), key=repr):
self.wr(f"// {register_interference_vgpr[base_gpr]}")
self.wr("//")
if state.register_interference_sgpr:
register_interference_sgpr = state.register_interference_sgpr
self.wr("//")
self.wr("// SGpr interference:")
for base_gpr in sorted(register_interference_sgpr.keys(), key=repr):
self.wr(f"// {register_interference_sgpr[base_gpr]}")
self.wr("//")
# Print register allocation if computed
if state.register_allocation_vgpr_by_color:
register_allocation_vgpr_by_color = state.register_allocation_vgpr_by_color
self.wr("//")
self.wr("// VGPR allocation by coloring:")
for color, gpr_list in enumerate(register_allocation_vgpr_by_color):
self.wr(f"// Color #{color} (count={max(gpr.count for gpr in gpr_list)}): {gpr_list}")
self.wr("//")
if state.register_interference_sgpr:
register_allocation_sgpr_by_color = state.register_allocation_sgpr_by_color
self.wr("//")
self.wr("// SGPR allocation by coloring:")
for color, gpr_list in enumerate(register_allocation_sgpr_by_color):
self.wr(f"// Color #{color} (count={max(gpr.count for gpr in gpr_list)}): {gpr_list}")
self.wr("//")
if state.register_allocation_vgpr_count is not None:
self.wr(f"// VGPR Count: {state.register_allocation_vgpr_count}")
if state.register_allocation_sgpr_count is not None:
self.wr(f"// SGPR Count: {state.register_allocation_sgpr_count}")
self.wr() # write an empty line at last
return False
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment