Commit 400cac28 authored by coderfeli's avatar coderfeli
Browse files

Merge branch 'develop' of https://github.com/ROCm/composable_kernel into update_cka8w8

parents 7cec63a6 4c2eff02
......@@ -9,508 +9,509 @@
#endif
"s_mov_b32 s16, %[s_res_a0] \n"
"s_mov_b32 s17, %[s_res_a1] \n"
"s_mov_b32 s18, %[s_res_a2] \n"
"s_mov_b32 s19, %[s_res_a3] \n"
"s_mov_b32 s20, %[s_res_b0] \n"
"s_mov_b32 s21, %[s_res_b1] \n"
"s_mov_b32 s22, %[s_res_b2] \n"
"s_mov_b32 s23, %[s_res_b3] \n"
// "s_nop 4\n"
"; -- prefetch A0\n"
"s_add_u32 m0, 0, %[s_m0_init] \n"
"buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[smem_sz], %[s_m0_init] \n"
"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move a with cond \n"
"s_cselect_b32 s86, %[s_tile_os_a], 0 ; move a with cond \n"
"s_add_u32 s16, s86, s16 ; move a with cond \n"
"s_addc_u32 s17, 0, s17 ; move a with cond \n"
"; -- prefetch A1\n"
"buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
"s_add_u32 m0, 0, %[s_m0_init] \n"
"s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
"s_cselect_b32 s86, %[s_tile_os_a], 0 ; move a with cond \n"
"s_add_u32 s16, s86, s16 ; move a with cond \n"
"s_addc_u32 s17, 0, s17 ; move a with cond \n"
"; -- prefetch B0\n"
"buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072 \n"
"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
"s_cselect_b32 s86, %[s_tile_os_b], 0 ; move b with cond \n"
"s_add_u32 s20, s86, s20 ; move b with cond \n"
"s_addc_u32 s21, 0, s21 ; move b with cond \n"
"s_waitcnt vmcnt(40) \n"
"s_barrier \n"
"ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0]\n" // 1024: N stride, 64 K stride
"ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1]\n"
"ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2]\n"
"ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3]\n"
"ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4]\n"
"ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5]\n"
"ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6]\n"
"ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7]\n"
"L_start%=: \n"
" s_waitcnt vmcnt(24) & lgkmcnt(0) \n"
" s_barrier \n"
_UK_MFMA_ " %[v_acc_0], acc[0:1], v[64:65], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[2:3], v[66:67], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[128:131], %[v_os_b0], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_0], acc[4:5], v[68:69], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[6:7], v[70:71], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_0], acc[8:9], v[72:73], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[10:11], v[74:75], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[132:135], %[v_os_b0], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_0], acc[12:13], v[76:77], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[14:15], v[78:79], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_1], acc[0:1], v[80:81], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[2:3], v[82:83], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[136:139], %[v_os_b0], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_1], acc[4:5], v[84:85], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[6:7], v[86:87], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_1], acc[8:9], v[88:89], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[10:11], v[90:91], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[140:143], %[v_os_b0], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_1], acc[12:13], v[92:93], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[14:15], v[94:95], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_2], acc[16:17], v[64:65], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[18:19], v[66:67], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[144:147], %[v_os_b1], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_2], acc[20:21], v[68:69], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[22:23], v[70:71], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_2], acc[24:25], v[72:73], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[26:27], v[74:75], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[148:151], %[v_os_b1], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_2], acc[28:29], v[76:77], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[30:31], v[78:79], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_3], acc[16:17], v[80:81], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[18:19], v[82:83], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[152:155], %[v_os_b1], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_3], acc[20:21], v[84:85], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[22:23], v[86:87], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_3], acc[24:25], v[88:89], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[26:27], v[90:91], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[156:159], %[v_os_b1], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_3], acc[28:29], v[92:93], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[30:31], v[94:95], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[smem_sz], %[s_m0_init] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " %[v_acc_4], acc[32:33], v[64:65], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[34:35], v[66:67], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[160:163], %[v_os_b2], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_4], acc[36:37], v[68:69], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[38:39], v[70:71], %[v_acc_4] \n"
" ds_read_b128 v[96:99], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_0] \n"
_UK_MFMA_ " %[v_acc_4], acc[40:41], v[72:73], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[42:43], v[74:75], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[164:167], %[v_os_b2], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_4], acc[44:45], v[76:77], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[46:47], v[78:79], %[v_acc_4] \n"
" ds_read_b128 v[100:103], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_1] \n"
_UK_MFMA_ " %[v_acc_5], acc[32:33], v[80:81], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[34:35], v[82:83], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[168:171], %[v_os_b2], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_5], acc[36:37], v[84:85], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[38:39], v[86:87], %[v_acc_5] \n"
" ds_read_b128 v[104:107], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_2] \n"
_UK_MFMA_ " %[v_acc_5], acc[40:41], v[88:89], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[42:43], v[90:91], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[172:175], %[v_os_b2], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_5], acc[44:45], v[92:93], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[46:47], v[94:95], %[v_acc_5] \n"
" ds_read_b128 v[108:111], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_3] \n"
_UK_MFMA_ " %[v_acc_6], acc[48:49], v[64:65], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[50:51], v[66:67], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[176:179], %[v_os_b3], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_6], acc[52:53], v[68:69], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[54:55], v[70:71], %[v_acc_6] \n"
" ds_read_b128 v[112:115], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_4] \n"
_UK_MFMA_ " %[v_acc_6], acc[56:57], v[72:73], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[58:59], v[74:75], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[180:183], %[v_os_b3], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_6], acc[60:61], v[76:77], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[62:63], v[78:79], %[v_acc_6] \n"
" ds_read_b128 v[116:119], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_5] \n"
_UK_MFMA_ " %[v_acc_7], acc[48:49], v[80:81], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[50:51], v[82:83], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[184:187], %[v_os_b3], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_7], acc[52:53], v[84:85], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[54:55], v[86:87], %[v_acc_7] \n"
" ds_read_b128 v[120:123], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_6] \n"
_UK_MFMA_ " %[v_acc_7], acc[56:57], v[88:89], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[58:59], v[90:91], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[188:191], %[v_os_b3], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_7], acc[60:61], v[92:93], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[62:63], v[94:95], %[v_acc_7] \n"
" ds_read_b128 v[124:127], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_7] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " %[v_acc_8], acc[64:65], v[64:65], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[66:67], v[66:67], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[192:195], %[v_os_b4], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_8], acc[68:69], v[68:69], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[70:71], v[70:71], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[72:73], v[72:73], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[74:75], v[74:75], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[196:199], %[v_os_b4], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_8], acc[76:77], v[76:77], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[78:79], v[78:79], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_9], acc[64:65], v[80:81], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[66:67], v[82:83], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[200:203], %[v_os_b4], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_9], acc[68:69], v[84:85], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[70:71], v[86:87], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[72:73], v[88:89], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[74:75], v[90:91], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[204:207], %[v_os_b4], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_9], acc[76:77], v[92:93], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[78:79], v[94:95], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_10], acc[80:81], v[64:65], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[82:83], v[66:67], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[208:211], %[v_os_b5], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_10], acc[84:85], v[68:69], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[86:87], v[70:71], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[88:89], v[72:73], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[90:91], v[74:75], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[212:215], %[v_os_b5], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_10], acc[92:93], v[76:77], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[94:95], v[78:79], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_11], acc[80:81], v[80:81], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[82:83], v[82:83], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[216:219], %[v_os_b5], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_11], acc[84:85], v[84:85], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[86:87], v[86:87], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[88:89], v[88:89], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[90:91], v[90:91], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[220:223], %[v_os_b5], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_11], acc[92:93], v[92:93], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[94:95], v[94:95], %[v_acc_11] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " %[v_acc_12], acc[96:97], v[64:65], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[98:99], v[66:67], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[224:227], %[v_os_b6], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_12], acc[100:101], v[68:69], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[102:103], v[70:71], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[104:105], v[72:73], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[106:107], v[74:75], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[228:231], %[v_os_b6], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_12], acc[108:109], v[76:77], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[110:111], v[78:79], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_13], acc[96:97], v[80:81], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[98:99], v[82:83], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[232:235], %[v_os_b6], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_13], acc[100:101], v[84:85], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[102:103], v[86:87], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[104:105], v[88:89], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[106:107], v[90:91], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[236:239], %[v_os_b6], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_13], acc[108:109], v[92:93], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[110:111], v[94:95], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_14], acc[112:113], v[64:65], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[114:115], v[66:67], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[240:243], %[v_os_b7], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_14], acc[116:117], v[68:69], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[118:119], v[70:71], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[120:121], v[72:73], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[122:123], v[74:75], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[244:247], %[v_os_b7], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_14], acc[124:125], v[76:77], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[126:127], v[78:79], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_15], acc[112:113], v[80:81], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[114:115], v[82:83], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[248:251], %[v_os_b7], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_15], acc[116:117], v[84:85], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[118:119], v[86:87], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[120:121], v[88:89], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[122:123], v[90:91], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[20:23], 0 offen offset:3072\n"
_UK_MFMA_ " %[v_acc_15], acc[124:125], v[92:93], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[126:127], v[94:95], %[v_acc_15] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
" s_cselect_b32 s86, %[s_tile_os_a], 0 \n"
" s_add_u32 s16, s86, s16 \n"
" s_addc_u32 s17, 0, s17 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s20, s86, s20 \n"
" s_addc_u32 s21, 0, s21 \n"
" ;------------------------------------------ \n"
" s_waitcnt vmcnt(24) & lgkmcnt(0) \n"
" s_barrier \n"
_UK_MFMA_ " %[v_acc_0], acc[128:129], v[96:97], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[130:131], v[98:99], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_0], acc[132:133], v[100:101], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[134:135], v[102:103], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_0], acc[136:137], v[104:105], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[138:139], v[106:107], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_0], acc[140:141], v[108:109], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[142:143], v[110:111], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_1], acc[128:129], v[112:113], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[130:131], v[114:115], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_1], acc[132:133], v[116:117], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[134:135], v[118:119], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_1], acc[136:137], v[120:121], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[138:139], v[122:123], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_1], acc[140:141], v[124:125], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[142:143], v[126:127], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_2], acc[144:145], v[96:97], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[146:147], v[98:99], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_2], acc[148:149], v[100:101], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[150:151], v[102:103], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_2], acc[152:153], v[104:105], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[154:155], v[106:107], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_2], acc[156:157], v[108:109], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[158:159], v[110:111], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_3], acc[144:145], v[112:113], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[146:147], v[114:115], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_3], acc[148:149], v[116:117], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[150:151], v[118:119], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_3], acc[152:153], v[120:121], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[154:155], v[122:123], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_3], acc[156:157], v[124:125], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[158:159], v[126:127], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
" s_add_u32 m0, 0, %[s_m0_init] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " %[v_acc_4], acc[160:161], v[96:97], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[162:163], v[98:99], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_4], acc[164:165], v[100:101], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[166:167], v[102:103], %[v_acc_4] \n"
" ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0] \n"
_UK_MFMA_ " %[v_acc_4], acc[168:169], v[104:105], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[170:171], v[106:107], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_4], acc[172:173], v[108:109], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[174:175], v[110:111], %[v_acc_4] \n"
" ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1] \n"
_UK_MFMA_ " %[v_acc_5], acc[160:161], v[112:113], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[162:163], v[114:115], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_5], acc[164:165], v[116:117], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[166:167], v[118:119], %[v_acc_5] \n"
" ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2] \n"
_UK_MFMA_ " %[v_acc_5], acc[168:169], v[120:121], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[170:171], v[122:123], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_5], acc[172:173], v[124:125], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[174:175], v[126:127], %[v_acc_5] \n"
" ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3] \n"
_UK_MFMA_ " %[v_acc_6], acc[176:177], v[96:97], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[178:179], v[98:99], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_6], acc[180:181], v[100:101], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[182:183], v[102:103], %[v_acc_6] \n"
" ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4] \n"
_UK_MFMA_ " %[v_acc_6], acc[184:185], v[104:105], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[186:187], v[106:107], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_6], acc[188:189], v[108:109], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[190:191], v[110:111], %[v_acc_6] \n"
" ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5] \n"
_UK_MFMA_ " %[v_acc_7], acc[176:177], v[112:113], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[178:179], v[114:115], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_7], acc[180:181], v[116:117], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[182:183], v[118:119], %[v_acc_7] \n"
" ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6] \n"
_UK_MFMA_ " %[v_acc_7], acc[184:185], v[120:121], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[186:187], v[122:123], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_7], acc[188:189], v[124:125], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[190:191], v[126:127], %[v_acc_7] \n"
" ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " %[v_acc_8], acc[192:193], v[96:97], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[194:195], v[98:99], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_8], acc[196:197], v[100:101], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[198:199], v[102:103], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[200:201], v[104:105], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[202:203], v[106:107], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_8], acc[204:205], v[108:109], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[206:207], v[110:111], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_9], acc[192:193], v[112:113], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[194:195], v[114:115], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_9], acc[196:197], v[116:117], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[198:199], v[118:119], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[200:201], v[120:121], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[202:203], v[122:123], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_9], acc[204:205], v[124:125], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[206:207], v[126:127], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_10], acc[208:209], v[96:97], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[210:211], v[98:99], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_10], acc[212:213], v[100:101], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[214:215], v[102:103], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[216:217], v[104:105], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[218:219], v[106:107], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_10], acc[220:221], v[108:109], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[222:223], v[110:111], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_11], acc[208:209], v[112:113], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[210:211], v[114:115], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_11], acc[212:213], v[116:117], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[214:215], v[118:119], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[216:217], v[120:121], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[218:219], v[122:123], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_11], acc[220:221], v[124:125], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[222:223], v[126:127], %[v_acc_11] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " %[v_acc_12], acc[224:225], v[96:97], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[226:227], v[98:99], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_12], acc[228:229], v[100:101], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[230:231], v[102:103], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[232:233], v[104:105], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[234:235], v[106:107], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_12], acc[236:237], v[108:109], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[238:239], v[110:111], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_13], acc[224:225], v[112:113], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[226:227], v[114:115], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_13], acc[228:229], v[116:117], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[230:231], v[118:119], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[232:233], v[120:121], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[234:235], v[122:123], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_13], acc[236:237], v[124:125], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[238:239], v[126:127], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_14], acc[240:241], v[96:97], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[242:243], v[98:99], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_14], acc[244:245], v[100:101], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[246:247], v[102:103], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[248:249], v[104:105], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[250:251], v[106:107], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_14], acc[252:253], v[108:109], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[254:255], v[110:111], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_15], acc[240:241], v[112:113], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[242:243], v[114:115], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_15], acc[244:245], v[116:117], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[246:247], v[118:119], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[248:249], v[120:121], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[250:251], v[122:123], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_15], acc[252:253], v[124:125], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[254:255], v[126:127], %[v_acc_15] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
" s_cselect_b32 s86, %[s_tile_os_a], 0 \n"
" s_add_u32 s16, s86, s16 \n"
" s_addc_u32 s17, 0, s17 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s20, s86, s20 \n"
" s_addc_u32 s21, 0, s21 \n"
" s_branch L_start%= \n"
"L_end%=: \n"
" s_nop 2 \n"
"s_mov_b32 s17, %[s_res_a1] \n"
"s_mov_b32 s18, %[s_res_a2] \n"
"s_mov_b32 s19, %[s_res_a3] \n"
"s_mov_b32 s20, %[s_res_b0] \n"
"s_mov_b32 s21, %[s_res_b1] \n"
"s_mov_b32 s22, %[s_res_b2] \n"
"s_mov_b32 s23, %[s_res_b3] \n"
// "s_nop 4\n"
"; -- prefetch A0\n"
"s_add_u32 m0, 0, %[s_m0_init] \n"
"buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[smem_sz], %[s_m0_init] \n"
"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move a with cond \n"
"s_cselect_b32 s86, %[s_tile_os_a], 0 ; move a with cond \n"
"s_add_u32 s16, s86, s16 ; move a with cond \n"
"s_addc_u32 s17, 0, s17 ; move a with cond \n"
"; -- prefetch A1\n"
"buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
"s_add_u32 m0, 0, %[s_m0_init] \n"
"s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
"s_cselect_b32 s86, %[s_tile_os_a], 0 ; move a with cond \n"
"s_add_u32 s16, s86, s16 ; move a with cond \n"
"s_addc_u32 s17, 0, s17 ; move a with cond \n"
"; -- prefetch B0\n"
"buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072 \n"
"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
"s_cselect_b32 s86, %[s_tile_os_b], 0 ; move b with cond \n"
"s_add_u32 s20, s86, s20 ; move b with cond \n"
"s_addc_u32 s21, 0, s21 ; move b with cond \n"
"s_waitcnt vmcnt(40) \n"
"s_barrier \n"
"ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0]\n" // 1024: N stride, 64
// K stride
"ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1]\n"
"ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2]\n"
"ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3]\n"
"ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4]\n"
"ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5]\n"
"ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6]\n"
"ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7]\n"
"L_start%=: \n"
" s_waitcnt vmcnt(24) & lgkmcnt(0) \n"
" s_barrier \n" _UK_MFMA_
" %[v_acc_0], acc[0:1], v[64:65], %[v_acc_0] \n" _UK_MFMA_
" %[v_acc_0], acc[2:3], v[66:67], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[128:131], %[v_os_b0], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_0], acc[4:5], v[68:69], %[v_acc_0] \n" _UK_MFMA_
" %[v_acc_0], acc[6:7], v[70:71], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_0], acc[8:9], v[72:73], %[v_acc_0] \n" _UK_MFMA_
" %[v_acc_0], acc[10:11], v[74:75], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[132:135], %[v_os_b0], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_0], acc[12:13], v[76:77], %[v_acc_0] \n" _UK_MFMA_
" %[v_acc_0], acc[14:15], v[78:79], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_1], acc[0:1], v[80:81], %[v_acc_1] \n" _UK_MFMA_
" %[v_acc_1], acc[2:3], v[82:83], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[136:139], %[v_os_b0], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_1], acc[4:5], v[84:85], %[v_acc_1] \n" _UK_MFMA_
" %[v_acc_1], acc[6:7], v[86:87], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_1], acc[8:9], v[88:89], %[v_acc_1] \n" _UK_MFMA_
" %[v_acc_1], acc[10:11], v[90:91], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[140:143], %[v_os_b0], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_1], acc[12:13], v[92:93], %[v_acc_1] \n" _UK_MFMA_
" %[v_acc_1], acc[14:15], v[94:95], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_2], acc[16:17], v[64:65], %[v_acc_2] \n" _UK_MFMA_
" %[v_acc_2], acc[18:19], v[66:67], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[144:147], %[v_os_b1], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_2], acc[20:21], v[68:69], %[v_acc_2] \n" _UK_MFMA_
" %[v_acc_2], acc[22:23], v[70:71], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_2], acc[24:25], v[72:73], %[v_acc_2] \n" _UK_MFMA_
" %[v_acc_2], acc[26:27], v[74:75], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[148:151], %[v_os_b1], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_2], acc[28:29], v[76:77], %[v_acc_2] \n" _UK_MFMA_
" %[v_acc_2], acc[30:31], v[78:79], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_3], acc[16:17], v[80:81], %[v_acc_3] \n" _UK_MFMA_
" %[v_acc_3], acc[18:19], v[82:83], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[152:155], %[v_os_b1], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_3], acc[20:21], v[84:85], %[v_acc_3] \n" _UK_MFMA_
" %[v_acc_3], acc[22:23], v[86:87], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_3], acc[24:25], v[88:89], %[v_acc_3] \n" _UK_MFMA_
" %[v_acc_3], acc[26:27], v[90:91], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[156:159], %[v_os_b1], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_3], acc[28:29], v[92:93], %[v_acc_3] \n" _UK_MFMA_
" %[v_acc_3], acc[30:31], v[94:95], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[smem_sz], %[s_m0_init] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" %[v_acc_4], acc[32:33], v[64:65], %[v_acc_4] \n" _UK_MFMA_
" %[v_acc_4], acc[34:35], v[66:67], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[160:163], %[v_os_b2], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_4], acc[36:37], v[68:69], %[v_acc_4] \n" _UK_MFMA_
" %[v_acc_4], acc[38:39], v[70:71], %[v_acc_4] \n"
" ds_read_b128 v[96:99], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_0] "
"\n" _UK_MFMA_ " %[v_acc_4], acc[40:41], v[72:73], %[v_acc_4] \n" _UK_MFMA_
" %[v_acc_4], acc[42:43], v[74:75], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[164:167], %[v_os_b2], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_4], acc[44:45], v[76:77], %[v_acc_4] \n" _UK_MFMA_
" %[v_acc_4], acc[46:47], v[78:79], %[v_acc_4] \n"
" ds_read_b128 v[100:103], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_1] "
"\n" _UK_MFMA_ " %[v_acc_5], acc[32:33], v[80:81], %[v_acc_5] \n" _UK_MFMA_
" %[v_acc_5], acc[34:35], v[82:83], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[168:171], %[v_os_b2], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_5], acc[36:37], v[84:85], %[v_acc_5] \n" _UK_MFMA_
" %[v_acc_5], acc[38:39], v[86:87], %[v_acc_5] \n"
" ds_read_b128 v[104:107], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_2] "
"\n" _UK_MFMA_ " %[v_acc_5], acc[40:41], v[88:89], %[v_acc_5] \n" _UK_MFMA_
" %[v_acc_5], acc[42:43], v[90:91], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[172:175], %[v_os_b2], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_5], acc[44:45], v[92:93], %[v_acc_5] \n" _UK_MFMA_
" %[v_acc_5], acc[46:47], v[94:95], %[v_acc_5] \n"
" ds_read_b128 v[108:111], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_3] "
"\n" _UK_MFMA_ " %[v_acc_6], acc[48:49], v[64:65], %[v_acc_6] \n" _UK_MFMA_
" %[v_acc_6], acc[50:51], v[66:67], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[176:179], %[v_os_b3], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_6], acc[52:53], v[68:69], %[v_acc_6] \n" _UK_MFMA_
" %[v_acc_6], acc[54:55], v[70:71], %[v_acc_6] \n"
" ds_read_b128 v[112:115], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_4] "
"\n" _UK_MFMA_ " %[v_acc_6], acc[56:57], v[72:73], %[v_acc_6] \n" _UK_MFMA_
" %[v_acc_6], acc[58:59], v[74:75], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[180:183], %[v_os_b3], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_6], acc[60:61], v[76:77], %[v_acc_6] \n" _UK_MFMA_
" %[v_acc_6], acc[62:63], v[78:79], %[v_acc_6] \n"
" ds_read_b128 v[116:119], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_5] "
"\n" _UK_MFMA_ " %[v_acc_7], acc[48:49], v[80:81], %[v_acc_7] \n" _UK_MFMA_
" %[v_acc_7], acc[50:51], v[82:83], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[184:187], %[v_os_b3], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_7], acc[52:53], v[84:85], %[v_acc_7] \n" _UK_MFMA_
" %[v_acc_7], acc[54:55], v[86:87], %[v_acc_7] \n"
" ds_read_b128 v[120:123], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_6] "
"\n" _UK_MFMA_ " %[v_acc_7], acc[56:57], v[88:89], %[v_acc_7] \n" _UK_MFMA_
" %[v_acc_7], acc[58:59], v[90:91], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[188:191], %[v_os_b3], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_7], acc[60:61], v[92:93], %[v_acc_7] \n" _UK_MFMA_
" %[v_acc_7], acc[62:63], v[94:95], %[v_acc_7] \n"
" ds_read_b128 v[124:127], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_7] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" %[v_acc_8], acc[64:65], v[64:65], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_8], acc[66:67], v[66:67], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[192:195], %[v_os_b4], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_8], acc[68:69], v[68:69], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_8], acc[70:71], v[70:71], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_8], acc[72:73], v[72:73], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_8], acc[74:75], v[74:75], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[196:199], %[v_os_b4], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_8], acc[76:77], v[76:77], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_8], acc[78:79], v[78:79], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_9], acc[64:65], v[80:81], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_9], acc[66:67], v[82:83], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[200:203], %[v_os_b4], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_9], acc[68:69], v[84:85], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_9], acc[70:71], v[86:87], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_9], acc[72:73], v[88:89], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_9], acc[74:75], v[90:91], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[204:207], %[v_os_b4], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_9], acc[76:77], v[92:93], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_9], acc[78:79], v[94:95], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_10], acc[80:81], v[64:65], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_10], acc[82:83], v[66:67], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[208:211], %[v_os_b5], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_10], acc[84:85], v[68:69], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_10], acc[86:87], v[70:71], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_10], acc[88:89], v[72:73], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_10], acc[90:91], v[74:75], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[212:215], %[v_os_b5], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_10], acc[92:93], v[76:77], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_10], acc[94:95], v[78:79], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_11], acc[80:81], v[80:81], %[v_acc_11] \n" _UK_MFMA_
" %[v_acc_11], acc[82:83], v[82:83], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[216:219], %[v_os_b5], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_11], acc[84:85], v[84:85], %[v_acc_11] \n" _UK_MFMA_
" %[v_acc_11], acc[86:87], v[86:87], %[v_acc_11] \n" _UK_MFMA_
" %[v_acc_11], acc[88:89], v[88:89], %[v_acc_11] \n" _UK_MFMA_
" %[v_acc_11], acc[90:91], v[90:91], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[220:223], %[v_os_b5], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_11], acc[92:93], v[92:93], %[v_acc_11] \n" _UK_MFMA_
" %[v_acc_11], acc[94:95], v[94:95], %[v_acc_11] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" %[v_acc_12], acc[96:97], v[64:65], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_12], acc[98:99], v[66:67], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[224:227], %[v_os_b6], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_12], acc[100:101], v[68:69], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_12], acc[102:103], v[70:71], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_12], acc[104:105], v[72:73], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_12], acc[106:107], v[74:75], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[228:231], %[v_os_b6], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_12], acc[108:109], v[76:77], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_12], acc[110:111], v[78:79], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_13], acc[96:97], v[80:81], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_13], acc[98:99], v[82:83], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[232:235], %[v_os_b6], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_13], acc[100:101], v[84:85], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_13], acc[102:103], v[86:87], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_13], acc[104:105], v[88:89], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_13], acc[106:107], v[90:91], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[236:239], %[v_os_b6], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_13], acc[108:109], v[92:93], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_13], acc[110:111], v[94:95], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_14], acc[112:113], v[64:65], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_14], acc[114:115], v[66:67], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[240:243], %[v_os_b7], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_14], acc[116:117], v[68:69], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_14], acc[118:119], v[70:71], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_14], acc[120:121], v[72:73], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_14], acc[122:123], v[74:75], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[244:247], %[v_os_b7], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_14], acc[124:125], v[76:77], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_14], acc[126:127], v[78:79], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_15], acc[112:113], v[80:81], %[v_acc_15] \n" _UK_MFMA_
" %[v_acc_15], acc[114:115], v[82:83], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[248:251], %[v_os_b7], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_15], acc[116:117], v[84:85], %[v_acc_15] \n" _UK_MFMA_
" %[v_acc_15], acc[118:119], v[86:87], %[v_acc_15] \n" _UK_MFMA_
" %[v_acc_15], acc[120:121], v[88:89], %[v_acc_15] \n" _UK_MFMA_
" %[v_acc_15], acc[122:123], v[90:91], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[20:23], 0 offen offset:3072\n" _UK_MFMA_
" %[v_acc_15], acc[124:125], v[92:93], %[v_acc_15] \n" _UK_MFMA_
" %[v_acc_15], acc[126:127], v[94:95], %[v_acc_15] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
" s_cselect_b32 s86, %[s_tile_os_a], 0 \n"
" s_add_u32 s16, s86, s16 \n"
" s_addc_u32 s17, 0, s17 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s20, s86, s20 \n"
" s_addc_u32 s21, 0, s21 \n"
" ;------------------------------------------ \n"
" s_waitcnt vmcnt(24) & lgkmcnt(0) \n"
" s_barrier \n" _UK_MFMA_
" %[v_acc_0], acc[128:129], v[96:97], %[v_acc_0] \n" _UK_MFMA_
" %[v_acc_0], acc[130:131], v[98:99], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_0], acc[132:133], v[100:101], %[v_acc_0] \n" _UK_MFMA_
" %[v_acc_0], acc[134:135], v[102:103], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_0], acc[136:137], v[104:105], %[v_acc_0] \n" _UK_MFMA_
" %[v_acc_0], acc[138:139], v[106:107], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_0], acc[140:141], v[108:109], %[v_acc_0] \n" _UK_MFMA_
" %[v_acc_0], acc[142:143], v[110:111], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_1], acc[128:129], v[112:113], %[v_acc_1] \n" _UK_MFMA_
" %[v_acc_1], acc[130:131], v[114:115], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_1], acc[132:133], v[116:117], %[v_acc_1] \n" _UK_MFMA_
" %[v_acc_1], acc[134:135], v[118:119], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_1], acc[136:137], v[120:121], %[v_acc_1] \n" _UK_MFMA_
" %[v_acc_1], acc[138:139], v[122:123], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_1], acc[140:141], v[124:125], %[v_acc_1] \n" _UK_MFMA_
" %[v_acc_1], acc[142:143], v[126:127], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_2], acc[144:145], v[96:97], %[v_acc_2] \n" _UK_MFMA_
" %[v_acc_2], acc[146:147], v[98:99], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_2], acc[148:149], v[100:101], %[v_acc_2] \n" _UK_MFMA_
" %[v_acc_2], acc[150:151], v[102:103], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_2], acc[152:153], v[104:105], %[v_acc_2] \n" _UK_MFMA_
" %[v_acc_2], acc[154:155], v[106:107], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_2], acc[156:157], v[108:109], %[v_acc_2] \n" _UK_MFMA_
" %[v_acc_2], acc[158:159], v[110:111], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_3], acc[144:145], v[112:113], %[v_acc_3] \n" _UK_MFMA_
" %[v_acc_3], acc[146:147], v[114:115], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_3], acc[148:149], v[116:117], %[v_acc_3] \n" _UK_MFMA_
" %[v_acc_3], acc[150:151], v[118:119], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_3], acc[152:153], v[120:121], %[v_acc_3] \n" _UK_MFMA_
" %[v_acc_3], acc[154:155], v[122:123], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_3], acc[156:157], v[124:125], %[v_acc_3] \n" _UK_MFMA_
" %[v_acc_3], acc[158:159], v[126:127], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
" s_add_u32 m0, 0, %[s_m0_init] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" %[v_acc_4], acc[160:161], v[96:97], %[v_acc_4] \n" _UK_MFMA_
" %[v_acc_4], acc[162:163], v[98:99], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_4], acc[164:165], v[100:101], %[v_acc_4] \n" _UK_MFMA_
" %[v_acc_4], acc[166:167], v[102:103], %[v_acc_4] \n"
" ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0] \n" _UK_MFMA_
" %[v_acc_4], acc[168:169], v[104:105], %[v_acc_4] \n" _UK_MFMA_
" %[v_acc_4], acc[170:171], v[106:107], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_4], acc[172:173], v[108:109], %[v_acc_4] \n" _UK_MFMA_
" %[v_acc_4], acc[174:175], v[110:111], %[v_acc_4] \n"
" ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1] \n" _UK_MFMA_
" %[v_acc_5], acc[160:161], v[112:113], %[v_acc_5] \n" _UK_MFMA_
" %[v_acc_5], acc[162:163], v[114:115], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_5], acc[164:165], v[116:117], %[v_acc_5] \n" _UK_MFMA_
" %[v_acc_5], acc[166:167], v[118:119], %[v_acc_5] \n"
" ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2] "
"\n" _UK_MFMA_ " %[v_acc_5], acc[168:169], v[120:121], %[v_acc_5] \n" _UK_MFMA_
" %[v_acc_5], acc[170:171], v[122:123], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_5], acc[172:173], v[124:125], %[v_acc_5] \n" _UK_MFMA_
" %[v_acc_5], acc[174:175], v[126:127], %[v_acc_5] \n"
" ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3] "
"\n" _UK_MFMA_ " %[v_acc_6], acc[176:177], v[96:97], %[v_acc_6] \n" _UK_MFMA_
" %[v_acc_6], acc[178:179], v[98:99], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_6], acc[180:181], v[100:101], %[v_acc_6] \n" _UK_MFMA_
" %[v_acc_6], acc[182:183], v[102:103], %[v_acc_6] \n"
" ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4] "
"\n" _UK_MFMA_ " %[v_acc_6], acc[184:185], v[104:105], %[v_acc_6] \n" _UK_MFMA_
" %[v_acc_6], acc[186:187], v[106:107], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_6], acc[188:189], v[108:109], %[v_acc_6] \n" _UK_MFMA_
" %[v_acc_6], acc[190:191], v[110:111], %[v_acc_6] \n"
" ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5] "
"\n" _UK_MFMA_ " %[v_acc_7], acc[176:177], v[112:113], %[v_acc_7] \n" _UK_MFMA_
" %[v_acc_7], acc[178:179], v[114:115], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_7], acc[180:181], v[116:117], %[v_acc_7] \n" _UK_MFMA_
" %[v_acc_7], acc[182:183], v[118:119], %[v_acc_7] \n"
" ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6] "
"\n" _UK_MFMA_ " %[v_acc_7], acc[184:185], v[120:121], %[v_acc_7] \n" _UK_MFMA_
" %[v_acc_7], acc[186:187], v[122:123], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_7], acc[188:189], v[124:125], %[v_acc_7] \n" _UK_MFMA_
" %[v_acc_7], acc[190:191], v[126:127], %[v_acc_7] \n"
" ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" %[v_acc_8], acc[192:193], v[96:97], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_8], acc[194:195], v[98:99], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_8], acc[196:197], v[100:101], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_8], acc[198:199], v[102:103], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_8], acc[200:201], v[104:105], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_8], acc[202:203], v[106:107], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_8], acc[204:205], v[108:109], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_8], acc[206:207], v[110:111], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_9], acc[192:193], v[112:113], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_9], acc[194:195], v[114:115], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_9], acc[196:197], v[116:117], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_9], acc[198:199], v[118:119], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_9], acc[200:201], v[120:121], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_9], acc[202:203], v[122:123], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_9], acc[204:205], v[124:125], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_9], acc[206:207], v[126:127], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_10], acc[208:209], v[96:97], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_10], acc[210:211], v[98:99], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_10], acc[212:213], v[100:101], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_10], acc[214:215], v[102:103], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_10], acc[216:217], v[104:105], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_10], acc[218:219], v[106:107], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_10], acc[220:221], v[108:109], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_10], acc[222:223], v[110:111], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_11], acc[208:209], v[112:113], %[v_acc_11] \n" _UK_MFMA_
" %[v_acc_11], acc[210:211], v[114:115], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_11], acc[212:213], v[116:117], %[v_acc_11] \n" _UK_MFMA_
" %[v_acc_11], acc[214:215], v[118:119], %[v_acc_11] \n" _UK_MFMA_
" %[v_acc_11], acc[216:217], v[120:121], %[v_acc_11] \n" _UK_MFMA_
" %[v_acc_11], acc[218:219], v[122:123], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_11], acc[220:221], v[124:125], %[v_acc_11] \n" _UK_MFMA_
" %[v_acc_11], acc[222:223], v[126:127], %[v_acc_11] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" %[v_acc_12], acc[224:225], v[96:97], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_12], acc[226:227], v[98:99], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_12], acc[228:229], v[100:101], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_12], acc[230:231], v[102:103], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_12], acc[232:233], v[104:105], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_12], acc[234:235], v[106:107], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_12], acc[236:237], v[108:109], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_12], acc[238:239], v[110:111], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_13], acc[224:225], v[112:113], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_13], acc[226:227], v[114:115], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_13], acc[228:229], v[116:117], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_13], acc[230:231], v[118:119], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_13], acc[232:233], v[120:121], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_13], acc[234:235], v[122:123], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_13], acc[236:237], v[124:125], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_13], acc[238:239], v[126:127], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_14], acc[240:241], v[96:97], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_14], acc[242:243], v[98:99], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_14], acc[244:245], v[100:101], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_14], acc[246:247], v[102:103], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_14], acc[248:249], v[104:105], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_14], acc[250:251], v[106:107], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_14], acc[252:253], v[108:109], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_14], acc[254:255], v[110:111], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_15], acc[240:241], v[112:113], %[v_acc_15] \n" _UK_MFMA_
" %[v_acc_15], acc[242:243], v[114:115], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_15], acc[244:245], v[116:117], %[v_acc_15] \n" _UK_MFMA_
" %[v_acc_15], acc[246:247], v[118:119], %[v_acc_15] \n" _UK_MFMA_
" %[v_acc_15], acc[248:249], v[120:121], %[v_acc_15] \n" _UK_MFMA_
" %[v_acc_15], acc[250:251], v[122:123], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_15], acc[252:253], v[124:125], %[v_acc_15] \n" _UK_MFMA_
" %[v_acc_15], acc[254:255], v[126:127], %[v_acc_15] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
" s_cselect_b32 s86, %[s_tile_os_a], 0 \n"
" s_add_u32 s16, s86, s16 \n"
" s_addc_u32 s17, 0, s17 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s20, s86, s20 \n"
" s_addc_u32 s21, 0, s21 \n"
" s_branch L_start%= \n"
"L_end%=: \n"
" s_nop 2 \n"
#undef _UK_MFMA_
......@@ -29,6 +29,8 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
......
......@@ -71,7 +71,8 @@ struct FmhaFwdKernel
using bfs = typename FmhaPipeline::BlockFmhaShape;
using g0br = typename bfs::Gemm0BlockWarps;
using g1br = typename bfs::Gemm1BlockWarps;
using gwt = typename bfs::Gemm0WarpTile;
using g0wt = typename bfs::Gemm0WarpTile;
using g1wt = typename bfs::Gemm1WarpTile;
#define _SS_ std::string
#define _TS_ std::to_string
auto pn = [&] () {
......@@ -88,7 +89,8 @@ struct FmhaFwdKernel
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
......
......@@ -8,9 +8,11 @@ namespace ck_tile {
template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdSplitKVCombineKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using FmhaPipeline = remove_cvref_t<FmhaPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using FmhaPipeline = remove_cvref_t<FmhaPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
static constexpr index_t kNumWarps = FmhaPipeline::kNumWarps;
static constexpr index_t kBlockSize = FmhaPipeline::kBlockSize;
static constexpr index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
......@@ -50,8 +52,7 @@ struct FmhaFwdSplitKVCombineKernel
return
_SS_("fmha_fwd_splitkv_combine_d") + _TS_(FmhaPipeline::kHeadDimV) + "_" + _SS_(t2s<ODataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_"
"b" + _TS_(FmhaPipeline::kM0) + "x" +
_TS_(FmhaPipeline::kN1) + "_" +
"b" + _TS_(FmhaPipeline::kN1) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
_SS_(FmhaPipeline::name) +
(pn.empty() ? "" : "_" + pn) +
......@@ -339,37 +340,56 @@ struct FmhaFwdSplitKVCombineKernel
number<FmhaPipeline::kAlignmentOacc>{},
number<1>{});
// read 4 * (kM0, kN1) o_acc tiles simultaneously by 4 warps
const auto o_acc_dram_view = pad_tensor_view(
o_acc_dram_naive,
make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<false, kPadSeqLenQ, kPadHeadDimV>{});
make_tuple(
number<kNumWarps>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<true, kPadSeqLenQ, kPadHeadDimV>{});
const index_t padded_num_splits =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<0>{}];
const index_t padded_seqlen_q =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}];
const index_t padded_hdim_v =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}];
return transform_tensor_view(
const index_t num_m_tiles = integer_divide_floor(padded_seqlen_q, FmhaPipeline::kM0);
// transform tensor view by following steps, given shape: (padded_num_splits,
// padded_seqlen_q, padded_hdim_v)
// 1. unmerge to (padded_num_splits, num_m_tiles, kM0, padded_hdim_v)
// 2. transpose to (num_m_tiles, padded_num_splits, kM0, padded_hdim_v)
// 3. merge to (num_m_tiles * padded_num_splits * kM0, padded_hdim_v)
auto transposed = transform_tensor_view(
o_acc_dram_view,
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_seqlen_q)),
make_tuple(make_pass_through_transform(padded_num_splits),
make_unmerge_transform(make_tuple(num_m_tiles, FmhaPipeline::kM0)),
make_pass_through_transform(padded_hdim_v)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<1>{}, sequence<0, 2>{}, sequence<3>{}));
return transform_tensor_view(
transposed,
make_tuple(make_merge_transform(
make_tuple(num_m_tiles, padded_num_splits, FmhaPipeline::kM0)),
make_pass_through_transform(padded_hdim_v)),
make_tuple(sequence<0, 1, 2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}();
auto lse_acc_dram_window = make_tile_window(
lse_acc_dram,
[&]() {
return make_tuple(number<FmhaPipeline::kMaxSplits>{}, number<FmhaPipeline::kM0>{});
}(),
make_tuple(number<FmhaPipeline::kMaxSplits>{}, number<FmhaPipeline::kM0>{}),
{0, i_m0});
const index_t padded_num_splits =
integer_divide_ceil(kargs.num_splits, kNumWarps) * kNumWarps;
auto o_acc_dram_window = make_tile_window(
o_acc_dram,
[&]() {
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{});
}(),
{i_m0, i_n1});
make_tuple(number<kNumWarps * FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
{i_tile_m * padded_num_splits * FmhaPipeline::kM0, i_n1});
// LSE DRAM window
auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
......@@ -410,7 +430,6 @@ struct FmhaFwdSplitKVCombineKernel
identity{}, // lse_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
kargs.num_splits,
kargs.seqlen_q,
smem_ptr);
}
else
......@@ -419,7 +438,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_dram_window,
lse_dram_window,
kargs.num_splits,
kargs.seqlen_q,
smem_ptr);
}
}();
......
......@@ -45,6 +45,7 @@ struct FmhaFwdSplitKVKernel
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
......@@ -67,7 +68,8 @@ struct FmhaFwdSplitKVKernel
using bfs = typename FmhaPipeline::BlockFmhaShape;
using g0br = typename bfs::Gemm0BlockWarps;
using g1br = typename bfs::Gemm1BlockWarps;
using gwt = typename bfs::Gemm0WarpTile;
using g0wt = typename bfs::Gemm0WarpTile;
using g1wt = typename bfs::Gemm1WarpTile;
#define _SS_ std::string
#define _TS_ std::to_string
auto pn = [&] () {
......@@ -84,11 +86,12 @@ struct FmhaFwdSplitKVKernel
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kDoFp8StaticQuant ? "_squant" : "") + (kIsPagedKV ? "_pagedkv" : "" );
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "") + (kIsPagedKV ? "_pagedkv" : "" );
#undef _SS_
#undef _TS_
// clang-format on
......
......@@ -53,6 +53,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr index_t kNumWarps = Problem::kNumWarps;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kHeadDimV = Problem::kHeadDimV;
......@@ -117,7 +118,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const LSEElementFunction& lse_element_func,
const OaccElementFunction& o_acc_element_func,
index_t num_splits,
index_t seqlen_q,
void* smem_ptr) const
{
// lse_acc tile in LDS
......@@ -143,11 +143,12 @@ struct BlockFmhaFwdSplitKVCombinePipeline
// copy lse_acc tile (shape=[kMaxSplits, kM0]) to LDS (shape=[kMaxSplits, kM0]).
auto lse_acc_tile = load_tile(lse_acc_dram_window);
store_tile(lse_acc_lds_write_window, lse_acc_tile);
block_sync_lds();
auto lse_accum = make_static_distributed_tensor<LSEDataType>(
Policy::template MakeLSEaccRegTileDistribution<Problem>());
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits])
// and fill up -INF values outside the [kM0, num_splits] region.
{
......@@ -264,46 +265,94 @@ struct BlockFmhaFwdSplitKVCombinePipeline
}
});
}
block_sync_lds();
if constexpr(kStoreLSE)
{
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum));
}
auto o_acc_dist = Policy::template MakeOaccDramTileDistribution<Problem>();
auto o_acc_dram_window =
auto o_acc_4_dist = Policy::template MakeOacc4DramTileDistribution<Problem>();
auto o_acc_4_dram_window =
make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(),
o_acc_dram_block_window_tmp.get_window_lengths(),
o_acc_dram_block_window_tmp.get_window_origin(),
o_acc_dist);
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
clear_tile(o_acc);
o_acc_4_dist);
const index_t padded_seqlen_q = integer_divide_ceil(seqlen_q, kM0) * kM0;
// shape=[4 * KM0, kN1]
auto o_acc_4 = make_static_distributed_tensor<OaccDataType>(o_acc_4_dist);
clear_tile(o_acc_4);
for(index_t i_split = 0; i_split < num_splits; ++i_split)
const index_t padded_num_splits = integer_divide_ceil(num_splits, kNumWarps) * kNumWarps;
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
// each warp handles a [KM0, kN1] tile
for(index_t split_start = 0; split_start < padded_num_splits; split_start += kNumWarps)
{
auto o_tile = load_tile(o_acc_dram_window);
auto o_tile = load_tile(o_acc_4_dram_window);
const index_t i_split = split_start + get_warp_id();
const index_t row_start = kM0 * get_warp_id();
{
constexpr auto spans = decltype(o_acc)::get_distributed_spans();
constexpr auto spans = decltype(o_acc_4)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
o_acc.get_tile_distribution(), i_j_idx);
o_acc_4.get_tile_distribution(), i_j_idx);
const auto row = x_indices.at(number<0>{});
const LSEDataType lse_scale = lse_acc_lds(row, i_split);
o_acc(i_j_idx) += lse_scale * o_tile(i_j_idx);
const LSEDataType lse_scale = lse_acc_lds(row - row_start, i_split);
o_acc_4(i_j_idx) += lse_scale * o_tile(i_j_idx);
});
});
}
move_tile_window(o_acc_dram_window, {padded_seqlen_q, 0});
move_tile_window(o_acc_4_dram_window, {kNumWarps * kM0, 0});
}
// 4 o_acc tiles in LDS. shape=[4 * kM0, kN1]
OaccDataType* o_acc_4_lds_ptr = static_cast<OaccDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeLSEacc<Problem>()));
{
auto o_acc_4_lds_window = [&]() {
auto desc = Policy::template MakeOacc4LdsBlockDescriptor<Problem>();
auto view = make_tensor_view<address_space_enum::lds>(o_acc_4_lds_ptr, desc);
return make_tile_window(view, desc.get_lengths(), {0, 0});
}();
store_tile(o_acc_4_lds_window, o_acc_4);
}
auto o_acc_dist = Policy::template MakeOaccDramTileDistribution<Problem>();
auto o_acc_4_lds_window = [&]() {
auto desc = Policy::template MakeOacc4LdsBlockDescriptor<Problem>();
auto view = make_tensor_view<address_space_enum::lds>(o_acc_4_lds_ptr, desc);
return make_tile_window(view, desc.get_lengths(), {0, 0}, o_acc_dist);
}();
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
clear_tile(o_acc);
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
static_for<0, kNumWarps, 1>{}([&](auto) {
auto o_acc_in = load_tile(o_acc_4_lds_window);
{
constexpr auto spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) += o_acc_in(i_j_idx);
});
});
}
move_tile_window(o_acc_4_lds_window, {kM0, 0});
});
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
......@@ -316,7 +365,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const OaccDramBlockWindow& o_acc_dram_block_window,
LSEDramBlockWindow& lse_dram_block_window,
index_t num_splits,
index_t seqlen_q,
void* smem_ptr) const
{
return operator()(lse_acc_dram_block_window,
......@@ -325,7 +373,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
identity{},
identity{},
num_splits,
seqlen_q,
smem_ptr);
}
};
......
......@@ -10,23 +10,38 @@ namespace ck_tile {
struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
template <index_t BlockSize, index_t M, index_t N, typename DataType>
template <index_t NumWarps, index_t M, index_t N, typename DataType>
CK_TILE_HOST_DEVICE static constexpr auto GetMaxNumWarpsForTile()
{
static_assert(NumWarps == 1 || NumWarps == 2 || NumWarps == 4);
constexpr index_t ElemPerThread = (M * N) / (NumWarps * get_warp_size());
if constexpr(0 < ElemPerThread)
{
return NumWarps;
}
else
{ // try dividing tile by smaller # of warps
return GetMaxNumWarpsForTile<NumWarps / 2, M, N, DataType>();
}
}
template <index_t NumWarps, index_t M, index_t N, typename DataType>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeForTile()
{
constexpr index_t PixelsPerThread = (M * N) / BlockSize;
static_assert(0 < PixelsPerThread);
constexpr index_t MaxNumWarps = GetMaxNumWarpsForTile<NumWarps, M, N, DataType>();
constexpr index_t MaxNPerThread = 16 / sizeof(DataType);
constexpr index_t NPerThread = min(MaxNPerThread, PixelsPerThread);
constexpr index_t ElemPerThread = (M * N) / (MaxNumWarps * get_warp_size());
return NPerThread;
constexpr index_t MaxNPerThread = 16 / sizeof(DataType);
return min(MaxNPerThread, ElemPerThread);
}
// alignment for dram lse tile (shape=[kMaxSplits, kM0])
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentLSE()
{
return GetVectorSizeForTile<Problem::kBlockSize,
return GetVectorSizeForTile<Problem::kNumWarps,
Problem::kMaxSplits,
Problem::kM0,
typename Problem::LSEDataType>();
......@@ -56,40 +71,54 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeLSEacc()
{
return sizeof(typename Problem::LSEDataType) *
MakeLSEaccLdsBlockDescriptor<Problem>().get_element_space_size();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOacc4()
{
return sizeof(typename Problem::OaccDataType) *
MakeOacc4LdsBlockDescriptor<Problem>().get_element_space_size();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return GetSmemSizeLSEacc<Problem>() + GetSmemSizeOacc4<Problem>();
}
// shape=[kMaxSplits, kM0]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNumWarps = Problem::kNumWarps;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t MaxNumWarps =
GetMaxNumWarpsForTile<Problem::kNumWarps, kNPerBlock, kMPerBlock, LSEDataType>();
constexpr index_t Replicate = Problem::kNumWarps / MaxNumWarps;
constexpr index_t NPerThread =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>();
GetVectorSizeForTile<MaxNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr index_t NThreads = kNPerBlock / NPerThread;
constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads;
constexpr index_t MPerThread = kMPerBlock / (kNumWarps * MThreadsPerWarp);
constexpr index_t MPerThread = kMPerBlock / (MaxNumWarps * MThreadsPerWarp);
static_assert(MPerThread * MaxNumWarps * MThreadsPerWarp == kMPerBlock);
static_assert(NThreads * NPerThread == kNPerBlock);
static_assert(MPerThread * kNumWarps * MThreadsPerWarp == kMPerBlock);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, kNumWarps, MThreadsPerWarp>,
tile_distribution_encoding<sequence<Replicate>,
tuple<sequence<MPerThread, MaxNumWarps, MThreadsPerWarp>,
sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
......@@ -100,17 +129,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kMaxSplits;
constexpr index_t NPack =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>();
GetVectorSizeForTile<Problem::kNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
number<8>{},
number<NPack>{},
number<1>{});
constexpr auto lse_acc_lds_block_desc = transform_tensor_descriptor(
......@@ -129,17 +156,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kMaxSplits;
constexpr index_t NPack =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>();
GetVectorSizeForTile<Problem::kNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
number<8>{},
number<NPack>{},
number<1>{});
constexpr auto lse_acc_t_lds_block_desc = transform_tensor_descriptor(
......@@ -152,33 +177,86 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
return lse_acc_t_lds_block_desc;
}
// 3d + padding, shape=[4 * kM0, kN1]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccRegTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeOacc4LdsBlockDescriptor()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kNPerBlock = Problem::kMaxSplits;
constexpr index_t kMPerBlock = 4 * Problem::kM0;
constexpr index_t kNPerBlock = Problem::kN1;
constexpr index_t NPack =
GetVectorSizeForTile<Problem::kNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr auto o_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto o_acc_t_lds_block_desc = transform_tensor_descriptor(
o_acc_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return o_acc_t_lds_block_desc;
}
// shape=[kM0, kMaxSplits]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccRegTileDistribution()
{
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kMaxSplits;
constexpr index_t NThreads = 4;
constexpr index_t NPerThread = kNPerBlock / NThreads;
constexpr index_t MaxNThreads = 8;
constexpr index_t NThreads = min(kNPerBlock, MaxNThreads);
constexpr index_t NPerThread = kNPerBlock / NThreads;
constexpr index_t MThreads = kBlockSize / NThreads;
constexpr index_t MPerThread = kMPerBlock / MThreads;
constexpr index_t MWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = 1;
constexpr index_t MThreads = kMPerBlock / MPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / NThreads;
constexpr index_t MaxNumWarps = (MThreads * NThreads) / get_warp_size();
constexpr index_t Replicate = Problem::kNumWarps / MaxNumWarps;
static_assert(MaxNumWarps * MThreadPerWarp * MPerThread == kMPerBlock);
static_assert(NThreads * NPerThread == kNPerBlock);
static_assert(MWarps * MThreadPerWarp * MPerThread == kMPerBlock);
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<MWarps, MThreadPerWarp, MPerThread>, sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<2, 1>>,
tuple<sequence<0>, sequence<0, 1>>,
sequence<1, 2>,
sequence<2, 1>>{});
tile_distribution_encoding<sequence<Replicate>,
tuple<sequence<MaxNumWarps, MThreadPerWarp, MPerThread>,
sequence<NThreads, NPerThread>>,
tuple<sequence<0, 1>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<0, 1>>,
sequence<1, 2>,
sequence<2, 1>>{});
}
// similar to MakeOaccDramTileDistribution(), but duplicate same 1-warp encoding 4 times on M
// direction
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOacc4DramTileDistribution()
{
constexpr index_t kMPerBlock = Problem::kM0; // real kMPerBlock we want is (4 * kM0)
constexpr index_t kNPerBlock = Problem::kN1;
static_assert(get_warp_size() <= kMPerBlock * kNPerBlock);
constexpr index_t M1 = 1; // compose encoding base on 1 warp
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
constexpr index_t N0 = get_warp_size() / M2;
constexpr index_t N1 = kNPerBlock / N0;
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<4, M0, M1, M2>, sequence<N0, N1>>,
tuple<sequence<1, 1>, sequence<1, 2>>,
tuple<sequence<0, 2>, sequence<3, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
template <typename Problem>
......@@ -187,6 +265,7 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kN1;
static_assert(kBlockSize <= kMPerBlock * kNPerBlock);
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
template <typename Problem_,
typename Policy_ = BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy>
struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentOacc =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOacc<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
if constexpr(kQKHeaddim <= 32)
{
return 2;
}
else if constexpr(kQKHeaddim <= 64)
{
return 3;
}
else if constexpr(kQKHeaddim <= 128)
{
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
}
else if constexpr(kQKHeaddim <= 256)
{
return 1;
}
}
}();
static constexpr const char* name = "qr_nwarp_sshuffle";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowLengths,
typename KPageBlockNavigator,
typename VDramBlockWindowLengths,
typename VPageBlockNavigator,
typename BiasDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction,
typename BiasElementFunction,
typename LSEaccElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
const KPageBlockNavigator& k_page_block_navigator,
const KElementFunction& k_element_func,
const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
const VPageBlockNavigator& v_page_block_navigator,
const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
const LSEaccElementFunction& lse_acc_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
index_t num_splits,
index_t i_split,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
void* smem_ptr) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KPageBlockNavigator::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VPageBlockNavigator::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kSubQKHeaddim ==
QDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN0 == KDramBlockWindowLengths{}[number<0>{}] &&
kK0 == KDramBlockWindowLengths{}[number<1>{}] &&
kN1 == VDramBlockWindowLengths{}[number<0>{}] &&
kK1 == VDramBlockWindowLengths{}[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// Q tile in LDS
QDataType* q_lds_ptr =
static_cast<QDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
// K tile in LDS
KDataType* k_lds_ptr =
static_cast<KDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
max(Policy::template GetSmemSizeQ<Problem>(),
Policy::template GetSmemSizeK<Problem>())),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// S tile in LDS
auto s_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<SaccDataType*>(reinterpret_cast<char*>(smem_ptr) +
max(Policy::template GetSmemSizeQ<Problem>(),
Policy::template GetSmemSizeK<Problem>())),
Policy::template MakeSLdsBlockDescriptor<Problem>());
auto s_write_lds_window = make_tile_window(
s_lds, Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
auto s_read_lds_window =
make_tile_window(s_lds,
Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0},
Policy::template MakeSRegTileDistribution<Problem>());
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>());
// load Q here, will store Q into LDS to maximize throughput
auto origin_q = load_tile(q_dram_window);
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{};
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
auto o_acc = OaccBlockTileType{};
// infer Sacc, S, P, M, L, Oacc type
using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(o_acc));
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
// init M, L
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
{
const index_t logical_num_total_loop =
integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0);
if(logical_num_total_loop <= 0)
{
if constexpr(kStoreLSE)
{
auto lse_acc =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
if(get_thread_local_1d_id() < kM0)
{
store_tile(lse_acc_dram_window_tmp,
tile_elementwise_in(lse_acc_element_func, lse_acc));
}
}
// Note: here occ are all cleard, return it
// Note: q loaded but no fence, ignore it.
return o_acc;
}
}
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
// make sure the first tile is completely located in page-block (page-block size should be
// divisible by kN0)
// relationship between each *_start variables: aligned_physical_seqlen_k_start <=
// physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
const index_t aligned_physical_seqlen_k_start =
[&, physical_seqlen_k_start_ = physical_seqlen_k_start] {
if constexpr(kIsPagedKV)
{
return kN0 * integer_divide_floor(physical_seqlen_k_start_, kN0);
}
else
{
return physical_seqlen_k_start_;
}
}();
const index_t num_total_loop =
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}),
logical_seqlen_k_start - (physical_seqlen_k_start -
aligned_physical_seqlen_k_start)}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths,
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
// store Q into LDS
__builtin_amdgcn_sched_barrier(0);
auto q_lds_window_for_store = make_tile_window(
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
store_tile(q_lds_window_for_store, origin_q);
__builtin_amdgcn_sched_barrier(0);
// load Q from LDS
__builtin_amdgcn_sched_barrier(0);
auto q_lds_window_for_load = make_tile_window(
q_lds,
Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0},
Policy::template MakeQRegTileDistribution<Problem, decltype(gemm_0)>());
block_sync_lds();
auto q = load_tile(q_lds_window_for_load);
__builtin_amdgcn_sched_barrier(0);
auto q_tile = tile_elementwise_in(q_element_func, q);
// prefetch K tile
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops);
static_assert(1 <= k1_loops);
auto k_dram_window = make_tile_window(
k_dram_block_window,
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load the first tile of the first iteration and store to LDS
auto k_block_tile = load_tile(k_dram_window);
// moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here.
move_tile_window(k_dram_window, {0, kK0});
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
do
{
// STAGE 1, QK gemm
clear_tile(s_acc); // initialize C
// load the second tile of the first iteration
k_block_tile = load_tile(k_dram_window);
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
if constexpr(k0_loops > 2)
{
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
block_sync_lds();
gemm_0(s_acc,
get_slice_tile(q_tile,
sequence<0, i_k0 * kK0>{},
sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window);
block_sync_lds();
move_tile_window(k_dram_window, {0, kK0});
store_tile(
k_lds_window,
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
k_block_tile = load_tile(k_dram_window); // global read i + 2
});
}
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
{ // tail
block_sync_lds();
gemm_0(s_acc,
get_slice_tile(q_tile,
sequence<0, (k0_loops - 2) * kK0>{},
sequence<kM0, (k0_loops - 1) * kK0>{}),
k_lds_window);
block_sync_lds();
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
block_sync_lds();
gemm_0(s_acc,
get_slice_tile(q_tile,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kM0, k0_loops * kK0>{}),
k_lds_window);
}
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x += type_convert<SaccDataType>(bias_element_func(y));
#else
x += log2e_v<SaccDataType> *
type_convert<SaccDataType>(bias_element_func(y));
#endif
},
s_acc,
bias_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s;
// position_encoding accept only logical coordinates, do conversion here
position_encoding.update(s_acc(i_j_idx), row, col - kv_l2p_offset);
});
});
}
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#endif
}
move_tile_window(bias_dram_window, {0, kN0});
/// TODO: only check in first/last iteration without increasing code size
if constexpr(kHasUnevenSplits)
{
const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
set_tile_if(
s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&,
physical_seqlen_k_start_ = physical_seqlen_k_start,
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
if constexpr(kIsPagedKV)
{
return col < physical_seqlen_k_start_ || physical_seqlen_k_end_ <= col;
}
else
{
return physical_seqlen_k_end_ <= col;
}
});
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
// mask accept only logical coordinates, do conversion here
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}) - kv_l2p_offset,
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col - kv_l2p_offset);
});
}
}
__builtin_amdgcn_sched_barrier(0);
// load the first tile for next iteration
if(i_total_loops < num_total_loop - 1)
{
// move K tile windows
i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0});
k_dram_window = make_tile_window(
k_dram_block_window,
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window
// laod the first tile of the first iteration and store to LDS
k_block_tile = load_tile(k_dram_window);
}
__builtin_amdgcn_sched_barrier(0);
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
// shuffle through LDS so that the tile layout is consistent with required by Gemm1
store_tile(s_write_lds_window, s);
block_sync_lds();
auto s_new = load_tile(s_read_lds_window);
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s_new,
sequence<1>{},
f_max,
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
s_new.get_tile_distribution()); // Pcompute{j}
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}
else
{
return raw_m;
}
};
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale_s * get_validated_m(m[i_idx]);
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max);
}
#else
p_compute(i_j_idx) = exp(s_new[i_j_idx] - get_validated_m(m[i_idx]));
#endif
});
});
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
}
}();
#else
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
#endif
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc(i_j_idx) *= tmp;
});
});
block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_prefetch);
store_tile(
v_lds_window,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
}
i_page_block_v =
v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1});
// STAGE 3, KV gemm
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&,
&i_page_block_v_ = i_page_block_v,
&v_dram_window_ = v_dram_window](auto i_k1) {
const auto v = load_tile(v_dram_window_); // load next v
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_window);
block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v);
store_tile(v_lds_window,
tile_elementwise_in(v_element_func,
v_shuffle_tmp)); // store the prefetch
}
else
{
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v)); // store next v
}
i_page_block_v_ = v_page_block_navigator.move_tile_window(
i_page_block_v_, v_dram_window_, {0, kK1});
});
}
// tail
{
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, k1_loops * kK1>{}),
v_lds_window);
block_sync_lds();
}
__builtin_amdgcn_sched_barrier(0);
// load the first tile for next iteration
if(i_total_loops < num_total_loop - 1)
{
// store the first tile for next iteration to LDS
// moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here.
move_tile_window(k_dram_window, {0, kK0});
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
}
} while(++i_total_loops < num_total_loop);
if constexpr(kStoreLSE)
{
// store lse acc
auto lse_acc = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans();
sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
}
else
{
lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
}
#else
lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
#endif
});
if(get_thread_local_1d_id() < kM0)
{
store_tile(lse_acc_dram_window_tmp,
tile_elementwise_in(lse_acc_element_func, lse_acc));
}
}
// finally, O
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = [&]() {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
}
else
return 1 / l[i_idx];
}();
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowLengths,
typename KPageBlockNavigator,
typename VDramBlockWindowLengths,
typename VPageBlockNavigator,
typename BiasDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
const KPageBlockNavigator& k_page_block_navigator,
const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
const VPageBlockNavigator& v_page_block_navigator,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile
index_t num_splits,
index_t i_split,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
void* smem_ptr) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_lengths,
k_page_block_navigator,
identity{},
v_dram_block_window_lengths,
v_page_block_navigator,
identity{},
bias_dram_block_window_tmp,
identity{},
lse_acc_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
num_splits,
i_split,
mask,
position_encoding,
scale_s,
kv_l2p_offset,
smem_ptr);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>
{
using BasePolicy = BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
// this should align with MakeQDramTileDistribution()
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
return min(ElemPerThread, MaxVectorSize);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
{
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
return static_cast<index_t>(16 / sizeof(OaccDataType));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t KPerThread = kMaxVecLoad;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem, typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
{
return BasePolicy::template MakeQDramTileDistribution<Problem, BlockGemm>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
{
// TODO: this is for 3d layout
using QDataType = remove_cvref_t<typename Problem::QDataType>;
return static_cast<index_t>(16 / sizeof(QDataType));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
constexpr index_t kKPack = min(ElemPerThread, GetSmemKPackQ<Problem>());
constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto q_lds_block_desc = transform_tensor_descriptor(
q_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<kMPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return q_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemNPackS()
{
using SDataType = remove_cvref_t<typename Problem::SaccDataType>;
return static_cast<index_t>(16 / sizeof(SDataType));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeSLdsBlockDescriptor()
{
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kNPack = GetSmemNPackS<Problem>();
constexpr auto s_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / kNPack>{}, number<kMPerBlock>{}, number<kNPack>{}),
make_tuple(number<(kMPerBlock + 1) * kNPack>{}, number<kNPack>{}, number<1>{}),
number<kNPack>{},
number<1>{});
constexpr auto s_lds_block_desc = transform_tensor_descriptor(
s_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<kMPerBlock>{}),
make_merge_transform(make_tuple(number<kNPerBlock / kNPack>{}, number<kNPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return s_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeSRegTileDistribution()
{
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
static_assert(MWarp == 1, "Check failed!");
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kTileK = Problem::BlockFmhaShape::kN0;
// K2 is equal to Impl::kABKPerLane * kKIterPerWarpGemm
constexpr index_t K3 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K2 = WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K1 = kKPerBlock / (K2 * K3);
constexpr index_t K0 = kTileK / kKPerBlock;
constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t M1 = MWarp;
constexpr index_t M0 = kMPerBlock / (M2 * M1);
constexpr auto s2_block_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2, K3>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<1, 0>, sequence<2, 2>>,
sequence<1, 2, 2, 2>,
sequence<0, 0, 1, 3>>{};
constexpr auto s2_block_dstr = make_static_tile_distribution(s2_block_dstr_encoding);
return s2_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
{
return MakeQLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::QDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK()
{
return MakeKLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::KDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV()
{
return MakeVLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::VDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeS()
{
return MakeSLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::SaccDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return max(GetSmemSizeQ<Problem>(), GetSmemSizeK<Problem>()) +
max(GetSmemSizeV<Problem>(), GetSmemSizeS<Problem>());
}
};
} // namespace ck_tile
......@@ -106,28 +106,43 @@ struct BlockFmhaFwdSplitKVPipelineProblem
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
// extract tile size attributes to remove dependency on traits
template <typename OaccDataType_, ck_tile::index_t kN1_>
struct BlockFmhaSplitKVCombinePipelineTileSizes
{
static constexpr index_t MaxVectorSize = 16 / sizeof(OaccDataType_);
static constexpr index_t kN1 = kN1_;
static constexpr index_t NThreads = kN1 / MaxVectorSize;
static constexpr index_t kM0 = get_warp_size() / NThreads; // MThreadPerWarp
};
template <typename LSEDataType_,
typename OaccDataType_,
typename ODataType_,
index_t HeadDimV_,
index_t kM0_,
index_t kN1_,
bool kIsGroupMode_,
ck_tile::index_t kN1_,
typename Traits_>
struct BlockFmhaSplitKVCombinePipelineProblem
: BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType_, kN1_>
{
using BaseType = BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType_, kN1_>;
using LSEDataType = remove_cvref_t<LSEDataType_>;
using OaccDataType = remove_cvref_t<OaccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kNumWarps = kM0_ / (get_warp_size() / 4);
static constexpr index_t kBlockSize = kNumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static_assert(std::is_same_v<LSEDataType, OaccDataType>);
static constexpr index_t kHeadDimV = HeadDimV_;
static constexpr index_t kM0 = kM0_;
static constexpr index_t kN1 = kN1_;
static constexpr bool kIsGroupMode = kIsGroupMode_;
using BaseType::kM0;
using BaseType::kN1;
static_assert(kN1 <= kHeadDimV && kHeadDimV % kN1 == 0);
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
......@@ -136,6 +151,13 @@ struct BlockFmhaSplitKVCombinePipelineProblem
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr index_t kMaxSplits = Traits::kMaxSplits;
static_assert(8 <= kMaxSplits);
static constexpr index_t kNumWarps = 4; // always use 4 warps for each workgroup
static constexpr index_t kBlockSize = kNumWarps * get_warp_size();
static_assert(get_warp_size() <= (kM0 * kMaxSplits) &&
(kM0 * kMaxSplits) % get_warp_size() == 0);
};
template <typename QDataType_,
......
......@@ -41,52 +41,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
return WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane);
}
template <typename Problem, typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t K2 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K0 = kKPerBlock / (K1 * K2);
constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t M1 = MWarp;
constexpr index_t M0 = kMPerBlock / (M2 * M1);
if constexpr(1 < Problem::kNumGemm0Warps)
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
tuple<sequence<1>, sequence<2, 1>>,
tuple<sequence<1>, sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
else
{
static_assert(MWarp == 1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
return BlockGemm::template MakeABlockTileDistribution<
Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kSubQKHeaddim>();
}
template <typename Problem>
......@@ -105,7 +74,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
constexpr auto warp_gemm = []() {
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
static_assert(WarpGemmM == 16 || WarpGemmM == 32);
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
......@@ -113,8 +82,10 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
{
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else // WarpGemmM == 16
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaF16F16F32M4N64K16{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> &&
......@@ -122,8 +93,10 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
{
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else // WarpGemmM == 16
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> &&
......
......@@ -43,8 +43,6 @@ struct TileFmhaShape
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
static_assert(std::is_same_v<Gemm0WarpTile, Gemm1WarpTile>);
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
......
......@@ -130,7 +130,8 @@ struct MoeSortingKernel
CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h)
{
const auto blocks = BlockSize(h);
return ((blocks.x + 1) * h.num_experts + (h.num_experts + 1)) * sizeof(index_t);
// usually num_experts is power of 2, we pad 1 dword here for the row-size
return ((blocks.x + 1) * (h.num_experts + 1) + (h.num_experts + 1)) * sizeof(index_t);
}
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
......@@ -154,6 +155,75 @@ struct MoeSortingKernel
return k;
}
// [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....]
template <typename data_t, int wave_size>
__device__ inline void wave_cumsum(data_t& thread_data) const
{
// wave_size must be power of 2
constexpr int row_mask = 0xf;
constexpr int bank_mask = 0xf;
constexpr bool bound_ctrl = true; // ! out-of-bound is zero !
auto reduce_op = [&](auto x_, auto y_) { return x_ + y_; };
if constexpr(wave_size > 1)
{
thread_data = reduce_op(
thread_data,
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x111,
row_mask,
bank_mask,
bound_ctrl))); // row_shr:1
}
if constexpr(wave_size > 2)
{
thread_data = reduce_op(
thread_data,
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x112,
row_mask,
bank_mask,
bound_ctrl))); // row_shr:2
}
if constexpr(wave_size > 4)
{
thread_data =
reduce_op(thread_data,
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x114,
row_mask,
bank_mask,
bound_ctrl))); // row_shr:4
}
if constexpr(wave_size > 8)
{
thread_data =
reduce_op(thread_data,
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x118,
row_mask,
bank_mask,
bound_ctrl))); // row_shr:8
}
if constexpr(wave_size > 16)
{
// now row-0, row-0+row-1, row-1+row-2, row-2+row-3
int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 1) << 2, __builtin_bit_cast(int, thread_data));
v_remote_tmp = __lane_id() >= 16 ? v_remote_tmp : 0;
thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp));
}
if constexpr(wave_size > 32)
{
// lane-id 48...63->31
int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 17) << 2, __builtin_bit_cast(int, thread_data));
v_remote_tmp = __lane_id() >= 32 ? v_remote_tmp : 0;
thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp));
}
}
CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const
{
return row * total_col + col;
......@@ -187,48 +257,124 @@ struct MoeSortingKernel
index_t* shared_mem = reinterpret_cast<index_t*>(smem);
index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts)
index_t* cumsum = shared_mem + (blockDim.x + 1) * num_experts; // 1: (num_experts + 1)
index_t* cumsum = shared_mem + (blockDim.x + 1) * (num_experts+1); // 1: (num_experts + 1)
for(int i = 0; i < num_experts; ++i)
{
tokens_cnts[calc_index(num_experts, tid + 1, i)] = 0;
tokens_cnts[calc_index(num_experts+1, tid + 1, i)] = 0;
}
#pragma unroll Problem_::InternalLoadUnroll
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{
++tokens_cnts[calc_index(num_experts, tid + 1, topk_id[i])];
++tokens_cnts[calc_index(num_experts+1, tid + 1, topk_id[i])];
}
__syncthreads();
#if 1
if(tid < num_experts)
{
tokens_cnts[calc_index(num_experts, 0, tid)] = 0;
for(int i = 1; i <= static_cast<index_t>(blockDim.x); ++i)
tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0;
index_t local_c[8];
index_t prev_c = 0;
// TODO: manually unroll. pragma unroll does not work well when we have dependency
for(int i = 1; i <= static_cast<index_t>(blockDim.x); i+= 8)
{
tokens_cnts[calc_index(num_experts, i, tid)] +=
tokens_cnts[calc_index(num_experts, i - 1, tid)];
local_c[0] = tokens_cnts[calc_index(num_experts+1, i + 0, tid)];
local_c[1] = tokens_cnts[calc_index(num_experts+1, i + 1, tid)];
local_c[2] = tokens_cnts[calc_index(num_experts+1, i + 2, tid)];
local_c[3] = tokens_cnts[calc_index(num_experts+1, i + 3, tid)];
local_c[4] = tokens_cnts[calc_index(num_experts+1, i + 4, tid)];
local_c[5] = tokens_cnts[calc_index(num_experts+1, i + 5, tid)];
local_c[6] = tokens_cnts[calc_index(num_experts+1, i + 6, tid)];
local_c[7] = tokens_cnts[calc_index(num_experts+1, i + 7, tid)];
local_c[0] += prev_c;
local_c[1] += local_c[0];
local_c[2] += local_c[1];
local_c[3] += local_c[2];
local_c[4] += local_c[3];
local_c[5] += local_c[4];
local_c[6] += local_c[5];
local_c[7] += local_c[6];
prev_c = local_c[7];
tokens_cnts[calc_index(num_experts+1, i + 0, tid)] = local_c[0];
tokens_cnts[calc_index(num_experts+1, i + 1, tid)] = local_c[1];
tokens_cnts[calc_index(num_experts+1, i + 2, tid)] = local_c[2];
tokens_cnts[calc_index(num_experts+1, i + 3, tid)] = local_c[3];
tokens_cnts[calc_index(num_experts+1, i + 4, tid)] = local_c[4];
tokens_cnts[calc_index(num_experts+1, i + 5, tid)] = local_c[5];
tokens_cnts[calc_index(num_experts+1, i + 6, tid)] = local_c[6];
tokens_cnts[calc_index(num_experts+1, i + 7, tid)] = local_c[7];
}
}
// __syncthreads();
if(tid == 0)
#else
// TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future heuristic
{
cumsum[0] = 0;
for(int i = 1; i <= num_experts; ++i)
if(tid < num_experts)
tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0;
for(int i = 0; i < num_experts; i+=8) {
index_t local_c[8];
#pragma unroll
for(int j = 0; j < 8; j++) {
local_c[j] = tokens_cnts[calc_index(num_experts+1, tid+1, i+j)];
}
#pragma unroll
for(int j = 0; j < 8; j++) {
wave_cumsum<int, 64>(local_c[j]);
}
#pragma unroll
for(int j = 0; j < 8; j++) {
tokens_cnts[calc_index(num_experts+1, tid+1, i+j)] = local_c[j];
}
}
}
#endif
__syncthreads();
if constexpr (Problem::ExpertTile == 0) {
if(tid == 0)
{
auto current_units = [&]() {
index_t x_ = tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)] +
unit_size_mdiv.divisor - 1;
index_t y_ = unit_size_mdiv.div(x_);
return max(y_, 1) * unit_size_mdiv.divisor;
}();
cumsum[i] = cumsum[i - 1] + current_units;
cumsum[0] = 0;
for(int i = 1; i <= num_experts; ++i)
{
auto current_units = [&]() {
index_t x_ = tokens_cnts[calc_index(num_experts+1, blockDim.x, i - 1)] +
unit_size_mdiv.divisor - 1;
index_t y_ = unit_size_mdiv.div(x_);
return max(y_, 1) * unit_size_mdiv.divisor;
}();
cumsum[i] = cumsum[i - 1] + current_units;
}
*p_total_tokens_post_pad = cumsum[num_experts];
}
} else {
// TODO: we have out-of-bound read here. But result is still OK (will ignore tid >= expert)
// for simplicity, not check experts here.
int local_cnt = tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)];
int blocks_pers_expert = unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1);
int padded_tokens_per_expert = max(blocks_pers_expert, 1) * unit_size_mdiv.divisor;
int local_cumsum = padded_tokens_per_expert;
wave_cumsum<int, 64>(local_cumsum);
if(tid == (num_experts - 1)) {
cumsum[0] = 0;
*p_total_tokens_post_pad = local_cumsum;
}
if(tid < num_experts) {
cumsum[tid + 1] = local_cumsum;
}
*p_total_tokens_post_pad = cumsum[num_experts];
}
__syncthreads();
if(tid < num_experts)
{
for(int i = cumsum[tid]; i < cumsum[tid + 1]; i += unit_size_mdiv.divisor)
int e_start = cumsum[tid];
int e_end = cumsum[tid + 1];
for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor)
{
p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid;
}
......@@ -238,8 +384,8 @@ struct MoeSortingKernel
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{
index_t expert_id = topk_id[i];
index_t rank_post_pad =
tokens_cnts[calc_index(num_experts, tid, expert_id)] + cumsum[expert_id];
index_t local_cnt = tokens_cnts[calc_index(num_experts+1, tid, expert_id)];
index_t rank_post_pad = local_cnt + cumsum[expert_id];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
uint32_t curr_token_id, curr_topk_id;
topk_mdiv.divmod(i, curr_token_id, curr_topk_id);
......@@ -247,27 +393,54 @@ struct MoeSortingKernel
#else
p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i);
#endif
p_sorted_weights[rank_post_pad] = weights[i];
++tokens_cnts[calc_index(num_experts, tid, expert_id)];
p_sorted_weights[rank_post_pad] = weights[i];
tokens_cnts[calc_index(num_experts+1, tid, expert_id)] = local_cnt+1;
}
const index_t prefill_token = topk_mdiv.div(numel);
if(tid < num_experts)
{
index_t expert_offset =
cumsum[tid] + tokens_cnts[calc_index(num_experts, blockDim.x, tid)];
while(expert_offset < cumsum[tid + 1])
if constexpr (Problem::ExpertTile == 0) {
const index_t prefill_token = topk_mdiv.div(numel);
if(tid < num_experts)
{
index_t expert_offset =
cumsum[tid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)];
index_t expert_end = cumsum[tid + 1];
while(expert_offset < expert_end)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[expert_offset] =
MOE_SORTING_MOCK_ID(prefill_token, topk_mdiv.divisor);
p_sorted_token_ids[expert_offset] =
MOE_SORTING_MOCK_ID(prefill_token, topk_mdiv.divisor);
#else
p_sorted_token_ids[expert_offset] = prefill_token;
p_sorted_token_ids[expert_offset] = prefill_token;
#endif
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
expert_offset++;
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
expert_offset++;
}
}
}
else {
const index_t prefill_token = topk_mdiv.div(numel);
// TODO: only support expert-tile like 8, 16, 32
static constexpr index_t experts_per_wave = warpSize / Problem::ExpertTile;
{
index_t eid = tid / experts_per_wave;
index_t expert_offset =
cumsum[eid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, eid)] + tid % experts_per_wave;
index_t expert_end = cumsum[eid + 1];
if(eid < num_experts) {
while(expert_offset < expert_end)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[expert_offset] =
MOE_SORTING_MOCK_ID(prefill_token, topk_mdiv.divisor);
#else
p_sorted_token_ids[expert_offset] = prefill_token;
#endif
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
expert_offset+=experts_per_wave;
}
}
}
}
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
......
......@@ -810,21 +810,46 @@ struct FusedMoeGemmPipelineFlatmmPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetUK_1()
{
using S_ = typename Problem::BlockShape;
using T_ = typename Problem::Traits;
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
T_::PipeInterleave == false)
{
return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16{};
// return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl{};
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
T_::PipeInterleave == false)
{
return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{};
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
T_::PipeInterleave == true)
{
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl{};
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
T_::PipeInterleave == true)
{
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{};
}
}
};
......
......@@ -22,7 +22,8 @@ template <bool IsGateOnly_,
FusedMoeGemmWeightPermuteEnum PermuteEnum_ =
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten,
bool PadHiddenSize_ = false,
bool PadIntermediateSize_ = false>
bool PadIntermediateSize_ = false,
bool PipeInterleave_ = true>
struct FusedMoeGemmTraits
{
// Gate+Up or Gate only
......@@ -32,6 +33,7 @@ struct FusedMoeGemmTraits
static constexpr FusedMoeGemmWeightPermuteEnum PermuteEnum = PermuteEnum_;
static constexpr bool PadHiddenSize = PadHiddenSize_;
static constexpr bool PadIntermediateSize = PadIntermediateSize_;
static constexpr bool PipeInterleave = PipeInterleave_;
};
// Note: this need to be a bit mask
......
......@@ -9,15 +9,20 @@
namespace ck_tile {
template <typename IndexType_, typename WeightType_, index_t InternalLoadUnroll_>
template <typename IndexType_,
typename WeightType_,
index_t InternalLoadUnroll_,
index_t ExpertTile_ = 0>
struct MoeSortingProblem
{
// TODO: this kernel only support warp per row
using WeightType = remove_cvref_t<WeightType_>;
using IndexType = remove_cvref_t<IndexType_>;
static constexpr index_t WarpSize = get_warp_size();
static constexpr index_t WarpsPerBlock = 1;
static constexpr index_t InternalLoadUnroll = InternalLoadUnroll_;
static constexpr index_t WarpSize = get_warp_size();
static constexpr index_t WarpsPerBlock = 1;
static constexpr index_t InternalLoadUnroll =
InternalLoadUnroll_; // TODO: need better design(like tile size)
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
};
} // namespace ck_tile
......@@ -65,14 +65,6 @@ struct BlockGemmARegBSmemCRegOneWarpV1
const index_t iNWarp = 0;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp>>,
......@@ -81,19 +73,14 @@ struct BlockGemmARegBSmemCRegOneWarpV1
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor =
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
MakeABlockTileDistribution());
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
......@@ -187,6 +174,33 @@ struct BlockGemmARegBSmemCRegOneWarpV1
});
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
{
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
return make_static_tile_distribution(a_block_dstr_encode);
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
......
......@@ -59,14 +59,6 @@ struct BlockGemmARegBSmemCRegV2
const index_t iNWarp = get_warp_id() % NWarp;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
......@@ -75,19 +67,14 @@ struct BlockGemmARegBSmemCRegV2
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor =
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
MakeABlockTileDistribution());
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
......@@ -182,6 +169,33 @@ struct BlockGemmARegBSmemCRegV2
});
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
{
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
return make_static_tile_distribution(a_block_dstr_encode);
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
......
......@@ -3,90 +3,93 @@
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
namespace ck_tile {
struct BatchedGemmHostArgs
struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs
{
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t batch_stride_A;
index_t batch_stride_B;
index_t batch_stride_C;
index_t batch_count;
CK_TILE_HOST BatchedGemmHostArgs() = default;
CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
ck_tile::index_t k_batch_,
ck_tile::index_t M_,
ck_tile::index_t N_,
ck_tile::index_t K_,
ck_tile::index_t stride_A_,
ck_tile::index_t stride_B_,
ck_tile::index_t stride_C_,
ck_tile::index_t batch_stride_A_,
ck_tile::index_t batch_stride_B_,
ck_tile::index_t batch_stride_C_,
ck_tile::index_t batch_count_)
: GemmHostArgs(
a_ptr_, b_ptr_, c_ptr_, k_batch_, M_, N_, K_, stride_A_, stride_B_, stride_C_),
batch_stride_A(batch_stride_A_),
batch_stride_B(batch_stride_B_),
batch_stride_C(batch_stride_C_),
batch_count(batch_count_)
{
}
ck_tile::index_t batch_stride_A;
ck_tile::index_t batch_stride_B;
ck_tile::index_t batch_stride_C;
ck_tile::index_t batch_count;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct BatchedGemmKernel
struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using GemmKernelArgs = typename Base::GemmKernelArgs;
struct BatchedGemmKargs
using ADataType = typename Base::ADataType;
using BDataType = typename Base::BDataType;
using CDataType = typename Base::CDataType;
using TilePartitioner = typename Base::TilePartitioner;
using GemmPipeline = typename Base::GemmPipeline;
using EpiloguePipeline = typename Base::EpiloguePipeline;
using ALayout = typename Base::ALayout;
using BLayout = typename Base::BLayout;
using CLayout = typename Base::CLayout;
struct BatchedGemmKernelArgs : GemmKernelArgs
{
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t batch_stride_A;
index_t batch_stride_B;
index_t batch_stride_C;
index_t batch_count;
};
using Kargs = BatchedGemmKargs;
using Hargs = BatchedGemmHostArgs;
using KernelArgs = BatchedGemmKernelArgs;
__host__ static constexpr auto GridSize(const Hargs& h)
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t batch_count)
{
return TilePartitioner::GridSize(h.M, h.N, h.batch_count);
return TilePartitioner::GridSize(M, N, batch_count);
}
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
__host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); }
CK_TILE_HOST static constexpr BatchedGemmKargs MakeKargs(const Hargs& h)
CK_TILE_HOST static constexpr BatchedGemmKernelArgs
MakeKernelArgs(const BatchedGemmHostArgs& hostArgs)
{
Kargs k;
k.a_ptr = h.a_ptr;
k.b_ptr = h.b_ptr;
k.c_ptr = h.c_ptr;
k.M = h.M;
k.N = h.N;
k.K = h.K;
k.stride_A = h.stride_A;
k.stride_B = h.stride_B;
k.stride_C = h.stride_C;
k.batch_stride_A = h.batch_stride_A;
k.batch_stride_B = h.batch_stride_B;
k.batch_stride_C = h.batch_stride_C;
k.batch_count = h.batch_count;
return k;
return BatchedGemmKernelArgs{{hostArgs.a_ptr,
hostArgs.b_ptr,
hostArgs.c_ptr,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.stride_A,
hostArgs.stride_B,
hostArgs.stride_C},
hostArgs.batch_stride_A,
hostArgs.batch_stride_B,
hostArgs.batch_stride_C,
hostArgs.batch_count};
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
......@@ -94,7 +97,7 @@ struct BatchedGemmKernel
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
{
const auto [i_m, i_n] = TilePartitioner{}();
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z);
......@@ -102,156 +105,17 @@ struct BatchedGemmKernel
// options
const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A);
const auto batch_offset_A = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_A);
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) + batch_offset_A;
const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B);
const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// Convert pointers to tensor views
auto a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_start + batch_offset_A,
make_tuple(kargs.M, kargs.K),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::VectorSizeA>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_start + batch_offset_A,
make_tuple(kargs.M, kargs.K),
make_tuple(1, kargs.stride_A),
number<1>{},
number<1>{});
}
}();
auto b_tensor_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
b_start + batch_offset_B,
make_tuple(kargs.N, kargs.K),
make_tuple(1, kargs.stride_B),
number<1>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
b_start + batch_offset_B,
make_tuple(kargs.N, kargs.K),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::VectorSizeB>{},
number<1>{});
}
}();
auto a_pad_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
// clang-format on
auto a_block_window = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0});
auto b_pad_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence<GemmPipeline::kPadN, false>{});
}
}();
// clang-format on
auto b_block_window = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0});
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
// Run GEMM cooperatively by whole wokrgroup.
auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr) + batch_offset_B;
const auto batch_stride_C = __builtin_amdgcn_readfirstlane(kargs.batch_stride_C);
const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
c_start + batch_offset_C,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::VectorSizeC>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
c_start + batch_offset_C,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
auto c_pad_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(
c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(
c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n});
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr) + batch_offset_C;
EpiloguePipeline{}(c_block_window, c_block_tile);
this->RunGemm(a_ptr, b_ptr, c_ptr, kargs, i_m, i_n);
}
};
......
......@@ -12,6 +12,50 @@
namespace ck_tile {
struct GemmProblem
{
CK_TILE_HOST GemmProblem() = default;
CK_TILE_HOST GemmProblem(
index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
: M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_)
{
}
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
};
struct GemmHostArgs : public GemmProblem
{
CK_TILE_HOST GemmHostArgs() = default;
CK_TILE_HOST GemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
index_t stride_C_)
: GemmProblem(M_, N_, K_, stride_A_, stride_B_, stride_C_),
a_ptr(a_ptr_),
b_ptr(b_ptr_),
c_ptr(c_ptr_),
k_batch(k_batch_)
{
}
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t k_batch;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GemmKernel
{
......@@ -25,9 +69,12 @@ struct GemmKernel
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{
return TilePartitioner::GridSize(M, N, KBatch);
......@@ -35,7 +82,7 @@ struct GemmKernel
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
struct GemmCommonKargs
struct GemmKernelArgs
{
const void* a_ptr;
const void* b_ptr;
......@@ -48,25 +95,37 @@ struct GemmKernel
index_t stride_C;
};
CK_TILE_HOST static constexpr GemmCommonKargs MakeKargs(const void* a_ptr,
const void* b_ptr,
void* c_ptr,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
index_t stride_C)
CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs)
{
return GemmCommonKargs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
return GemmKernelArgs{hostArgs.a_ptr,
hostArgs.b_ptr,
hostArgs.c_ptr,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.stride_A,
hostArgs.stride_B,
hostArgs.stride_C};
}
// CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const void* a_ptr,
// const void* b_ptr,
// void* c_ptr,
// index_t M,
// index_t N,
// index_t K,
// index_t stride_A,
// index_t stride_B,
// index_t stride_C)
// {
// return GemmKernelArgs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
// }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_HOST static bool IsSupportedArgument(const GemmCommonKargs& kargs)
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
{
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
......@@ -139,18 +198,16 @@ struct GemmKernel
return true;
}
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const
CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_ptr,
CDataType* c_ptr,
const GemmKernelArgs& kargs) const
{
const auto [i_m, i_n] = TilePartitioner{}();
// options
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// Convert pointers to tensor views
auto a_tensor_view = [&]() {
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_start,
a_ptr,
make_tuple(kargs.M, kargs.K),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::VectorSizeA>{},
......@@ -159,7 +216,7 @@ struct GemmKernel
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_start,
a_ptr,
make_tuple(kargs.M, kargs.K),
make_tuple(1, kargs.stride_A),
number<1>{},
......@@ -167,11 +224,11 @@ struct GemmKernel
}
}();
auto b_tensor_view = [&]() {
const auto& b_tensor_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
b_start,
b_ptr,
make_tuple(kargs.N, kargs.K),
make_tuple(1, kargs.stride_B),
number<1>{},
......@@ -180,7 +237,7 @@ struct GemmKernel
else
{
return make_naive_tensor_view<address_space_enum::global>(
b_start,
b_ptr,
make_tuple(kargs.N, kargs.K),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::VectorSizeB>{},
......@@ -188,7 +245,35 @@ struct GemmKernel
}
}();
auto a_pad_view = [&]() {
const auto& c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::VectorSizeC>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
return make_tuple(a_tensor_view, b_tensor_view, c_tensor_view);
}
template <typename TensorView>
CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView& views) const
{
const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(
......@@ -204,14 +289,9 @@ struct GemmKernel
sequence<GemmPipeline::kPadM, false>{});
}
}();
// clang-format on
auto a_block_window = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0});
auto b_pad_view = [&]() {
const auto& b_pad_view = [&]() {
const auto& b_tensor_view = views.at(I1);
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return pad_tensor_view(
......@@ -228,43 +308,8 @@ struct GemmKernel
}
}();
auto b_block_window = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0});
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
// Run GEMM cooperatively by whole wokrgroup.
auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::VectorSizeC>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
auto c_pad_view = [&]() {
const auto& c_pad_view = [&]() {
const auto& c_tensor_view = views.at(I2);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(
......@@ -280,12 +325,82 @@ struct GemmKernel
sequence<GemmPipeline::kPadM, false>{});
}
}();
auto CBlockWindow_pad = make_tile_window(
return make_tuple(a_pad_view, b_pad_view, c_pad_view);
}
template <typename PadView>
CK_TILE_DEVICE auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) const
{
const auto& a_pad_view = views.at(I0);
const auto& a_block_window = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0});
const auto& b_pad_view = views.at(I1);
const auto& b_block_window = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0});
const auto& c_pad_view = views.at(I2);
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n});
EpiloguePipeline{}(CBlockWindow_pad, c_block_tile);
return make_tuple(a_block_window, b_block_window, c_block_window);
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param kargs GEMM kernel arguments
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*/
CK_TILE_DEVICE void RunGemm(const ADataType* a_ptr,
const BDataType* b_ptr,
CDataType* c_ptr,
const GemmKernelArgs& kargs,
const index_t block_idx_m,
const index_t block_idx_n) const
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2);
EpiloguePipeline{}(c_block_window, c_block_tile);
}
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
{
const auto [i_m, i_n] = TilePartitioner{}();
// options
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
RunGemm(a_ptr, b_ptr, c_ptr, kargs, i_m, i_n);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment