Commit eca84f93 authored by root's avatar root
Browse files

Merge branch 'gemm_bf16_sk_muozturk' of...

Merge branch 'gemm_bf16_sk_muozturk' of https://github.com/ROCm/composable_kernel into gemm_bf16_sk_muozturk
parents 6f210155 c256f018
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#define CK_TILE_FLATMM_UK_MFMA_FP16 0
#define CK_TILE_FLATMM_UK_MFMA_BF16 1
#define CK_TILE_FLATMM_UK_MFMA_INT8 2
#define CK_TILE_FLATMM_UK_MFMA_FP8 3
#define CK_TILE_FLATMM_UK_MFMA_BF8 4
the files under this folder should not be included directly!
\ No newline at end of file
#ifndef CK_TILE_FLATMM_UK_MFMA
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#endif
#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_BF16
#define _UK_MFMA_ "v_mfma_f32_16x16x16_bf16"
#define _UK_PK_CVT_(x0_, x1_, y_) \
" v_cmp_u_f32 s[36:37], " x0_ ", " x0_ " \n" \
" v_add3_u32 v50, " x0_ ", %[v_nan_lo], 1 \n" \
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[36:37] \n" \
" v_cmp_u_f32 s[36:37], " x1_ ", " x1_ " \n" \
" v_add3_u32 v50, " x1_ ", %[v_nan_lo], 1 \n" \
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[36:37] \n" \
" v_perm_b32 " y_ ", v55, v54, s52 \n"
#define _UK_ATOMIC_ADD_ "global_atomic_pk_add_bf16"
#elif CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_FP16
#define _UK_MFMA_ "v_mfma_f32_16x16x16_f16"
#define _UK_PK_CVT_(x0_, x1_, y_) \
" v_cvt_f16_f32 v54, " x0_ " \n" \
" v_cvt_f16_f32 v55, " x1_ " \n" \
" v_pack_b32_f16 " y_ ", v54, v55 \n"
#define _UK_ATOMIC_ADD_ "global_atomic_pk_add_f16"
#endif
";-------------------------------------------------------------\n"
" s_mov_b32 s52, 0x07060302 ; v_perm\n"
" s_mov_b64 s[38:39], exec ; save current exec\n"
" s_mov_b32 s8, %[s_res_o0] \n"
" s_mov_b32 s9, %[s_res_o1] \n"
" s_mov_b32 s12, %[s_res_b0] \n"
" s_mov_b32 s13, %[s_res_b1] \n"
" s_mov_b32 s14, %[s_res_b2] \n"
" s_mov_b32 s15, %[s_res_b3] \n"
" ds_read_b64 v[128:129], %[v_sld_y_os] offset:0 + %[sld_a_base] \n"
" ds_read_b64 v[130:131], %[v_sld_y_os] offset:128 + %[sld_a_base] \n"
" ds_read_b64 v[132:133], %[v_sld_y_os] offset:1024 + %[sld_a_base] \n"
" ds_read_b64 v[134:135], %[v_sld_y_os] offset:1152 + %[sld_a_base] \n"
" ds_read_b64 v[136:137], %[v_sld_y_os] offset:2048 + %[sld_a_base] \n"
" ds_read_b64 v[138:139], %[v_sld_y_os] offset:2176 + %[sld_a_base] \n"
" ds_read_b64 v[140:141], %[v_sld_y_os] offset:3072 + %[sld_a_base] \n"
" ds_read_b64 v[142:143], %[v_sld_y_os] offset:3200 + %[sld_a_base] \n"
" ds_read_b64 v[144:145], %[v_sld_y_os] offset:4096 + %[sld_a_base] \n"
" ds_read_b64 v[146:147], %[v_sld_y_os] offset:4224 + %[sld_a_base] \n"
" ds_read_b64 v[148:149], %[v_sld_y_os] offset:5120 + %[sld_a_base] \n"
" ds_read_b64 v[150:151], %[v_sld_y_os] offset:5248 + %[sld_a_base] \n"
" ds_read_b64 v[152:153], %[v_sld_y_os] offset:6144 + %[sld_a_base] \n"
" ds_read_b64 v[154:155], %[v_sld_y_os] offset:6272 + %[sld_a_base] \n"
" ds_read_b64 v[156:157], %[v_sld_y_os] offset:7168 + %[sld_a_base] \n"
" ds_read_b64 v[158:159], %[v_sld_y_os] offset:7296 + %[sld_a_base] \n"
" ds_read_b64 v[160:161], %[v_sld_y_os] offset:8192 + %[sld_a_base] \n"
" ds_read_b64 v[162:163], %[v_sld_y_os] offset:8320 + %[sld_a_base] \n"
" ds_read_b64 v[164:165], %[v_sld_y_os] offset:9216 + %[sld_a_base] \n"
" ds_read_b64 v[166:167], %[v_sld_y_os] offset:9344 + %[sld_a_base] \n"
" ds_read_b64 v[168:169], %[v_sld_y_os] offset:10240 + %[sld_a_base] \n"
" ds_read_b64 v[170:171], %[v_sld_y_os] offset:10368 + %[sld_a_base] \n"
" ds_read_b64 v[172:173], %[v_sld_y_os] offset:11264 + %[sld_a_base] \n"
" ds_read_b64 v[174:175], %[v_sld_y_os] offset:11392 + %[sld_a_base] \n"
" ds_read_b64 v[176:177], %[v_sld_y_os] offset:12288 + %[sld_a_base] \n"
" ds_read_b64 v[178:179], %[v_sld_y_os] offset:12416 + %[sld_a_base] \n"
" ds_read_b64 v[180:181], %[v_sld_y_os] offset:13312 + %[sld_a_base] \n"
" ds_read_b64 v[182:183], %[v_sld_y_os] offset:13440 + %[sld_a_base] \n"
" ds_read_b64 v[184:185], %[v_sld_y_os] offset:14336 + %[sld_a_base] \n"
" ds_read_b64 v[186:187], %[v_sld_y_os] offset:14464 + %[sld_a_base] \n"
" ds_read_b64 v[188:189], %[v_sld_y_os] offset:15360 + %[sld_a_base] \n"
" ds_read_b64 v[190:191], %[v_sld_y_os] offset:15488 + %[sld_a_base] \n"
" ds_read_b64 v[192:193], %[v_sld_y_os] offset:16384 + %[sld_a_base] \n"
" ds_read_b64 v[194:195], %[v_sld_y_os] offset:16512 + %[sld_a_base] \n"
" ds_read_b64 v[196:197], %[v_sld_y_os] offset:17408 + %[sld_a_base] \n"
" ds_read_b64 v[198:199], %[v_sld_y_os] offset:17536 + %[sld_a_base] \n"
" ds_read_b64 v[200:201], %[v_sld_y_os] offset:18432 + %[sld_a_base] \n"
" ds_read_b64 v[202:203], %[v_sld_y_os] offset:18560 + %[sld_a_base] \n"
" ds_read_b64 v[204:205], %[v_sld_y_os] offset:19456 + %[sld_a_base] \n"
" ds_read_b64 v[206:207], %[v_sld_y_os] offset:19584 + %[sld_a_base] \n"
" ds_read_b64 v[208:209], %[v_sld_y_os] offset:20480 + %[sld_a_base] \n"
" ds_read_b64 v[210:211], %[v_sld_y_os] offset:20608 + %[sld_a_base] \n"
" ds_read_b64 v[212:213], %[v_sld_y_os] offset:21504 + %[sld_a_base] \n"
" ds_read_b64 v[214:215], %[v_sld_y_os] offset:21632 + %[sld_a_base] \n"
" ds_read_b64 v[216:217], %[v_sld_y_os] offset:22528 + %[sld_a_base] \n"
" ds_read_b64 v[218:219], %[v_sld_y_os] offset:22656 + %[sld_a_base] \n"
" ds_read_b64 v[220:221], %[v_sld_y_os] offset:23552 + %[sld_a_base] \n"
" ds_read_b64 v[222:223], %[v_sld_y_os] offset:23680 + %[sld_a_base] \n"
" ds_read_b64 v[224:225], %[v_sld_y_os] offset:24576 + %[sld_a_base] \n"
" ds_read_b64 v[226:227], %[v_sld_y_os] offset:24704 + %[sld_a_base] \n"
" ds_read_b64 v[228:229], %[v_sld_y_os] offset:25600 + %[sld_a_base] \n"
" ds_read_b64 v[230:231], %[v_sld_y_os] offset:25728 + %[sld_a_base] \n"
" ds_read_b64 v[232:233], %[v_sld_y_os] offset:26624 + %[sld_a_base] \n"
" ds_read_b64 v[234:235], %[v_sld_y_os] offset:26752 + %[sld_a_base] \n"
" ds_read_b64 v[236:237], %[v_sld_y_os] offset:27648 + %[sld_a_base] \n"
" ds_read_b64 v[238:239], %[v_sld_y_os] offset:27776 + %[sld_a_base] \n"
" ds_read_b64 v[240:241], %[v_sld_y_os] offset:28672 + %[sld_a_base] \n"
" ds_read_b64 v[242:243], %[v_sld_y_os] offset:28800 + %[sld_a_base] \n"
" ds_read_b64 v[244:245], %[v_sld_y_os] offset:29696 + %[sld_a_base] \n"
" ds_read_b64 v[246:247], %[v_sld_y_os] offset:29824 + %[sld_a_base] \n"
" ds_read_b64 v[248:249], %[v_sld_y_os] offset:30720 + %[sld_a_base] \n"
" ds_read_b64 v[250:251], %[v_sld_y_os] offset:30848 + %[sld_a_base] \n"
" ds_read_b64 v[252:253], %[v_sld_y_os] offset:31744 + %[sld_a_base] \n"
" ds_read_b64 v[254:255], %[v_sld_y_os] offset:31872 + %[sld_a_base] \n"
" s_waitcnt 0 \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[12:15], 0 offen offset:3072 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s12, s86, s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" s_waitcnt 0 \n"
"L_start%=: \n"
" s_waitcnt vmcnt(32) \n"
" s_barrier \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[0:1], v[128:129], 0 \n"
" buffer_load_dwordx4 acc[128:131], %[v_os_b0], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[2:3], v[130:131], [%[c0], %[c1], %[c2], %[c3]] "
"\n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[4:5], v[132:133], [%[c0], %[c1], %[c2], "
"%[c3]] \n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[6:7], v[134:135], "
"[%[c0], %[c1], %[c2], %[c3]] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[8:9], v[136:137], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[132:135], %[v_os_b0], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[10:11], v[138:139], [%[c0], %[c1], %[c2], %[c3]] "
"\n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[12:13], v[140:141], [%[c0], %[c1], %[c2], "
"%[c3]] \n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[14:15], v[142:143], "
"[%[c0], %[c1], %[c2], %[c3]] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[0:1], v[192:193], 0 \n"
" buffer_load_dwordx4 acc[136:139], %[v_os_b0], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[2:3], v[194:195], [%[c4], %[c5], %[c6], %[c7]] "
"\n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[4:5], v[196:197], [%[c4], %[c5], %[c6], "
"%[c7]] \n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[6:7], v[198:199], "
"[%[c4], %[c5], %[c6], %[c7]] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[8:9], v[200:201], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[140:143], %[v_os_b0], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[10:11], v[202:203], [%[c4], %[c5], %[c6], %[c7]] "
"\n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[12:13], v[204:205], [%[c4], %[c5], %[c6], "
"%[c7]] \n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[14:15], v[206:207], "
"[%[c4], %[c5], %[c6], %[c7]] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[16:17], v[128:129], 0 \n"
" buffer_load_dwordx4 acc[144:147], %[v_os_b1], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[18:19], v[130:131], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[20:21], v[132:133], [%[c8], %[c9], "
"%[c10], %[c11]] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[22:23], v[134:135], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[24:25], v[136:137], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[148:151], %[v_os_b1], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[26:27], v[138:139], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[28:29], v[140:141], [%[c8], %[c9], "
"%[c10], %[c11]] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[30:31], v[142:143], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[16:17], v[192:193], 0 \n"
" buffer_load_dwordx4 acc[152:155], %[v_os_b1], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[18:19], v[194:195], [%[c12], %[c13], %[c14], %[c15]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[20:21], v[196:197], [%[c12], %[c13], "
"%[c14], %[c15]] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[22:23], v[198:199], [%[c12], %[c13], %[c14], %[c15]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[24:25], v[200:201], [%[c12], %[c13], "
"%[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[156:159], %[v_os_b1], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[26:27], v[202:203], [%[c12], %[c13], %[c14], %[c15]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[28:29], v[204:205], [%[c12], %[c13], "
"%[c14], %[c15]] \n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[30:31], "
"v[206:207], [%[c12], %[c13], %[c14], %[c15]] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[32:33], v[144:145], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[160:163], %[v_os_b2], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[34:35], v[146:147], [%[c0], %[c1], %[c2], %[c3]] "
"\n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[36:37], v[148:149], [%[c0], %[c1], %[c2], "
"%[c3]] \n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[38:39], v[150:151], "
"[%[c0], %[c1], %[c2], %[c3]] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[40:41], v[152:153], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[164:167], %[v_os_b2], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[42:43], v[154:155], [%[c0], %[c1], %[c2], %[c3]] "
"\n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[44:45], v[156:157], [%[c0], %[c1], %[c2], "
"%[c3]] \n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[46:47], v[158:159], "
"[%[c0], %[c1], %[c2], %[c3]] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[32:33], v[208:209], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[168:171], %[v_os_b2], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[34:35], v[210:211], [%[c4], %[c5], %[c6], %[c7]] "
"\n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[36:37], v[212:213], [%[c4], %[c5], %[c6], "
"%[c7]] \n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[38:39], v[214:215], "
"[%[c4], %[c5], %[c6], %[c7]] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[40:41], v[216:217], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[172:175], %[v_os_b2], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[42:43], v[218:219], [%[c4], %[c5], %[c6], %[c7]] "
"\n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[44:45], v[220:221], [%[c4], %[c5], %[c6], "
"%[c7]] \n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[46:47], v[222:223], "
"[%[c4], %[c5], %[c6], %[c7]] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[48:49], v[144:145], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[176:179], %[v_os_b3], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[50:51], v[146:147], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[52:53], v[148:149], [%[c8], %[c9], "
"%[c10], %[c11]] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[54:55], v[150:151], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[56:57], v[152:153], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[180:183], %[v_os_b3], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[58:59], v[154:155], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[60:61], v[156:157], [%[c8], %[c9], "
"%[c10], %[c11]] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[62:63], v[158:159], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[48:49], v[208:209], [%[c12], %[c13], "
"%[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[184:187], %[v_os_b3], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[50:51], v[210:211], [%[c12], %[c13], %[c14], %[c15]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[52:53], v[212:213], [%[c12], %[c13], "
"%[c14], %[c15]] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[54:55], v[214:215], [%[c12], %[c13], %[c14], %[c15]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[56:57], v[216:217], [%[c12], %[c13], "
"%[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[188:191], %[v_os_b3], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[58:59], v[218:219], [%[c12], %[c13], %[c14], %[c15]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[60:61], v[220:221], [%[c12], %[c13], "
"%[c14], %[c15]] \n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[62:63], "
"v[222:223], [%[c12], %[c13], %[c14], %[c15]] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[64:65], v[160:161], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[192:195], %[v_os_b4], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[66:67], v[162:163], [%[c0], %[c1], %[c2], %[c3]] "
"\n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[68:69], v[164:165], [%[c0], %[c1], %[c2], "
"%[c3]] \n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[70:71], v[166:167], "
"[%[c0], %[c1], %[c2], %[c3]] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[72:73], v[168:169], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[196:199], %[v_os_b4], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[74:75], v[170:171], [%[c0], %[c1], %[c2], %[c3]] "
"\n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[76:77], v[172:173], [%[c0], %[c1], %[c2], "
"%[c3]] \n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[78:79], v[174:175], "
"[%[c0], %[c1], %[c2], %[c3]] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[64:65], v[224:225], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[200:203], %[v_os_b4], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[66:67], v[226:227], [%[c4], %[c5], %[c6], %[c7]] "
"\n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[68:69], v[228:229], [%[c4], %[c5], %[c6], "
"%[c7]] \n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[70:71], v[230:231], "
"[%[c4], %[c5], %[c6], %[c7]] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[72:73], v[232:233], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[204:207], %[v_os_b4], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[74:75], v[234:235], [%[c4], %[c5], %[c6], %[c7]] "
"\n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[76:77], v[236:237], [%[c4], %[c5], %[c6], "
"%[c7]] \n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[78:79], v[238:239], "
"[%[c4], %[c5], %[c6], %[c7]] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[80:81], v[160:161], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[208:211], %[v_os_b5], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[82:83], v[162:163], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[84:85], v[164:165], [%[c8], %[c9], "
"%[c10], %[c11]] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[86:87], v[166:167], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[88:89], v[168:169], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[212:215], %[v_os_b5], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[90:91], v[170:171], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[92:93], v[172:173], [%[c8], %[c9], "
"%[c10], %[c11]] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[94:95], v[174:175], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[80:81], v[224:225], [%[c12], %[c13], "
"%[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[216:219], %[v_os_b5], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[82:83], v[226:227], [%[c12], %[c13], %[c14], %[c15]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[84:85], v[228:229], [%[c12], %[c13], "
"%[c14], %[c15]] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[86:87], v[230:231], [%[c12], %[c13], %[c14], %[c15]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[88:89], v[232:233], [%[c12], %[c13], "
"%[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[220:223], %[v_os_b5], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[90:91], v[234:235], [%[c12], %[c13], %[c14], %[c15]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[92:93], v[236:237], [%[c12], %[c13], "
"%[c14], %[c15]] \n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[94:95], "
"v[238:239], [%[c12], %[c13], %[c14], %[c15]] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[96:97], v[176:177], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[224:227], %[v_os_b6], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[98:99], v[178:179], [%[c0], %[c1], %[c2], %[c3]] "
"\n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[100:101], v[180:181], [%[c0], %[c1], "
"%[c2], %[c3]] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[102:103], v[182:183], [%[c0], %[c1], %[c2], %[c3]] "
"\n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[104:105], v[184:185], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[228:231], %[v_os_b6], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[106:107], v[186:187], [%[c0], %[c1], %[c2], %[c3]] "
"\n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[108:109], v[188:189], [%[c0], %[c1], "
"%[c2], %[c3]] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[110:111], v[190:191], [%[c0], %[c1], %[c2], %[c3]] "
"\n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[96:97], v[240:241], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[232:235], %[v_os_b6], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[98:99], v[242:243], [%[c4], %[c5], %[c6], %[c7]] "
"\n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[100:101], v[244:245], [%[c4], %[c5], "
"%[c6], %[c7]] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[102:103], v[246:247], [%[c4], %[c5], %[c6], %[c7]] "
"\n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[104:105], v[248:249], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[236:239], %[v_os_b6], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[106:107], v[250:251], [%[c4], %[c5], %[c6], %[c7]] "
"\n" _UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[108:109], v[252:253], [%[c4], %[c5], "
"%[c6], %[c7]] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[110:111], v[254:255], [%[c4], %[c5], %[c6], %[c7]] "
"\n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[112:113], v[176:177], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[240:243], %[v_os_b7], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[114:115], v[178:179], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[116:117], v[180:181], [%[c8], %[c9], "
"%[c10], %[c11]] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[118:119], v[182:183], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[120:121], v[184:185], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[244:247], %[v_os_b7], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[122:123], v[186:187], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[124:125], v[188:189], [%[c8], %[c9], "
"%[c10], %[c11]] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[126:127], v[190:191], [%[c8], %[c9], %[c10], %[c11]] "
"\n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[112:113], v[240:241], [%[c12], %[c13], "
"%[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[248:251], %[v_os_b7], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[114:115], v[242:243], [%[c12], %[c13], %[c14], "
"%[c15]] \n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[116:117], v[244:245], [%[c12], "
"%[c13], %[c14], %[c15]] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[118:119], v[246:247], [%[c12], %[c13], %[c14], "
"%[c15]] \n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[120:121], v[248:249], [%[c12], "
"%[c13], %[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[122:123], v[250:251], [%[c12], %[c13], %[c14], "
"%[c15]] \n" _UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[124:125], v[252:253], [%[c12], "
"%[c13], %[c14], %[c15]] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[126:127], v[254:255], [%[c12], %[c13], %[c14], "
"%[c15]]\n"
" v_mul_f32 %[c0], %[scale_0], %[c0] \n"
" v_mul_f32 %[c1], %[scale_0], %[c1] \n"
" v_mul_f32 %[c2], %[scale_0], %[c2] \n"
" v_mul_f32 %[c3], %[scale_0], %[c3] \n"
" v_mul_f32 %[c4], %[scale_1], %[c4] \n"
" v_mul_f32 %[c5], %[scale_1], %[c5] \n"
" v_mul_f32 %[c6], %[scale_1], %[c6] \n"
" v_mul_f32 %[c7], %[scale_1], %[c7] \n"
" v_mul_f32 %[c8], %[scale_0], %[c8] \n"
" v_mul_f32 %[c9], %[scale_0], %[c9] \n"
" v_mul_f32 %[c10], %[scale_0], %[c10] \n"
" v_mul_f32 %[c11], %[scale_0], %[c11] \n"
" v_mul_f32 %[c12], %[scale_1], %[c12] \n"
" v_mul_f32 %[c13], %[scale_1], %[c13] \n"
" v_mul_f32 %[c14], %[scale_1], %[c14] \n"
" v_mul_f32 %[c15], %[scale_1], %[c15] \n" _UK_PK_CVT_(
"%[c0]", "%[c1]", "%[c0]") _UK_PK_CVT_("%[c2]", "%[c3]", "%[c1]")
_UK_PK_CVT_("%[c4]", "%[c5]", "%[c2]") _UK_PK_CVT_("%[c6]", "%[c7]", "%[c3]") _UK_PK_CVT_(
"%[c8]", "%[c9]", "%[c4]") _UK_PK_CVT_("%[c10]", "%[c11]", "%[c5]")
_UK_PK_CVT_("%[c12]", "%[c13]", "%[c6]") _UK_PK_CVT_(
"%[c14]",
"%[c15]",
"%[c7]") " ;------------------------------ \n"
" ds_write_b64 %[v_sfl_sst], [%[c0],%[c1]] offset:0 + %[shfl_base] "
" \n"
" ds_write_b64 %[v_sfl_sst], [%[c2],%[c3]] offset:4352 + %[shfl_base] "
" \n"
" ds_write_b64 %[v_sfl_sst], [%[c4],%[c5]] offset:2176 + %[shfl_base] "
" \n"
" ds_write_b64 %[v_sfl_sst], [%[c6],%[c7]] offset:6528 + %[shfl_base] "
" \n"
" s_waitcnt lgkmcnt(0) \n"
" s_barrier \n"
" ds_read_b32 %[c0], %[v_sfl_sld] offset:0 + %[shfl_base] "
" \n"
" ds_read_b32 %[c1], %[v_sfl_sld] offset:32 + %[shfl_base] "
" \n"
" ds_read_b32 %[c2], %[v_sfl_sld] offset:64 + %[shfl_base] "
" \n"
" ds_read_b32 %[c3], %[v_sfl_sld] offset:96 + %[shfl_base] "
" \n"
" ds_read_b32 %[c4], %[v_sfl_sld] offset:4352 + %[shfl_base] "
" \n"
" ds_read_b32 %[c5], %[v_sfl_sld] offset:4384 + %[shfl_base] "
" \n"
" ds_read_b32 %[c6], %[v_sfl_sld] offset:4416 + %[shfl_base] "
" \n"
" ds_read_b32 %[c7], %[v_sfl_sld] offset:4448 + %[shfl_base] "
" \n"
" s_waitcnt lgkmcnt(0) \n"
" s_mov_b64 exec, %[s_execflag_0] "
"\n" _UK_ATOMIC_ADD_ " %[v_os_o0], %[c0], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_1] "
"\n" _UK_ATOMIC_ADD_ " %[v_os_o1], %[c1], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_2] "
"\n" _UK_ATOMIC_ADD_ " %[v_os_o2], %[c2], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_3] "
"\n" _UK_ATOMIC_ADD_ " %[v_os_o3], %[c3], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_4] "
"\n" _UK_ATOMIC_ADD_ " %[v_os_o4], %[c4], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_5] "
"\n" _UK_ATOMIC_ADD_ " %[v_os_o5], %[c5], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_6] "
"\n" _UK_ATOMIC_ADD_ " %[v_os_o6], %[c6], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_7] "
"\n" _UK_ATOMIC_ADD_ " %[v_os_o7], %[c7], s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 ; k-- \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s12, s86, s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" s_add_u32 s8, %[s_tile_os_o], s8 \n"
" s_addc_u32 s9, 0, s9 \n"
" s_waitcnt vmcnt(32) \n"
" s_barrier \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[128:129], v[128:129], 0 \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen "
"\n" _UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[130:131], "
"v[130:131], [%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[132:133], v[132:133], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[134:135], v[134:135], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[136:137], v[136:137], "
"[%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen "
"offset:1024 \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[138:139], v[138:139], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[140:141], v[140:141], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[142:143], v[142:143], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[128:129], v[192:193], 0 \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen "
"offset:2048 \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[130:131], v[194:195], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[132:133], v[196:197], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[134:135], v[198:199], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[136:137], v[200:201], "
"[%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen "
"offset:3072 \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[138:139], v[202:203], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[140:141], v[204:205], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[142:143], v[206:207], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[144:145], v[128:129], 0 \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen "
"\n" _UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[146:147], "
"v[130:131], [%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[148:149], v[132:133], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[150:151], v[134:135], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[152:153], v[136:137], "
"[%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen "
"offset:1024 \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[154:155], v[138:139], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[156:157], v[140:141], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[158:159], v[142:143], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[144:145], v[192:193], 0 \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen "
"offset:2048 \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[146:147], v[194:195], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[148:149], v[196:197], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[150:151], v[198:199], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[152:153], v[200:201], "
"[%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen "
"offset:3072 \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[154:155], v[202:203], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[156:157], v[204:205], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[158:159], v[206:207], "
"[%[c28],%[c29],%[c30],%[c31]] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[160:161], v[144:145], "
"[%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen "
"\n" _UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[162:163], "
"v[146:147], [%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[164:165], v[148:149], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[166:167], v[150:151], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[168:169], v[152:153], "
"[%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen "
"offset:1024 \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[170:171], v[154:155], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[172:173], v[156:157], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[174:175], v[158:159], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[160:161], v[208:209], "
"[%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen "
"offset:2048 \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[162:163], v[210:211], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[164:165], v[212:213], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[166:167], v[214:215], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[168:169], v[216:217], "
"[%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen "
"offset:3072 \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[170:171], v[218:219], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[172:173], v[220:221], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[174:175], v[222:223], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[176:177], v[144:145], "
"[%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen "
"\n" _UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[178:179], "
"v[146:147], [%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[180:181], v[148:149], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[182:183], v[150:151], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[184:185], v[152:153], "
"[%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen "
"offset:1024 \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[186:187], v[154:155], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[188:189], v[156:157], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[190:191], v[158:159], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[176:177], v[208:209], "
"[%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen "
"offset:2048 \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[178:179], v[210:211], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[180:181], v[212:213], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[182:183], v[214:215], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[184:185], v[216:217], "
"[%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen "
"offset:3072 \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[186:187], v[218:219], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[188:189], v[220:221], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[190:191], v[222:223], "
"[%[c28],%[c29],%[c30],%[c31]] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[192:193], v[160:161], "
"[%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[12:15], 0 offen "
"\n" _UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[194:195], "
"v[162:163], [%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[196:197], v[164:165], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[198:199], v[166:167], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[200:201], v[168:169], "
"[%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[12:15], 0 offen "
"offset:1024 \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[202:203], v[170:171], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[204:205], v[172:173], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[206:207], v[174:175], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[192:193], v[224:225], "
"[%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[12:15], 0 offen "
"offset:2048 \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[194:195], v[226:227], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[196:197], v[228:229], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[198:199], v[230:231], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[200:201], v[232:233], "
"[%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[12:15], 0 offen "
"offset:3072 \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[202:203], v[234:235], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[204:205], v[236:237], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[206:207], v[238:239], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[208:209], v[160:161], "
"[%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[12:15], 0 offen "
"\n" _UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[210:211], "
"v[162:163], [%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[212:213], v[164:165], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[214:215], v[166:167], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[216:217], v[168:169], "
"[%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[12:15], 0 offen "
"offset:1024 \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[218:219], v[170:171], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[220:221], v[172:173], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[222:223], v[174:175], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[208:209], v[224:225], "
"[%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[12:15], 0 offen "
"offset:2048 \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[210:211], v[226:227], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[212:213], v[228:229], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[214:215], v[230:231], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[216:217], v[232:233], "
"[%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[12:15], 0 offen "
"offset:3072 \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[218:219], v[234:235], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[220:221], v[236:237], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[222:223], v[238:239], "
"[%[c28],%[c29],%[c30],%[c31]] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[224:225], v[176:177], "
"[%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[12:15], 0 offen "
"\n" _UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[226:227], "
"v[178:179], [%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[228:229], v[180:181], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[230:231], v[182:183], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[232:233], v[184:185], "
"[%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[12:15], 0 offen "
"offset:1024 \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[234:235], v[186:187], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[236:237], v[188:189], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c16],%[c17],%[c18],%[c19]], acc[238:239], v[190:191], "
"[%[c16],%[c17],%[c18],%[c19]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[224:225], v[240:241], "
"[%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[12:15], 0 offen "
"offset:2048 \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[226:227], v[242:243], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[228:229], v[244:245], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[230:231], v[246:247], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[232:233], v[248:249], "
"[%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[12:15], 0 offen "
"offset:3072 \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[234:235], v[250:251], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[236:237], v[252:253], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c20],%[c21],%[c22],%[c23]], acc[238:239], v[254:255], "
"[%[c20],%[c21],%[c22],%[c23]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[240:241], v[176:177], "
"[%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[12:15], 0 offen "
"\n" _UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[242:243], "
"v[178:179], [%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[244:245], v[180:181], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[246:247], v[182:183], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[248:249], v[184:185], "
"[%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[12:15], 0 offen "
"offset:1024 \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[250:251], v[186:187], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[252:253], v[188:189], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c24],%[c25],%[c26],%[c27]], acc[254:255], v[190:191], "
"[%[c24],%[c25],%[c26],%[c27]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[240:241], v[240:241], "
"[%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[12:15], 0 offen "
"offset:2048 \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[242:243], v[242:243], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[244:245], v[244:245], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[246:247], v[246:247], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[248:249], v[248:249], "
"[%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[12:15], 0 offen "
"offset:3072 \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[250:251], v[250:251], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[252:253], v[252:253], "
"[%[c28],%[c29],%[c30],%[c31]] \n" _UK_MFMA_
" [%[c28],%[c29],%[c30],%[c31]], acc[254:255], v[254:255], "
"[%[c28],%[c29],%[c30],%[c31]]\n"
" v_mul_f32 %[c16], %[scale_0], %[c16] \n"
" v_mul_f32 %[c17], %[scale_0], %[c17] \n"
" v_mul_f32 %[c18], %[scale_0], %[c18] \n"
" v_mul_f32 %[c19], %[scale_0], %[c19] \n"
" v_mul_f32 %[c20], %[scale_1], %[c20] \n"
" v_mul_f32 %[c21], %[scale_1], %[c21] \n"
" v_mul_f32 %[c22], %[scale_1], %[c22] \n"
" v_mul_f32 %[c23], %[scale_1], %[c23] \n"
" v_mul_f32 %[c24], %[scale_0], %[c24] \n"
" v_mul_f32 %[c25], %[scale_0], %[c25] \n"
" v_mul_f32 %[c26], %[scale_0], %[c26] \n"
" v_mul_f32 %[c27], %[scale_0], %[c27] \n"
" v_mul_f32 %[c28], %[scale_1], %[c28] \n"
" v_mul_f32 %[c29], %[scale_1], %[c29] \n"
" v_mul_f32 %[c30], %[scale_1], %[c30] \n"
" v_mul_f32 %[c31], %[scale_1], %[c31] \n"
_UK_PK_CVT_("%[c16]", "%[c17]", "%[c16]") _UK_PK_CVT_("%[c18]", "%[c19]", "%[c17]") _UK_PK_CVT_(
"%[c20]", "%[c21]", "%[c18]") _UK_PK_CVT_("%[c22]", "%[c23]", "%[c19]")
_UK_PK_CVT_("%[c24]", "%[c25]", "%[c20]") _UK_PK_CVT_(
"%[c26]", "%[c27]", "%[c21]") _UK_PK_CVT_("%[c28]",
"%[c29]",
"%[c22]") _UK_PK_CVT_("%[c30]",
"%[c31]",
"%[c23]")
" ;------------------------------ \n"
" ds_write_b64 %[v_sfl_sst], [%[c16],%[c17]] offset:0 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c18],%[c19]] offset:4352 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c20],%[c21]] offset:2176 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c22],%[c23]] offset:6528 + %[shfl_base] \n"
" s_waitcnt lgkmcnt(0) \n"
" s_barrier \n"
" ds_read_b32 %[c16], %[v_sfl_sld] offset:0 + %[shfl_base] \n"
" ds_read_b32 %[c17], %[v_sfl_sld] offset:32 + %[shfl_base] \n"
" ds_read_b32 %[c18], %[v_sfl_sld] offset:64 + %[shfl_base] \n"
" ds_read_b32 %[c19], %[v_sfl_sld] offset:96 + %[shfl_base] \n"
" ds_read_b32 %[c20], %[v_sfl_sld] offset:4352 + %[shfl_base] \n"
" ds_read_b32 %[c21], %[v_sfl_sld] offset:4384 + %[shfl_base] \n"
" ds_read_b32 %[c22], %[v_sfl_sld] offset:4416 + %[shfl_base] \n"
" ds_read_b32 %[c23], %[v_sfl_sld] offset:4448 + %[shfl_base] \n"
" s_waitcnt lgkmcnt(0) \n"
" s_mov_b64 exec, %[s_execflag_0] \n" _UK_ATOMIC_ADD_
" %[v_os_o0], %[c16], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_1] \n" _UK_ATOMIC_ADD_
" %[v_os_o1], %[c17], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_2] \n" _UK_ATOMIC_ADD_
" %[v_os_o2], %[c18], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_3] \n" _UK_ATOMIC_ADD_
" %[v_os_o3], %[c19], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_4] \n" _UK_ATOMIC_ADD_
" %[v_os_o4], %[c20], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_5] \n" _UK_ATOMIC_ADD_
" %[v_os_o5], %[c21], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_6] \n" _UK_ATOMIC_ADD_
" %[v_os_o6], %[c22], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_7] \n" _UK_ATOMIC_ADD_
" %[v_os_o7], %[c23], s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 ; k-- \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s12, s86, s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" s_add_u32 s8, %[s_tile_os_o], s8 \n"
" s_addc_u32 s9, 0, s9 \n"
" s_branch L_start%= \n"
"L_end%=: \n"
#undef _UK_MFMA_
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
#ifndef CK_TILE_FLATMM_UK_MFMA
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#endif
#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_BF16
#define _UK_MFMA_ "v_mfma_f32_16x16x16_bf16"
#define _UK_PK_CVT_(x0_, x1_, y_) \
" v_cmp_u_f32 s[36:37], " x0_ ", " x0_ " \n" \
" v_add3_u32 v50, " x0_ ", %[v_nan_lo], 1 \n" \
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[36:37] \n" \
" v_cmp_u_f32 s[36:37], " x1_ ", " x1_ " \n" \
" v_add3_u32 v50, " x1_ ", %[v_nan_lo], 1 \n" \
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[36:37] \n" \
" v_perm_b32 " y_ ", v55, v54, s52 \n"
#define _UK_ATOMIC_ADD_ "global_atomic_pk_add_bf16"
#elif CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_FP16
#define _UK_MFMA_ "v_mfma_f32_16x16x16_f16"
#define _UK_PK_CVT_(x0_, x1_, y_) \
" v_cvt_f16_f32 v54, " x0_ " \n" \
" v_cvt_f16_f32 v55, " x1_ " \n" \
" v_pack_b32_f16 " y_ ", v54, v55 \n"
#define _UK_ATOMIC_ADD_ "global_atomic_pk_add_f16"
#endif
";-------------------------------------------------------------\n"
" s_mov_b32 s52, 0x07060302 ; v_perm\n"
" s_mov_b64 s[38:39], exec ; save current exec\n"
" s_mov_b32 s8, %[s_res_o0] \n"
" s_mov_b32 s9, %[s_res_o1] \n"
" s_mov_b32 s12, %[s_res_b0] \n"
" s_mov_b32 s13, %[s_res_b1] \n"
" s_mov_b32 s14, %[s_res_b2] \n"
" s_mov_b32 s15, %[s_res_b3] \n"
" s_mov_b32 s59, 0 \n"
" ds_read_b64 v[128:129], %[v_sld_y_os] offset:0 + %[sld_a_base] \n"
" ds_read_b64 v[130:131], %[v_sld_y_os] offset:128 + %[sld_a_base] \n"
" ds_read_b64 v[132:133], %[v_sld_y_os] offset:1024 + %[sld_a_base] \n"
" ds_read_b64 v[134:135], %[v_sld_y_os] offset:1152 + %[sld_a_base] \n"
" ds_read_b64 v[136:137], %[v_sld_y_os] offset:2048 + %[sld_a_base] \n"
" ds_read_b64 v[138:139], %[v_sld_y_os] offset:2176 + %[sld_a_base] \n"
" ds_read_b64 v[140:141], %[v_sld_y_os] offset:3072 + %[sld_a_base] \n"
" ds_read_b64 v[142:143], %[v_sld_y_os] offset:3200 + %[sld_a_base] \n"
" ds_read_b64 v[144:145], %[v_sld_y_os] offset:4096 + %[sld_a_base] \n"
" ds_read_b64 v[146:147], %[v_sld_y_os] offset:4224 + %[sld_a_base] \n"
" ds_read_b64 v[148:149], %[v_sld_y_os] offset:5120 + %[sld_a_base] \n"
" ds_read_b64 v[150:151], %[v_sld_y_os] offset:5248 + %[sld_a_base] \n"
" ds_read_b64 v[152:153], %[v_sld_y_os] offset:6144 + %[sld_a_base] \n"
" ds_read_b64 v[154:155], %[v_sld_y_os] offset:6272 + %[sld_a_base] \n"
" ds_read_b64 v[156:157], %[v_sld_y_os] offset:7168 + %[sld_a_base] \n"
" ds_read_b64 v[158:159], %[v_sld_y_os] offset:7296 + %[sld_a_base] \n"
" ds_read_b64 v[160:161], %[v_sld_y_os] offset:8192 + %[sld_a_base] \n"
" ds_read_b64 v[162:163], %[v_sld_y_os] offset:8320 + %[sld_a_base] \n"
" ds_read_b64 v[164:165], %[v_sld_y_os] offset:9216 + %[sld_a_base] \n"
" ds_read_b64 v[166:167], %[v_sld_y_os] offset:9344 + %[sld_a_base] \n"
" ds_read_b64 v[168:169], %[v_sld_y_os] offset:10240 + %[sld_a_base] \n"
" ds_read_b64 v[170:171], %[v_sld_y_os] offset:10368 + %[sld_a_base] \n"
" ds_read_b64 v[172:173], %[v_sld_y_os] offset:11264 + %[sld_a_base] \n"
" ds_read_b64 v[174:175], %[v_sld_y_os] offset:11392 + %[sld_a_base] \n"
" ds_read_b64 v[176:177], %[v_sld_y_os] offset:12288 + %[sld_a_base] \n"
" ds_read_b64 v[178:179], %[v_sld_y_os] offset:12416 + %[sld_a_base] \n"
" ds_read_b64 v[180:181], %[v_sld_y_os] offset:13312 + %[sld_a_base] \n"
" ds_read_b64 v[182:183], %[v_sld_y_os] offset:13440 + %[sld_a_base] \n"
" ds_read_b64 v[184:185], %[v_sld_y_os] offset:14336 + %[sld_a_base] \n"
" ds_read_b64 v[186:187], %[v_sld_y_os] offset:14464 + %[sld_a_base] \n"
" ds_read_b64 v[188:189], %[v_sld_y_os] offset:15360 + %[sld_a_base] \n"
" ds_read_b64 v[190:191], %[v_sld_y_os] offset:15488 + %[sld_a_base] \n"
" ds_read_b64 v[192:193], %[v_sld_y_os] offset:16384 + %[sld_a_base] \n"
" ds_read_b64 v[194:195], %[v_sld_y_os] offset:16512 + %[sld_a_base] \n"
" ds_read_b64 v[196:197], %[v_sld_y_os] offset:17408 + %[sld_a_base] \n"
" ds_read_b64 v[198:199], %[v_sld_y_os] offset:17536 + %[sld_a_base] \n"
" ds_read_b64 v[200:201], %[v_sld_y_os] offset:18432 + %[sld_a_base] \n"
" ds_read_b64 v[202:203], %[v_sld_y_os] offset:18560 + %[sld_a_base] \n"
" ds_read_b64 v[204:205], %[v_sld_y_os] offset:19456 + %[sld_a_base] \n"
" ds_read_b64 v[206:207], %[v_sld_y_os] offset:19584 + %[sld_a_base] \n"
" ds_read_b64 v[208:209], %[v_sld_y_os] offset:20480 + %[sld_a_base] \n"
" ds_read_b64 v[210:211], %[v_sld_y_os] offset:20608 + %[sld_a_base] \n"
" ds_read_b64 v[212:213], %[v_sld_y_os] offset:21504 + %[sld_a_base] \n"
" ds_read_b64 v[214:215], %[v_sld_y_os] offset:21632 + %[sld_a_base] \n"
" ds_read_b64 v[216:217], %[v_sld_y_os] offset:22528 + %[sld_a_base] \n"
" ds_read_b64 v[218:219], %[v_sld_y_os] offset:22656 + %[sld_a_base] \n"
" ds_read_b64 v[220:221], %[v_sld_y_os] offset:23552 + %[sld_a_base] \n"
" ds_read_b64 v[222:223], %[v_sld_y_os] offset:23680 + %[sld_a_base] \n"
" ds_read_b64 v[224:225], %[v_sld_y_os] offset:24576 + %[sld_a_base] \n"
" ds_read_b64 v[226:227], %[v_sld_y_os] offset:24704 + %[sld_a_base] \n"
" ds_read_b64 v[228:229], %[v_sld_y_os] offset:25600 + %[sld_a_base] \n"
" ds_read_b64 v[230:231], %[v_sld_y_os] offset:25728 + %[sld_a_base] \n"
" ds_read_b64 v[232:233], %[v_sld_y_os] offset:26624 + %[sld_a_base] \n"
" ds_read_b64 v[234:235], %[v_sld_y_os] offset:26752 + %[sld_a_base] \n"
" ds_read_b64 v[236:237], %[v_sld_y_os] offset:27648 + %[sld_a_base] \n"
" ds_read_b64 v[238:239], %[v_sld_y_os] offset:27776 + %[sld_a_base] \n"
" ds_read_b64 v[240:241], %[v_sld_y_os] offset:28672 + %[sld_a_base] \n"
" ds_read_b64 v[242:243], %[v_sld_y_os] offset:28800 + %[sld_a_base] \n"
" ds_read_b64 v[244:245], %[v_sld_y_os] offset:29696 + %[sld_a_base] \n"
" ds_read_b64 v[246:247], %[v_sld_y_os] offset:29824 + %[sld_a_base] \n"
" ds_read_b64 v[248:249], %[v_sld_y_os] offset:30720 + %[sld_a_base] \n"
" ds_read_b64 v[250:251], %[v_sld_y_os] offset:30848 + %[sld_a_base] \n"
" ds_read_b64 v[252:253], %[v_sld_y_os] offset:31744 + %[sld_a_base] \n"
" ds_read_b64 v[254:255], %[v_sld_y_os] offset:31872 + %[sld_a_base] \n"
" s_waitcnt 0 \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[12:15], 0 offen offset:3072 \n"
" s_add_u32 s12, %[s_tile_os_b], s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" v_mov_b32 v64, 0 \n"
" v_mov_b32 v80, 0 \n"
" v_mov_b32 v65, 0 \n"
" v_mov_b32 v81, 0 \n"
" v_mov_b32 v66, 0 \n"
" v_mov_b32 v82, 0 \n"
" v_mov_b32 v67, 0 \n"
" v_mov_b32 v83, 0 \n"
" v_mov_b32 v68, 0 \n"
" v_mov_b32 v84, 0 \n"
" v_mov_b32 v69, 0 \n"
" v_mov_b32 v85, 0 \n"
" v_mov_b32 v70, 0 \n"
" v_mov_b32 v86, 0 \n"
" v_mov_b32 v71, 0 \n"
" v_mov_b32 v87, 0 \n"
" ds_write_b64 %[v_sfl_sst], [%[c0],%[c1]] offset:16640 \n"
" ds_write_b64 %[v_sfl_sst], [%[c2],%[c3]] offset:20992 \n"
" ds_write_b64 %[v_sfl_sst], [%[c4],%[c5]] offset:18816 \n"
" ds_write_b64 %[v_sfl_sst], [%[c6],%[c7]] offset:23168 \n"
" s_mov_b32 s80, 0 \n"
" s_waitcnt vmcnt(24) \n"
"label_0AA6: \n"
" s_waitcnt vmcnt(30) & lgkmcnt(0) \n"
" s_barrier \n" _UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[0:1], v[128:129], 0 \n"
" ds_read_b32 v10, %[v_sfl_sld] offset:16640 \n"
" ds_read_b32 v11, %[v_sfl_sld] offset:16672 \n"
" ds_write_b64 %[v_sfl_sst], [%[c16],%[c17]] offset:25344 \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[2:3], v[130:131], v[64:67] \n"
" buffer_load_dwordx4 acc[128:131], %[v_os_b0], s[12:15], 0 offen \n"
" ds_write_b64 %[v_sfl_sst], [%[c18],%[c19]] offset:29696 \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[4:5], v[132:133], v[64:67] \n"
" ds_read_b32 v12, %[v_sfl_sld] offset:16704 \n"
" ds_read_b32 v13, %[v_sfl_sld] offset:16736 \n"
" ds_write_b64 %[v_sfl_sst], [%[c20],%[c21]] offset:27520 \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[6:7], v[134:135], v[64:67] \n"
" ds_write_b64 %[v_sfl_sst], [%[c22],%[c23]] offset:31872 \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[8:9], v[136:137], v[64:67] \n"
" ds_read_b32 v14, %[v_sfl_sld] offset:20992 \n"
" ds_read_b32 v15, %[v_sfl_sld] offset:21024 \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[10:11], v[138:139], v[64:67] \n"
" buffer_load_dwordx4 acc[132:135], %[v_os_b0], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[12:13], v[140:141], v[64:67] \n"
" ds_read_b32 v16, %[v_sfl_sld] offset:21056 \n"
" ds_read_b32 v17, %[v_sfl_sld] offset:21088 \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[14:15], v[142:143], v[64:67] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[0:1], v[192:193], 0 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[2:3], v[194:195], v[68:71] \n"
" buffer_load_dwordx4 acc[136:139], %[v_os_b0], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[4:5], v[196:197], v[68:71] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[6:7], v[198:199], v[68:71] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[8:9], v[200:201], v[68:71] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[10:11], v[202:203], v[68:71] \n"
" buffer_load_dwordx4 acc[140:143], %[v_os_b0], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[12:13], v[204:205], v[68:71] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[14:15], v[206:207], v[68:71] \n"
" s_waitcnt lgkmcnt(0) \n"
" s_mov_b64 exec, %[s_execflag_0] \n" _UK_ATOMIC_ADD_ " %[v_os_o0], v10, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[16:17], v[128:129], 0 \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[18:19], v[130:131], v[72:75] \n"
" buffer_load_dwordx4 acc[144:147], %[v_os_b1], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[20:21], v[132:133], v[72:75] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[22:23], v[134:135], v[72:75] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[24:25], v[136:137], v[72:75] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[26:27], v[138:139], v[72:75] \n"
" buffer_load_dwordx4 acc[148:151], %[v_os_b1], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[28:29], v[140:141], v[72:75] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[30:31], v[142:143], v[72:75] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[16:17], v[192:193], 0 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[18:19], v[194:195], v[76:79] \n"
" buffer_load_dwordx4 acc[152:155], %[v_os_b1], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[20:21], v[196:197], v[76:79] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[22:23], v[198:199], v[76:79] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[24:25], v[200:201], v[76:79] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[26:27], v[202:203], v[76:79] \n"
" buffer_load_dwordx4 acc[156:159], %[v_os_b1], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[28:29], v[204:205], v[76:79] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[30:31], v[206:207], v[76:79] \n"
" s_mov_b64 exec, %[s_execflag_1] \n" _UK_ATOMIC_ADD_ " %[v_os_o1], v11, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
" s_waitcnt vmcnt(30) \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[32:33], v[144:145], v[64:67] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[34:35], v[146:147], v[64:67] \n"
" buffer_load_dwordx4 acc[160:163], %[v_os_b2], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[36:37], v[148:149], v[64:67] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[38:39], v[150:151], v[64:67] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[40:41], v[152:153], v[64:67] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[42:43], v[154:155], v[64:67] \n"
" buffer_load_dwordx4 acc[164:167], %[v_os_b2], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[44:45], v[156:157], v[64:67] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[46:47], v[158:159], v[64:67] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[32:33], v[208:209], v[68:71] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[34:35], v[210:211], v[68:71] \n"
" buffer_load_dwordx4 acc[168:171], %[v_os_b2], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[36:37], v[212:213], v[68:71] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[38:39], v[214:215], v[68:71] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[40:41], v[216:217], v[68:71] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[42:43], v[218:219], v[68:71] \n"
" buffer_load_dwordx4 acc[172:175], %[v_os_b2], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[44:45], v[220:221], v[68:71] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[46:47], v[222:223], v[68:71] \n"
" s_mov_b64 exec, %[s_execflag_2] \n" _UK_ATOMIC_ADD_ " %[v_os_o2], v12, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[48:49], v[144:145], v[72:75] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[50:51], v[146:147], v[72:75] \n"
" buffer_load_dwordx4 acc[176:179], %[v_os_b3], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[52:53], v[148:149], v[72:75] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[54:55], v[150:151], v[72:75] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[56:57], v[152:153], v[72:75] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[58:59], v[154:155], v[72:75] \n"
" buffer_load_dwordx4 acc[180:183], %[v_os_b3], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[60:61], v[156:157], v[72:75] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[62:63], v[158:159], v[72:75] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[48:49], v[208:209], v[76:79] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[50:51], v[210:211], v[76:79] \n"
" buffer_load_dwordx4 acc[184:187], %[v_os_b3], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[52:53], v[212:213], v[76:79] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[54:55], v[214:215], v[76:79] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[56:57], v[216:217], v[76:79] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[58:59], v[218:219], v[76:79] \n"
" buffer_load_dwordx4 acc[188:191], %[v_os_b3], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[60:61], v[220:221], v[76:79] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[62:63], v[222:223], v[76:79] \n"
" s_mov_b64 exec, %[s_execflag_3] \n" _UK_ATOMIC_ADD_ " %[v_os_o3], v13, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
" s_waitcnt vmcnt(30) \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[64:65], v[160:161], v[64:67] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[66:67], v[162:163], v[64:67] \n"
" buffer_load_dwordx4 acc[192:195], %[v_os_b4], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[68:69], v[164:165], v[64:67] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[70:71], v[166:167], v[64:67] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[72:73], v[168:169], v[64:67] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[74:75], v[170:171], v[64:67] \n"
" buffer_load_dwordx4 acc[196:199], %[v_os_b4], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[76:77], v[172:173], v[64:67] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[78:79], v[174:175], v[64:67] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[64:65], v[224:225], v[68:71] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[66:67], v[226:227], v[68:71] \n"
" buffer_load_dwordx4 acc[200:203], %[v_os_b4], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[68:69], v[228:229], v[68:71] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[70:71], v[230:231], v[68:71] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[72:73], v[232:233], v[68:71] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[74:75], v[234:235], v[68:71] \n"
" buffer_load_dwordx4 acc[204:207], %[v_os_b4], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[76:77], v[236:237], v[68:71] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[78:79], v[238:239], v[68:71] \n"
" s_mov_b64 exec, %[s_execflag_4] \n" _UK_ATOMIC_ADD_ " %[v_os_o4], v14, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[80:81], v[160:161], v[72:75] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[82:83], v[162:163], v[72:75] \n"
" buffer_load_dwordx4 acc[208:211], %[v_os_b5], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[84:85], v[164:165], v[72:75] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[86:87], v[166:167], v[72:75] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[88:89], v[168:169], v[72:75] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[90:91], v[170:171], v[72:75] \n"
" buffer_load_dwordx4 acc[212:215], %[v_os_b5], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[92:93], v[172:173], v[72:75] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[94:95], v[174:175], v[72:75] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[80:81], v[224:225], v[76:79] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[82:83], v[226:227], v[76:79] \n"
" buffer_load_dwordx4 acc[216:219], %[v_os_b5], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[84:85], v[228:229], v[76:79] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[86:87], v[230:231], v[76:79] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[88:89], v[232:233], v[76:79] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[90:91], v[234:235], v[76:79] \n"
" buffer_load_dwordx4 acc[220:223], %[v_os_b5], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[92:93], v[236:237], v[76:79] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[94:95], v[238:239], v[76:79] \n"
" s_mov_b64 exec, %[s_execflag_5] \n" _UK_ATOMIC_ADD_ " %[v_os_o5], v15, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
" s_waitcnt vmcnt(30) \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[96:97], v[176:177], v[64:67] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[98:99], v[178:179], v[64:67] \n"
" buffer_load_dwordx4 acc[224:227], %[v_os_b6], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[100:101], v[180:181], v[64:67] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[102:103], v[182:183], v[64:67] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[104:105], v[184:185], v[64:67] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[106:107], v[186:187], v[64:67] \n"
" buffer_load_dwordx4 acc[228:231], %[v_os_b6], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[108:109], v[188:189], v[64:67] \n" _UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[110:111], v[190:191], v[64:67] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[96:97], v[240:241], v[68:71] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[98:99], v[242:243], v[68:71] \n"
" buffer_load_dwordx4 acc[232:235], %[v_os_b6], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[100:101], v[244:245], v[68:71] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[102:103], v[246:247], v[68:71] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[104:105], v[248:249], v[68:71] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[106:107], v[250:251], v[68:71] \n"
" buffer_load_dwordx4 acc[236:239], %[v_os_b6], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[108:109], v[252:253], v[68:71] \n" _UK_MFMA_
" [%[c4], %[c5], %[c6], %[c7]], acc[110:111], v[254:255], v[68:71] \n"
" s_mov_b64 exec, %[s_execflag_6] \n" _UK_ATOMIC_ADD_ " %[v_os_o6], v16, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[112:113], v[176:177], v[72:75] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[114:115], v[178:179], v[72:75] \n"
" buffer_load_dwordx4 acc[240:243], %[v_os_b7], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[116:117], v[180:181], v[72:75] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[118:119], v[182:183], v[72:75] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[120:121], v[184:185], v[72:75] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[122:123], v[186:187], v[72:75] \n"
" buffer_load_dwordx4 acc[244:247], %[v_os_b7], s[12:15], 0 offen offset:1024 \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[124:125], v[188:189], v[72:75] \n" _UK_MFMA_
" [%[c8], %[c9], %[c10], %[c11]], acc[126:127], v[190:191], v[72:75] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[112:113], v[240:241], v[76:79] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[114:115], v[242:243], v[76:79] \n"
" buffer_load_dwordx4 acc[248:251], %[v_os_b7], s[12:15], 0 offen offset:2048 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[116:117], v[244:245], v[76:79] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[118:119], v[246:247], v[76:79] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[120:121], v[248:249], v[76:79] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[122:123], v[250:251], v[76:79] \n"
" buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[12:15], 0 offen offset:3072 \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[124:125], v[252:253], v[76:79] \n" _UK_MFMA_
" [%[c12], %[c13], %[c14], %[c15]], acc[126:127], v[254:255], v[76:79] \n"
" s_mov_b64 exec, %[s_execflag_7] \n" _UK_ATOMIC_ADD_ " %[v_os_o7], v17, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
" s_add_u32 s60, 0x00000100, s80 \n"
" s_cmp_lt_u32 s60, %[s_loop_cnt] \n"
" s_cselect_b32 s56, %[s_tile_os_b], 0 \n"
" s_add_u32 s12, s56, s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" s_cmp_ge_u32 s80, 0x00000100 \n"
" s_cselect_b32 s59, %[s_tile_os_o], s59 \n"
" s_add_u32 s8, s59, s8 \n"
" s_addc_u32 s9, 0, s9 \n"
" v_mul_f32 %[c0], %[scale_0], %[c0] \n"
" v_mul_f32 %[c1], %[scale_0], %[c1] \n"
" v_mul_f32 %[c2], %[scale_0], %[c2] \n"
" v_mul_f32 %[c3], %[scale_0], %[c3] \n"
" v_mul_f32 %[c4], %[scale_1], %[c4] \n"
" v_mul_f32 %[c5], %[scale_1], %[c5] \n"
" v_mul_f32 %[c6], %[scale_1], %[c6] \n"
" v_mul_f32 %[c7], %[scale_1], %[c7] \n"
" v_mul_f32 %[c8], %[scale_0], %[c8] \n"
" v_mul_f32 %[c9], %[scale_0], %[c9] \n"
" v_mul_f32 %[c10], %[scale_0], %[c10] \n"
" v_mul_f32 %[c11], %[scale_0], %[c11] \n"
" v_mul_f32 %[c12], %[scale_1], %[c12] \n"
" v_mul_f32 %[c13], %[scale_1], %[c13] \n"
" v_mul_f32 %[c14], %[scale_1], %[c14] \n"
" v_mul_f32 %[c15], %[scale_1], %[c15] \n" _UK_PK_CVT_("%[c0]", "%[c1]", "%[c0]") _UK_PK_CVT_(
"%[c2]",
"%[c3]",
"%[c1]") _UK_PK_CVT_("%[c4]",
"%[c5]",
"%[c2]") _UK_PK_CVT_("%[c6]",
"%[c7]",
"%[c3]") _UK_PK_CVT_("%[c8]",
"%[c9]",
"%[c4]") _UK_PK_CVT_("%["
"c10]",
"%["
"c11]",
"%[c5]")
_UK_PK_CVT_("%[c12]", "%[c13]", "%[c6]") _UK_PK_CVT_(
"%[c14]",
"%[c15]",
"%[c7]") " s_addk_i32 s80, 0x0080 \n"
" s_cmp_lt_i32 s80, %[s_loop_cnt] \n"
" s_cbranch_scc0 label_0EC1 \n"
" s_waitcnt vmcnt(30) & lgkmcnt(0) \n"
" s_barrier \n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[128:129], v[128:129], 0 \n"
" ds_read_b32 v10, %[v_sfl_sld] offset:25344 \n"
" ds_read_b32 v11, %[v_sfl_sld] offset:25376 \n"
" ds_write_b64 v3, v[64:65] offset:16640 \n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[130:131], v[130:131], v[80:83] \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen \n"
" ds_write_b64 v3, v[66:67] offset:20992 \n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[132:133], v[132:133], v[80:83] \n"
" ds_read_b32 v12, %[v_sfl_sld] offset:25408 \n"
" ds_read_b32 v13, %[v_sfl_sld] offset:25440 \n"
" ds_write_b64 v3, v[68:69] offset:18816 \n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[134:135], v[134:135], v[80:83] \n"
" ds_write_b64 v3, v[70:71] offset:23168 \n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[136:137], v[136:137], v[80:83] \n"
" ds_read_b32 v14, %[v_sfl_sld] offset:29696 \n"
" ds_read_b32 v15, %[v_sfl_sld] offset:29728 \n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[138:139], v[138:139], v[80:83] \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024 "
"\n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[140:141], v[140:141], v[80:83] \n"
" ds_read_b32 v16, %[v_sfl_sld] offset:29760 \n"
" ds_read_b32 v17, %[v_sfl_sld] offset:29792 \n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[142:143], v[142:143], v[80:83] "
"\n" _UK_MFMA_
" [%[c20], %[c21], %[c22], %[c23]], acc[128:129], v[192:193], 0 \n" _UK_MFMA_
" [%[c20], %[c21], %[c22], %[c23]], acc[130:131], v[194:195], v[84:87] \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048 "
"\n" _UK_MFMA_ " [%[c20], %[c21], %[c22], %[c23]], acc[132:133], v[196:197], "
"v[84:87] \n" _UK_MFMA_
" [%[c20], %[c21], %[c22], %[c23]], acc[134:135], v[198:199], v[84:87] "
"\n" _UK_MFMA_ " [%[c20], %[c21], %[c22], %[c23]], acc[136:137], v[200:201], "
"v[84:87] \n" _UK_MFMA_
" [%[c20], %[c21], %[c22], %[c23]], acc[138:139], v[202:203], v[84:87] \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072 "
"\n" _UK_MFMA_ " [%[c20], %[c21], %[c22], %[c23]], acc[140:141], v[204:205], "
"v[84:87] \n" _UK_MFMA_
" [%[c20], %[c21], %[c22], %[c23]], acc[142:143], v[206:207], v[84:87] \n"
" s_waitcnt lgkmcnt(0) \n"
" s_mov_b64 exec, %[s_execflag_0] \n" _UK_ATOMIC_ADD_
" %[v_os_o0], v10, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n" _UK_MFMA_
" [%[c24], %[c25], %[c26], %[c27]], acc[144:145], v[128:129], 0 \n" _UK_MFMA_
" [%[c24], %[c25], %[c26], %[c27]], acc[146:147], v[130:131], v[88:91] \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c24], %[c25], %[c26], %[c27]], acc[148:149], v[132:133], v[88:91] "
"\n" _UK_MFMA_ " [%[c24], %[c25], %[c26], %[c27]], acc[150:151], v[134:135], "
"v[88:91] \n" _UK_MFMA_
" [%[c24], %[c25], %[c26], %[c27]], acc[152:153], v[136:137], v[88:91] "
"\n" _UK_MFMA_
" [%[c24], %[c25], %[c26], %[c27]], acc[154:155], v[138:139], v[88:91] \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024 "
"\n" _UK_MFMA_ " [%[c24], %[c25], %[c26], %[c27]], acc[156:157], v[140:141], "
"v[88:91] \n" _UK_MFMA_
" [%[c24], %[c25], %[c26], %[c27]], acc[158:159], v[142:143], v[88:91] "
"\n" _UK_MFMA_
" [%[c28], %[c29], %[c30], %[c31]], acc[144:145], v[192:193], 0 \n" _UK_MFMA_
" [%[c28], %[c29], %[c30], %[c31]], acc[146:147], v[194:195], v[92:95] \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048 "
"\n" _UK_MFMA_ " [%[c28], %[c29], %[c30], %[c31]], acc[148:149], v[196:197], "
"v[92:95] \n" _UK_MFMA_
" [%[c28], %[c29], %[c30], %[c31]], acc[150:151], v[198:199], v[92:95] "
"\n" _UK_MFMA_ " [%[c28], %[c29], %[c30], %[c31]], acc[152:153], v[200:201], "
"v[92:95] \n" _UK_MFMA_
" [%[c28], %[c29], %[c30], %[c31]], acc[154:155], v[202:203], v[92:95] \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072 "
"\n" _UK_MFMA_ " [%[c28], %[c29], %[c30], %[c31]], acc[156:157], v[204:205], "
"v[92:95] \n" _UK_MFMA_
" [%[c28], %[c29], %[c30], %[c31]], acc[158:159], v[206:207], v[92:95] \n"
" s_mov_b64 exec, %[s_execflag_1] \n" _UK_ATOMIC_ADD_
" %[v_os_o1], v11, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
" s_waitcnt vmcnt(30) \n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[160:161], v[144:145], v[80:83] "
"\n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[162:163], v[146:147], v[80:83] \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[164:165], v[148:149], v[80:83] "
"\n" _UK_MFMA_ " [%[c16], %[c17], %[c18], %[c19]], acc[166:167], v[150:151], "
"v[80:83] \n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[168:169], v[152:153], v[80:83] "
"\n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[170:171], v[154:155], v[80:83] \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024 "
"\n" _UK_MFMA_ " [%[c16], %[c17], %[c18], %[c19]], acc[172:173], v[156:157], "
"v[80:83] \n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[174:175], v[158:159], v[80:83] "
"\n" _UK_MFMA_ " [%[c20], %[c21], %[c22], %[c23]], acc[160:161], v[208:209], "
"v[84:87] \n" _UK_MFMA_
" [%[c20], %[c21], %[c22], %[c23]], acc[162:163], v[210:211], v[84:87] \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048 "
"\n" _UK_MFMA_ " [%[c20], %[c21], %[c22], %[c23]], acc[164:165], v[212:213], "
"v[84:87] \n" _UK_MFMA_
" [%[c20], %[c21], %[c22], %[c23]], acc[166:167], v[214:215], v[84:87] "
"\n" _UK_MFMA_ " [%[c20], %[c21], %[c22], %[c23]], acc[168:169], v[216:217], "
"v[84:87] \n" _UK_MFMA_
" [%[c20], %[c21], %[c22], %[c23]], acc[170:171], v[218:219], v[84:87] \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072 "
"\n" _UK_MFMA_ " [%[c20], %[c21], %[c22], %[c23]], acc[172:173], v[220:221], "
"v[84:87] \n" _UK_MFMA_
" [%[c20], %[c21], %[c22], %[c23]], acc[174:175], v[222:223], v[84:87] \n"
" s_mov_b64 exec, %[s_execflag_2] \n" _UK_ATOMIC_ADD_
" %[v_os_o2], v12, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n" _UK_MFMA_
" [%[c24], %[c25], %[c26], %[c27]], acc[176:177], v[144:145], v[88:91] "
"\n" _UK_MFMA_
" [%[c24], %[c25], %[c26], %[c27]], acc[178:179], v[146:147], v[88:91] \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c24], %[c25], %[c26], %[c27]], acc[180:181], v[148:149], v[88:91] "
"\n" _UK_MFMA_ " [%[c24], %[c25], %[c26], %[c27]], acc[182:183], v[150:151], "
"v[88:91] \n" _UK_MFMA_
" [%[c24], %[c25], %[c26], %[c27]], acc[184:185], v[152:153], v[88:91] "
"\n" _UK_MFMA_
" [%[c24], %[c25], %[c26], %[c27]], acc[186:187], v[154:155], v[88:91] \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024 "
"\n" _UK_MFMA_ " [%[c24], %[c25], %[c26], %[c27]], acc[188:189], v[156:157], "
"v[88:91] \n" _UK_MFMA_
" [%[c24], %[c25], %[c26], %[c27]], acc[190:191], v[158:159], v[88:91] "
"\n" _UK_MFMA_ " [%[c28], %[c29], %[c30], %[c31]], acc[176:177], v[208:209], "
"v[92:95] \n" _UK_MFMA_
" [%[c28], %[c29], %[c30], %[c31]], acc[178:179], v[210:211], v[92:95] \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048 "
"\n" _UK_MFMA_ " [%[c28], %[c29], %[c30], %[c31]], acc[180:181], v[212:213], "
"v[92:95] \n" _UK_MFMA_
" [%[c28], %[c29], %[c30], %[c31]], acc[182:183], v[214:215], v[92:95] "
"\n" _UK_MFMA_ " [%[c28], %[c29], %[c30], %[c31]], acc[184:185], v[216:217], "
"v[92:95] \n" _UK_MFMA_
" [%[c28], %[c29], %[c30], %[c31]], acc[186:187], v[218:219], v[92:95] \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072 "
"\n" _UK_MFMA_ " [%[c28], %[c29], %[c30], %[c31]], acc[188:189], v[220:221], "
"v[92:95] \n" _UK_MFMA_
" [%[c28], %[c29], %[c30], %[c31]], acc[190:191], v[222:223], v[92:95] \n"
" s_mov_b64 exec, %[s_execflag_3] \n" _UK_ATOMIC_ADD_
" %[v_os_o3], v13, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
" s_waitcnt vmcnt(30) \n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[192:193], v[160:161], v[80:83] "
"\n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[194:195], v[162:163], v[80:83] \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[196:197], v[164:165], v[80:83] "
"\n" _UK_MFMA_ " [%[c16], %[c17], %[c18], %[c19]], acc[198:199], v[166:167], "
"v[80:83] \n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[200:201], v[168:169], v[80:83] "
"\n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[202:203], v[170:171], v[80:83] \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[12:15], 0 offen offset:1024 "
"\n" _UK_MFMA_ " [%[c16], %[c17], %[c18], %[c19]], acc[204:205], v[172:173], "
"v[80:83] \n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[206:207], v[174:175], v[80:83] "
"\n" _UK_MFMA_ " [%[c20], %[c21], %[c22], %[c23]], acc[192:193], v[224:225], "
"v[84:87] \n" _UK_MFMA_
" [%[c20], %[c21], %[c22], %[c23]], acc[194:195], v[226:227], v[84:87] \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[12:15], 0 offen offset:2048 "
"\n" _UK_MFMA_ " [%[c20], %[c21], %[c22], %[c23]], acc[196:197], v[228:229], "
"v[84:87] \n" _UK_MFMA_
" [%[c20], %[c21], %[c22], %[c23]], acc[198:199], v[230:231], v[84:87] "
"\n" _UK_MFMA_ " [%[c20], %[c21], %[c22], %[c23]], acc[200:201], v[232:233], "
"v[84:87] \n" _UK_MFMA_
" [%[c20], %[c21], %[c22], %[c23]], acc[202:203], v[234:235], v[84:87] \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[12:15], 0 offen offset:3072 "
"\n" _UK_MFMA_ " [%[c20], %[c21], %[c22], %[c23]], acc[204:205], v[236:237], "
"v[84:87] \n" _UK_MFMA_
" [%[c20], %[c21], %[c22], %[c23]], acc[206:207], v[238:239], v[84:87] \n"
" s_mov_b64 exec, %[s_execflag_4] \n" _UK_ATOMIC_ADD_
" %[v_os_o4], v14, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n" _UK_MFMA_
" [%[c24], %[c25], %[c26], %[c27]], acc[208:209], v[160:161], v[88:91] "
"\n" _UK_MFMA_
" [%[c24], %[c25], %[c26], %[c27]], acc[210:211], v[162:163], v[88:91] \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c24], %[c25], %[c26], %[c27]], acc[212:213], v[164:165], v[88:91] "
"\n" _UK_MFMA_ " [%[c24], %[c25], %[c26], %[c27]], acc[214:215], v[166:167], "
"v[88:91] \n" _UK_MFMA_
" [%[c24], %[c25], %[c26], %[c27]], acc[216:217], v[168:169], v[88:91] "
"\n" _UK_MFMA_
" [%[c24], %[c25], %[c26], %[c27]], acc[218:219], v[170:171], v[88:91] \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[12:15], 0 offen offset:1024 "
"\n" _UK_MFMA_ " [%[c24], %[c25], %[c26], %[c27]], acc[220:221], v[172:173], "
"v[88:91] \n" _UK_MFMA_
" [%[c24], %[c25], %[c26], %[c27]], acc[222:223], v[174:175], v[88:91] "
"\n" _UK_MFMA_ " [%[c28], %[c29], %[c30], %[c31]], acc[208:209], v[224:225], "
"v[92:95] \n" _UK_MFMA_
" [%[c28], %[c29], %[c30], %[c31]], acc[210:211], v[226:227], v[92:95] \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[12:15], 0 offen offset:2048 "
"\n" _UK_MFMA_ " [%[c28], %[c29], %[c30], %[c31]], acc[212:213], v[228:229], "
"v[92:95] \n" _UK_MFMA_
" [%[c28], %[c29], %[c30], %[c31]], acc[214:215], v[230:231], v[92:95] "
"\n" _UK_MFMA_ " [%[c28], %[c29], %[c30], %[c31]], acc[216:217], v[232:233], "
"v[92:95] \n" _UK_MFMA_
" [%[c28], %[c29], %[c30], %[c31]], acc[218:219], v[234:235], v[92:95] \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[12:15], 0 offen offset:3072 "
"\n" _UK_MFMA_ " [%[c28], %[c29], %[c30], %[c31]], acc[220:221], v[236:237], "
"v[92:95] \n" _UK_MFMA_
" [%[c28], %[c29], %[c30], %[c31]], acc[222:223], v[238:239], v[92:95] \n"
" s_mov_b64 exec, %[s_execflag_5] \n" _UK_ATOMIC_ADD_
" %[v_os_o5], v15, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
" s_waitcnt vmcnt(30) \n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[224:225], v[176:177], v[80:83] "
"\n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[226:227], v[178:179], v[80:83] \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[12:15], 0 offen \n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[228:229], v[180:181], v[80:83] "
"\n" _UK_MFMA_ " [%[c16], %[c17], %[c18], %[c19]], acc[230:231], v[182:183], "
"v[80:83] \n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[232:233], v[184:185], v[80:83] "
"\n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[234:235], v[186:187], v[80:83] \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[12:15], 0 offen "
"offset:1024 \n" _UK_MFMA_ " [%[c16], %[c17], %[c18], %[c19]], acc[236:237], "
"v[188:189], v[80:83] \n" _UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[238:239], v[190:191], v[80:83] "
"\n" _UK_MFMA_ " [%[c20], %[c21], %[c22], %[c23]], acc[224:225], v[240:241], "
"v[84:87] \n" _UK_MFMA_
" [%[c20], %[c21], %[c22], %[c23]], acc[226:227], v[242:243], v[84:87] \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[12:15], 0 offen "
"offset:2048 \n" _UK_MFMA_ " [%[c20], %[c21], %[c22], %[c23]], acc[228:229], "
"v[244:245], v[84:87] \n" _UK_MFMA_
" [%[c20], %[c21], %[c22], %[c23]], acc[230:231], v[246:247], v[84:87] "
"\n" _UK_MFMA_ " [%[c20], %[c21], %[c22], %[c23]], acc[232:233], v[248:249], "
"v[84:87] \n" _UK_MFMA_
" [%[c20], %[c21], %[c22], %[c23]], acc[234:235], v[250:251], v[84:87] \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[12:15], 0 offen "
"offset:3072 \n" _UK_MFMA_ " [%[c20], %[c21], %[c22], %[c23]], acc[236:237], "
"v[252:253], v[84:87] \n" _UK_MFMA_
" [%[c20], %[c21], %[c22], %[c23]], acc[238:239], v[254:255], v[84:87] \n"
" s_mov_b64 exec, %[s_execflag_6] \n" _UK_ATOMIC_ADD_
" %[v_os_o6], v16, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n" _UK_MFMA_
" [%[c24], %[c25], %[c26], %[c27]], acc[240:241], v[176:177], v[88:91] "
"\n" _UK_MFMA_
" [%[c24], %[c25], %[c26], %[c27]], acc[242:243], v[178:179], v[88:91] \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[12:15], 0 offen "
"\n" _UK_MFMA_ " [%[c24], %[c25], %[c26], %[c27]], acc[244:245], v[180:181], "
"v[88:91] \n" _UK_MFMA_
" [%[c24], %[c25], %[c26], %[c27]], acc[246:247], v[182:183], v[88:91] "
"\n" _UK_MFMA_ " [%[c24], %[c25], %[c26], %[c27]], acc[248:249], v[184:185], "
"v[88:91] \n" _UK_MFMA_
" [%[c24], %[c25], %[c26], %[c27]], acc[250:251], v[186:187], v[88:91] \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[12:15], 0 offen "
"offset:1024 \n" _UK_MFMA_ " [%[c24], %[c25], %[c26], %[c27]], acc[252:253], "
"v[188:189], v[88:91] \n" _UK_MFMA_
" [%[c24], %[c25], %[c26], %[c27]], acc[254:255], v[190:191], v[88:91] "
"\n" _UK_MFMA_ " [%[c28], %[c29], %[c30], %[c31]], acc[240:241], v[240:241], "
"v[92:95] \n" _UK_MFMA_
" [%[c28], %[c29], %[c30], %[c31]], acc[242:243], v[242:243], v[92:95] \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[12:15], 0 offen "
"offset:2048 \n" _UK_MFMA_ " [%[c28], %[c29], %[c30], %[c31]], acc[244:245], "
"v[244:245], v[92:95] \n" _UK_MFMA_
" [%[c28], %[c29], %[c30], %[c31]], acc[246:247], v[246:247], v[92:95] "
"\n" _UK_MFMA_ " [%[c28], %[c29], %[c30], %[c31]], acc[248:249], v[248:249], "
"v[92:95] \n" _UK_MFMA_
" [%[c28], %[c29], %[c30], %[c31]], acc[250:251], v[250:251], v[92:95] \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[12:15], 0 offen "
"offset:3072 \n" _UK_MFMA_ " [%[c28], %[c29], %[c30], %[c31]], acc[252:253], "
"v[252:253], v[92:95] \n" _UK_MFMA_
" [%[c28], %[c29], %[c30], %[c31]], acc[254:255], v[254:255], v[92:95] \n"
" s_mov_b64 exec, %[s_execflag_7] \n" _UK_ATOMIC_ADD_
" %[v_os_o7], v17, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
" s_add_u32 s60, 0x00000100, s80 \n"
" s_cmp_lt_u32 s60, %[s_loop_cnt] \n"
" s_cselect_b32 s56, s56, 0 \n"
" s_add_u32 s12, s56, s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" s_cmp_ge_u32 s80, 0x00000100 \n"
" s_cselect_b32 s59, 0x00000100, s59 \n"
" s_add_u32 s8, s59, s8 \n"
" s_addc_u32 s9, 0, s9 \n"
" v_mul_f32 %[c16], %[scale_0], %[c16] \n"
" v_mul_f32 %[c17], %[scale_0], %[c17] \n"
" v_mul_f32 %[c18], %[scale_0], %[c18] \n"
" v_mul_f32 %[c19], %[scale_0], %[c19] \n"
" v_mul_f32 %[c20], %[scale_1], %[c20] \n"
" v_mul_f32 %[c21], %[scale_1], %[c21] \n"
" v_mul_f32 %[c22], %[scale_1], %[c22] \n"
" v_mul_f32 %[c23], %[scale_1], %[c23] \n"
" v_mul_f32 %[c24], %[scale_0], %[c24] \n"
" v_mul_f32 %[c25], %[scale_0], %[c25] \n"
" v_mul_f32 %[c26], %[scale_0], %[c26] \n"
" v_mul_f32 %[c27], %[scale_0], %[c27] \n"
" v_mul_f32 %[c28], %[scale_1], %[c28] \n"
" v_mul_f32 %[c29], %[scale_1], %[c29] \n"
" v_mul_f32 %[c30], %[scale_1], %[c30] \n"
" v_mul_f32 %[c31], %[scale_1], %[c31] \n" _UK_PK_CVT_(
"%[c16]", "%[c17]", "%[c16]") _UK_PK_CVT_("%[c18]", "%[c19]", "%[c17]")
_UK_PK_CVT_("%[c20]", "%[c21]", "%[c18]") _UK_PK_CVT_(
"%[c22]", "%[c23]", "%[c19]") _UK_PK_CVT_("%[c24]", "%[c25]", "%[c20]")
_UK_PK_CVT_("%[c26]", "%[c27]", "%[c21]")
_UK_PK_CVT_("%[c28]", "%[c29]", "%[c22]") _UK_PK_CVT_(
"%[c30]",
"%[c31]",
"%[c23]") " s_addk_i32 s80, 0x0080 \n"
" s_cmp_lt_i32 s80, %[s_loop_cnt] \n"
" s_cbranch_scc0 label_0EC1 \n"
" s_branch label_0AA6 \n"
" label_0EC1: \n"
" s_waitcnt lgkmcnt(0) \n"
" s_barrier \n"
" ds_read_b32 v10, %[v_sfl_sld] offset:16640 \n"
" ds_read_b32 v11, %[v_sfl_sld] offset:16672 \n"
" ds_read_b32 v12, %[v_sfl_sld] offset:16704 \n"
" ds_read_b32 v13, %[v_sfl_sld] offset:16736 \n"
" ds_read_b32 v14, %[v_sfl_sld] offset:20992 \n"
" ds_read_b32 v15, %[v_sfl_sld] offset:21024 \n"
" ds_read_b32 v16, %[v_sfl_sld] offset:21056 \n"
" ds_read_b32 v17, %[v_sfl_sld] offset:21088 \n"
" s_waitcnt lgkmcnt(0) \n"
" s_mov_b64 exec, %[s_execflag_0] \n" _UK_ATOMIC_ADD_
" %[v_os_o0], v10, s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_1] \n" _UK_ATOMIC_ADD_
" %[v_os_o1], v11, s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_2] \n" _UK_ATOMIC_ADD_
" %[v_os_o2], v12, s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_3] \n" _UK_ATOMIC_ADD_
" %[v_os_o3], v13, s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_4] \n" _UK_ATOMIC_ADD_
" %[v_os_o4], v14, s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_5] \n" _UK_ATOMIC_ADD_
" %[v_os_o5], v15, s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_6] \n" _UK_ATOMIC_ADD_
" %[v_os_o6], v16, s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_7] \n" _UK_ATOMIC_ADD_
" %[v_os_o7], v17, s[8:9] \n"
" s_mov_b64 exec, s[38:39] "
" \n"
" s_add_u32 s8, s59, s8 \n"
" s_addc_u32 s9, 0, s9 \n"
" ds_write_b64 %[v_sfl_sst], [%[c16],%[c17]] "
"offset:25344 \n"
" ds_write_b64 %[v_sfl_sst], [%[c18],%[c19]] "
"offset:29696 \n"
" ds_write_b64 %[v_sfl_sst], [%[c20],%[c21]] "
"offset:27520 \n"
" ds_write_b64 %[v_sfl_sst], [%[c22],%[c23]] "
"offset:31872 \n"
" s_waitcnt lgkmcnt(0) \n"
" s_barrier \n"
" ds_read_b32 v10, %[v_sfl_sld] offset:25344 \n"
" ds_read_b32 v11, %[v_sfl_sld] offset:25376 \n"
" ds_read_b32 v12, %[v_sfl_sld] offset:25408 \n"
" ds_read_b32 v13, %[v_sfl_sld] offset:25440 \n"
" ds_read_b32 v14, %[v_sfl_sld] offset:29696 \n"
" ds_read_b32 v15, %[v_sfl_sld] offset:29728 \n"
" ds_read_b32 v16, %[v_sfl_sld] offset:29760 \n"
" ds_read_b32 v17, %[v_sfl_sld] offset:29792 \n"
" s_waitcnt lgkmcnt(0) \n"
" s_mov_b64 exec, %[s_execflag_0] \n" _UK_ATOMIC_ADD_
" %[v_os_o0], v10, s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_1] \n" _UK_ATOMIC_ADD_
" %[v_os_o1], v11, s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_2] \n" _UK_ATOMIC_ADD_
" %[v_os_o2], v12, s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_3] \n" _UK_ATOMIC_ADD_
" %[v_os_o3], v13, s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_4] \n" _UK_ATOMIC_ADD_
" %[v_os_o4], v14, s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_5] \n" _UK_ATOMIC_ADD_
" %[v_os_o5], v15, s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_6] \n" _UK_ATOMIC_ADD_
" %[v_os_o6], v16, s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_7] \n" _UK_ATOMIC_ADD_
" %[v_os_o7], v17, s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
#undef _UK_MFMA_
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
#ifndef CK_TILE_FLATMM_UK_MFMA
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#endif
#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_BF16
#define _UK_MFMA_ "v_mfma_f32_16x16x16_bf16"
#elif CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_FP16
#define _UK_MFMA_ "v_mfma_f32_16x16x16_f16"
#endif
"s_mov_b32 s16, %[s_res_a0] \n"
"s_mov_b32 s17, %[s_res_a1] \n"
"s_mov_b32 s18, %[s_res_a2] \n"
"s_mov_b32 s19, %[s_res_a3] \n"
"s_mov_b32 s20, %[s_res_b0] \n"
"s_mov_b32 s21, %[s_res_b1] \n"
"s_mov_b32 s22, %[s_res_b2] \n"
"s_mov_b32 s23, %[s_res_b3] \n"
// "s_nop 4\n"
"; -- prefetch A0\n"
"s_add_u32 m0, 0, %[s_m0_init] \n"
"buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[smem_sz], %[s_m0_init] \n"
"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move a with cond \n"
"s_cselect_b32 s86, %[s_tile_os_a], 0 ; move a with cond \n"
"s_add_u32 s16, s86, s16 ; move a with cond \n"
"s_addc_u32 s17, 0, s17 ; move a with cond \n"
"; -- prefetch A1\n"
"buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
"s_add_u32 m0, 0, %[s_m0_init] \n"
"s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
"s_cselect_b32 s86, %[s_tile_os_a], 0 ; move a with cond \n"
"s_add_u32 s16, s86, s16 ; move a with cond \n"
"s_addc_u32 s17, 0, s17 ; move a with cond \n"
"; -- prefetch B0\n"
"buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072 \n"
"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
"s_cselect_b32 s86, %[s_tile_os_b], 0 ; move b with cond \n"
"s_add_u32 s20, s86, s20 ; move b with cond \n"
"s_addc_u32 s21, 0, s21 ; move b with cond \n"
"s_waitcnt vmcnt(40) \n"
"s_barrier \n"
"ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0]\n" // 1024: N stride, 64
// K stride
"ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1]\n"
"ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2]\n"
"ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3]\n"
"ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4]\n"
"ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5]\n"
"ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6]\n"
"ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7]\n"
"L_start%=: \n"
" s_waitcnt vmcnt(24) & lgkmcnt(0) \n"
" s_barrier \n" _UK_MFMA_
" %[v_acc_0], acc[0:1], v[64:65], %[v_acc_0] \n" _UK_MFMA_
" %[v_acc_0], acc[2:3], v[66:67], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[128:131], %[v_os_b0], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_0], acc[4:5], v[68:69], %[v_acc_0] \n" _UK_MFMA_
" %[v_acc_0], acc[6:7], v[70:71], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_0], acc[8:9], v[72:73], %[v_acc_0] \n" _UK_MFMA_
" %[v_acc_0], acc[10:11], v[74:75], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[132:135], %[v_os_b0], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_0], acc[12:13], v[76:77], %[v_acc_0] \n" _UK_MFMA_
" %[v_acc_0], acc[14:15], v[78:79], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_1], acc[0:1], v[80:81], %[v_acc_1] \n" _UK_MFMA_
" %[v_acc_1], acc[2:3], v[82:83], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[136:139], %[v_os_b0], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_1], acc[4:5], v[84:85], %[v_acc_1] \n" _UK_MFMA_
" %[v_acc_1], acc[6:7], v[86:87], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_1], acc[8:9], v[88:89], %[v_acc_1] \n" _UK_MFMA_
" %[v_acc_1], acc[10:11], v[90:91], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[140:143], %[v_os_b0], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_1], acc[12:13], v[92:93], %[v_acc_1] \n" _UK_MFMA_
" %[v_acc_1], acc[14:15], v[94:95], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_2], acc[16:17], v[64:65], %[v_acc_2] \n" _UK_MFMA_
" %[v_acc_2], acc[18:19], v[66:67], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[144:147], %[v_os_b1], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_2], acc[20:21], v[68:69], %[v_acc_2] \n" _UK_MFMA_
" %[v_acc_2], acc[22:23], v[70:71], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_2], acc[24:25], v[72:73], %[v_acc_2] \n" _UK_MFMA_
" %[v_acc_2], acc[26:27], v[74:75], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[148:151], %[v_os_b1], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_2], acc[28:29], v[76:77], %[v_acc_2] \n" _UK_MFMA_
" %[v_acc_2], acc[30:31], v[78:79], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_3], acc[16:17], v[80:81], %[v_acc_3] \n" _UK_MFMA_
" %[v_acc_3], acc[18:19], v[82:83], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[152:155], %[v_os_b1], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_3], acc[20:21], v[84:85], %[v_acc_3] \n" _UK_MFMA_
" %[v_acc_3], acc[22:23], v[86:87], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_3], acc[24:25], v[88:89], %[v_acc_3] \n" _UK_MFMA_
" %[v_acc_3], acc[26:27], v[90:91], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[156:159], %[v_os_b1], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_3], acc[28:29], v[92:93], %[v_acc_3] \n" _UK_MFMA_
" %[v_acc_3], acc[30:31], v[94:95], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[smem_sz], %[s_m0_init] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" %[v_acc_4], acc[32:33], v[64:65], %[v_acc_4] \n" _UK_MFMA_
" %[v_acc_4], acc[34:35], v[66:67], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[160:163], %[v_os_b2], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_4], acc[36:37], v[68:69], %[v_acc_4] \n" _UK_MFMA_
" %[v_acc_4], acc[38:39], v[70:71], %[v_acc_4] \n"
" ds_read_b128 v[96:99], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_0] "
"\n" _UK_MFMA_ " %[v_acc_4], acc[40:41], v[72:73], %[v_acc_4] \n" _UK_MFMA_
" %[v_acc_4], acc[42:43], v[74:75], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[164:167], %[v_os_b2], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_4], acc[44:45], v[76:77], %[v_acc_4] \n" _UK_MFMA_
" %[v_acc_4], acc[46:47], v[78:79], %[v_acc_4] \n"
" ds_read_b128 v[100:103], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_1] "
"\n" _UK_MFMA_ " %[v_acc_5], acc[32:33], v[80:81], %[v_acc_5] \n" _UK_MFMA_
" %[v_acc_5], acc[34:35], v[82:83], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[168:171], %[v_os_b2], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_5], acc[36:37], v[84:85], %[v_acc_5] \n" _UK_MFMA_
" %[v_acc_5], acc[38:39], v[86:87], %[v_acc_5] \n"
" ds_read_b128 v[104:107], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_2] "
"\n" _UK_MFMA_ " %[v_acc_5], acc[40:41], v[88:89], %[v_acc_5] \n" _UK_MFMA_
" %[v_acc_5], acc[42:43], v[90:91], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[172:175], %[v_os_b2], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_5], acc[44:45], v[92:93], %[v_acc_5] \n" _UK_MFMA_
" %[v_acc_5], acc[46:47], v[94:95], %[v_acc_5] \n"
" ds_read_b128 v[108:111], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_3] "
"\n" _UK_MFMA_ " %[v_acc_6], acc[48:49], v[64:65], %[v_acc_6] \n" _UK_MFMA_
" %[v_acc_6], acc[50:51], v[66:67], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[176:179], %[v_os_b3], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_6], acc[52:53], v[68:69], %[v_acc_6] \n" _UK_MFMA_
" %[v_acc_6], acc[54:55], v[70:71], %[v_acc_6] \n"
" ds_read_b128 v[112:115], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_4] "
"\n" _UK_MFMA_ " %[v_acc_6], acc[56:57], v[72:73], %[v_acc_6] \n" _UK_MFMA_
" %[v_acc_6], acc[58:59], v[74:75], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[180:183], %[v_os_b3], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_6], acc[60:61], v[76:77], %[v_acc_6] \n" _UK_MFMA_
" %[v_acc_6], acc[62:63], v[78:79], %[v_acc_6] \n"
" ds_read_b128 v[116:119], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_5] "
"\n" _UK_MFMA_ " %[v_acc_7], acc[48:49], v[80:81], %[v_acc_7] \n" _UK_MFMA_
" %[v_acc_7], acc[50:51], v[82:83], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[184:187], %[v_os_b3], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_7], acc[52:53], v[84:85], %[v_acc_7] \n" _UK_MFMA_
" %[v_acc_7], acc[54:55], v[86:87], %[v_acc_7] \n"
" ds_read_b128 v[120:123], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_6] "
"\n" _UK_MFMA_ " %[v_acc_7], acc[56:57], v[88:89], %[v_acc_7] \n" _UK_MFMA_
" %[v_acc_7], acc[58:59], v[90:91], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[188:191], %[v_os_b3], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_7], acc[60:61], v[92:93], %[v_acc_7] \n" _UK_MFMA_
" %[v_acc_7], acc[62:63], v[94:95], %[v_acc_7] \n"
" ds_read_b128 v[124:127], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_7] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" %[v_acc_8], acc[64:65], v[64:65], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_8], acc[66:67], v[66:67], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[192:195], %[v_os_b4], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_8], acc[68:69], v[68:69], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_8], acc[70:71], v[70:71], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_8], acc[72:73], v[72:73], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_8], acc[74:75], v[74:75], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[196:199], %[v_os_b4], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_8], acc[76:77], v[76:77], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_8], acc[78:79], v[78:79], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_9], acc[64:65], v[80:81], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_9], acc[66:67], v[82:83], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[200:203], %[v_os_b4], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_9], acc[68:69], v[84:85], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_9], acc[70:71], v[86:87], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_9], acc[72:73], v[88:89], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_9], acc[74:75], v[90:91], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[204:207], %[v_os_b4], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_9], acc[76:77], v[92:93], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_9], acc[78:79], v[94:95], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_10], acc[80:81], v[64:65], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_10], acc[82:83], v[66:67], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[208:211], %[v_os_b5], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_10], acc[84:85], v[68:69], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_10], acc[86:87], v[70:71], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_10], acc[88:89], v[72:73], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_10], acc[90:91], v[74:75], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[212:215], %[v_os_b5], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_10], acc[92:93], v[76:77], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_10], acc[94:95], v[78:79], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_11], acc[80:81], v[80:81], %[v_acc_11] \n" _UK_MFMA_
" %[v_acc_11], acc[82:83], v[82:83], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[216:219], %[v_os_b5], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_11], acc[84:85], v[84:85], %[v_acc_11] \n" _UK_MFMA_
" %[v_acc_11], acc[86:87], v[86:87], %[v_acc_11] \n" _UK_MFMA_
" %[v_acc_11], acc[88:89], v[88:89], %[v_acc_11] \n" _UK_MFMA_
" %[v_acc_11], acc[90:91], v[90:91], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[220:223], %[v_os_b5], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_11], acc[92:93], v[92:93], %[v_acc_11] \n" _UK_MFMA_
" %[v_acc_11], acc[94:95], v[94:95], %[v_acc_11] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" %[v_acc_12], acc[96:97], v[64:65], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_12], acc[98:99], v[66:67], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[224:227], %[v_os_b6], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_12], acc[100:101], v[68:69], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_12], acc[102:103], v[70:71], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_12], acc[104:105], v[72:73], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_12], acc[106:107], v[74:75], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[228:231], %[v_os_b6], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_12], acc[108:109], v[76:77], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_12], acc[110:111], v[78:79], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_13], acc[96:97], v[80:81], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_13], acc[98:99], v[82:83], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[232:235], %[v_os_b6], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_13], acc[100:101], v[84:85], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_13], acc[102:103], v[86:87], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_13], acc[104:105], v[88:89], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_13], acc[106:107], v[90:91], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[236:239], %[v_os_b6], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_13], acc[108:109], v[92:93], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_13], acc[110:111], v[94:95], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_14], acc[112:113], v[64:65], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_14], acc[114:115], v[66:67], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[240:243], %[v_os_b7], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_14], acc[116:117], v[68:69], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_14], acc[118:119], v[70:71], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_14], acc[120:121], v[72:73], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_14], acc[122:123], v[74:75], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[244:247], %[v_os_b7], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_14], acc[124:125], v[76:77], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_14], acc[126:127], v[78:79], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_15], acc[112:113], v[80:81], %[v_acc_15] \n" _UK_MFMA_
" %[v_acc_15], acc[114:115], v[82:83], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[248:251], %[v_os_b7], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_15], acc[116:117], v[84:85], %[v_acc_15] \n" _UK_MFMA_
" %[v_acc_15], acc[118:119], v[86:87], %[v_acc_15] \n" _UK_MFMA_
" %[v_acc_15], acc[120:121], v[88:89], %[v_acc_15] \n" _UK_MFMA_
" %[v_acc_15], acc[122:123], v[90:91], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[20:23], 0 offen offset:3072\n" _UK_MFMA_
" %[v_acc_15], acc[124:125], v[92:93], %[v_acc_15] \n" _UK_MFMA_
" %[v_acc_15], acc[126:127], v[94:95], %[v_acc_15] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
" s_cselect_b32 s86, %[s_tile_os_a], 0 \n"
" s_add_u32 s16, s86, s16 \n"
" s_addc_u32 s17, 0, s17 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s20, s86, s20 \n"
" s_addc_u32 s21, 0, s21 \n"
" ;------------------------------------------ \n"
" s_waitcnt vmcnt(24) & lgkmcnt(0) \n"
" s_barrier \n" _UK_MFMA_
" %[v_acc_0], acc[128:129], v[96:97], %[v_acc_0] \n" _UK_MFMA_
" %[v_acc_0], acc[130:131], v[98:99], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_0], acc[132:133], v[100:101], %[v_acc_0] \n" _UK_MFMA_
" %[v_acc_0], acc[134:135], v[102:103], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_0], acc[136:137], v[104:105], %[v_acc_0] \n" _UK_MFMA_
" %[v_acc_0], acc[138:139], v[106:107], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_0], acc[140:141], v[108:109], %[v_acc_0] \n" _UK_MFMA_
" %[v_acc_0], acc[142:143], v[110:111], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_1], acc[128:129], v[112:113], %[v_acc_1] \n" _UK_MFMA_
" %[v_acc_1], acc[130:131], v[114:115], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_1], acc[132:133], v[116:117], %[v_acc_1] \n" _UK_MFMA_
" %[v_acc_1], acc[134:135], v[118:119], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_1], acc[136:137], v[120:121], %[v_acc_1] \n" _UK_MFMA_
" %[v_acc_1], acc[138:139], v[122:123], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_1], acc[140:141], v[124:125], %[v_acc_1] \n" _UK_MFMA_
" %[v_acc_1], acc[142:143], v[126:127], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_2], acc[144:145], v[96:97], %[v_acc_2] \n" _UK_MFMA_
" %[v_acc_2], acc[146:147], v[98:99], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_2], acc[148:149], v[100:101], %[v_acc_2] \n" _UK_MFMA_
" %[v_acc_2], acc[150:151], v[102:103], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_2], acc[152:153], v[104:105], %[v_acc_2] \n" _UK_MFMA_
" %[v_acc_2], acc[154:155], v[106:107], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_2], acc[156:157], v[108:109], %[v_acc_2] \n" _UK_MFMA_
" %[v_acc_2], acc[158:159], v[110:111], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_3], acc[144:145], v[112:113], %[v_acc_3] \n" _UK_MFMA_
" %[v_acc_3], acc[146:147], v[114:115], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_3], acc[148:149], v[116:117], %[v_acc_3] \n" _UK_MFMA_
" %[v_acc_3], acc[150:151], v[118:119], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n" _UK_MFMA_
" %[v_acc_3], acc[152:153], v[120:121], %[v_acc_3] \n" _UK_MFMA_
" %[v_acc_3], acc[154:155], v[122:123], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_3], acc[156:157], v[124:125], %[v_acc_3] \n" _UK_MFMA_
" %[v_acc_3], acc[158:159], v[126:127], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
" s_add_u32 m0, 0, %[s_m0_init] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" %[v_acc_4], acc[160:161], v[96:97], %[v_acc_4] \n" _UK_MFMA_
" %[v_acc_4], acc[162:163], v[98:99], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_4], acc[164:165], v[100:101], %[v_acc_4] \n" _UK_MFMA_
" %[v_acc_4], acc[166:167], v[102:103], %[v_acc_4] \n"
" ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0] \n" _UK_MFMA_
" %[v_acc_4], acc[168:169], v[104:105], %[v_acc_4] \n" _UK_MFMA_
" %[v_acc_4], acc[170:171], v[106:107], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_4], acc[172:173], v[108:109], %[v_acc_4] \n" _UK_MFMA_
" %[v_acc_4], acc[174:175], v[110:111], %[v_acc_4] \n"
" ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1] \n" _UK_MFMA_
" %[v_acc_5], acc[160:161], v[112:113], %[v_acc_5] \n" _UK_MFMA_
" %[v_acc_5], acc[162:163], v[114:115], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_5], acc[164:165], v[116:117], %[v_acc_5] \n" _UK_MFMA_
" %[v_acc_5], acc[166:167], v[118:119], %[v_acc_5] \n"
" ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2] "
"\n" _UK_MFMA_ " %[v_acc_5], acc[168:169], v[120:121], %[v_acc_5] \n" _UK_MFMA_
" %[v_acc_5], acc[170:171], v[122:123], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_5], acc[172:173], v[124:125], %[v_acc_5] \n" _UK_MFMA_
" %[v_acc_5], acc[174:175], v[126:127], %[v_acc_5] \n"
" ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3] "
"\n" _UK_MFMA_ " %[v_acc_6], acc[176:177], v[96:97], %[v_acc_6] \n" _UK_MFMA_
" %[v_acc_6], acc[178:179], v[98:99], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_6], acc[180:181], v[100:101], %[v_acc_6] \n" _UK_MFMA_
" %[v_acc_6], acc[182:183], v[102:103], %[v_acc_6] \n"
" ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4] "
"\n" _UK_MFMA_ " %[v_acc_6], acc[184:185], v[104:105], %[v_acc_6] \n" _UK_MFMA_
" %[v_acc_6], acc[186:187], v[106:107], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_6], acc[188:189], v[108:109], %[v_acc_6] \n" _UK_MFMA_
" %[v_acc_6], acc[190:191], v[110:111], %[v_acc_6] \n"
" ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5] "
"\n" _UK_MFMA_ " %[v_acc_7], acc[176:177], v[112:113], %[v_acc_7] \n" _UK_MFMA_
" %[v_acc_7], acc[178:179], v[114:115], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_7], acc[180:181], v[116:117], %[v_acc_7] \n" _UK_MFMA_
" %[v_acc_7], acc[182:183], v[118:119], %[v_acc_7] \n"
" ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6] "
"\n" _UK_MFMA_ " %[v_acc_7], acc[184:185], v[120:121], %[v_acc_7] \n" _UK_MFMA_
" %[v_acc_7], acc[186:187], v[122:123], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_7], acc[188:189], v[124:125], %[v_acc_7] \n" _UK_MFMA_
" %[v_acc_7], acc[190:191], v[126:127], %[v_acc_7] \n"
" ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" %[v_acc_8], acc[192:193], v[96:97], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_8], acc[194:195], v[98:99], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_8], acc[196:197], v[100:101], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_8], acc[198:199], v[102:103], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_8], acc[200:201], v[104:105], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_8], acc[202:203], v[106:107], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_8], acc[204:205], v[108:109], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_8], acc[206:207], v[110:111], %[v_acc_8] \n" _UK_MFMA_
" %[v_acc_9], acc[192:193], v[112:113], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_9], acc[194:195], v[114:115], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_9], acc[196:197], v[116:117], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_9], acc[198:199], v[118:119], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_9], acc[200:201], v[120:121], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_9], acc[202:203], v[122:123], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_9], acc[204:205], v[124:125], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_9], acc[206:207], v[126:127], %[v_acc_9] \n" _UK_MFMA_
" %[v_acc_10], acc[208:209], v[96:97], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_10], acc[210:211], v[98:99], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_10], acc[212:213], v[100:101], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_10], acc[214:215], v[102:103], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_10], acc[216:217], v[104:105], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_10], acc[218:219], v[106:107], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_10], acc[220:221], v[108:109], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_10], acc[222:223], v[110:111], %[v_acc_10] \n" _UK_MFMA_
" %[v_acc_11], acc[208:209], v[112:113], %[v_acc_11] \n" _UK_MFMA_
" %[v_acc_11], acc[210:211], v[114:115], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_11], acc[212:213], v[116:117], %[v_acc_11] \n" _UK_MFMA_
" %[v_acc_11], acc[214:215], v[118:119], %[v_acc_11] \n" _UK_MFMA_
" %[v_acc_11], acc[216:217], v[120:121], %[v_acc_11] \n" _UK_MFMA_
" %[v_acc_11], acc[218:219], v[122:123], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_11], acc[220:221], v[124:125], %[v_acc_11] \n" _UK_MFMA_
" %[v_acc_11], acc[222:223], v[126:127], %[v_acc_11] \n"
" s_waitcnt vmcnt(32) \n" _UK_MFMA_
" %[v_acc_12], acc[224:225], v[96:97], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_12], acc[226:227], v[98:99], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_12], acc[228:229], v[100:101], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_12], acc[230:231], v[102:103], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_12], acc[232:233], v[104:105], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_12], acc[234:235], v[106:107], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_12], acc[236:237], v[108:109], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_12], acc[238:239], v[110:111], %[v_acc_12] \n" _UK_MFMA_
" %[v_acc_13], acc[224:225], v[112:113], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_13], acc[226:227], v[114:115], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_13], acc[228:229], v[116:117], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_13], acc[230:231], v[118:119], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_13], acc[232:233], v[120:121], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_13], acc[234:235], v[122:123], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_13], acc[236:237], v[124:125], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_13], acc[238:239], v[126:127], %[v_acc_13] \n" _UK_MFMA_
" %[v_acc_14], acc[240:241], v[96:97], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_14], acc[242:243], v[98:99], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[20:23], 0 offen \n" _UK_MFMA_
" %[v_acc_14], acc[244:245], v[100:101], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_14], acc[246:247], v[102:103], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_14], acc[248:249], v[104:105], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_14], acc[250:251], v[106:107], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[20:23], 0 offen offset:1024 \n" _UK_MFMA_
" %[v_acc_14], acc[252:253], v[108:109], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_14], acc[254:255], v[110:111], %[v_acc_14] \n" _UK_MFMA_
" %[v_acc_15], acc[240:241], v[112:113], %[v_acc_15] \n" _UK_MFMA_
" %[v_acc_15], acc[242:243], v[114:115], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048 \n" _UK_MFMA_
" %[v_acc_15], acc[244:245], v[116:117], %[v_acc_15] \n" _UK_MFMA_
" %[v_acc_15], acc[246:247], v[118:119], %[v_acc_15] \n" _UK_MFMA_
" %[v_acc_15], acc[248:249], v[120:121], %[v_acc_15] \n" _UK_MFMA_
" %[v_acc_15], acc[250:251], v[122:123], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072 \n" _UK_MFMA_
" %[v_acc_15], acc[252:253], v[124:125], %[v_acc_15] \n" _UK_MFMA_
" %[v_acc_15], acc[254:255], v[126:127], %[v_acc_15] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
" s_cselect_b32 s86, %[s_tile_os_a], 0 \n"
" s_add_u32 s16, s86, s16 \n"
" s_addc_u32 s17, 0, s17 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s20, s86, s20 \n"
" s_addc_u32 s21, 0, s21 \n"
" s_branch L_start%= \n"
"L_end%=: \n"
" s_nop 2 \n"
#undef _UK_MFMA_
......@@ -43,4 +43,5 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
......@@ -230,7 +230,15 @@ struct PageBlockNavigator
CK_TILE_HOST_DEVICE
DataType* get_block_ptr(index_t block_index) const
{
return physical_blocks + physical_block_indices[block_index] * block_stride + fixed_offset;
if(block_index < num_blocks)
{
return physical_blocks + physical_block_indices[block_index] * block_stride +
fixed_offset;
}
else
{
return nullptr;
}
}
CK_TILE_HOST_DEVICE int32_t get_block_index(const WindowOrigin& global_window_origin) const
......
......@@ -304,64 +304,64 @@ struct FmhaBwdDQDKDVKernel
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t batch_stride_dk,
ck_tile::index_t batch_stride_dv,
ck_tile::index_t batch_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
MakeKargsImpl(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t batch_stride_dk,
ck_tile::index_t batch_stride_dv,
ck_tile::index_t batch_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
......@@ -470,7 +470,8 @@ struct FmhaBwdDQDKDVKernel
return kargs;
}
template <bool Cond = kIsGroupMode>
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
......@@ -484,9 +485,129 @@ struct FmhaBwdDQDKDVKernel
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t batch_stride_dk,
ck_tile::index_t batch_stride_dv,
ck_tile::index_t batch_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
lse_ptr,
do_ptr,
d_ptr,
rand_val_ptr,
dk_ptr,
dv_ptr,
dbias_ptr,
dq_acc_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_do,
stride_dq_acc,
stride_dk,
stride_dv,
stride_dbias,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_dq_acc,
nhead_stride_dk,
nhead_stride_dv,
nhead_stride_dbias,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_randval,
batch_stride_do,
batch_stride_lsed,
batch_stride_dq_acc,
batch_stride_dk,
batch_stride_dv,
batch_stride_dbias,
split_stride_dq_acc,
window_size_left,
window_size_right,
mask_type,
p_drop,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
......@@ -513,13 +634,134 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t batch_stride_dk,
ck_tile::index_t batch_stride_dv,
ck_tile::index_t batch_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
const std::tuple<const void*, const void*>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
lse_ptr,
do_ptr,
d_ptr,
rand_val_ptr,
dk_ptr,
dv_ptr,
dbias_ptr,
dq_acc_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_do,
stride_dq_acc,
stride_dk,
stride_dv,
stride_dbias,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_dq_acc,
nhead_stride_dk,
nhead_stride_dv,
nhead_stride_dbias,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_randval,
batch_stride_do,
batch_stride_lsed,
batch_stride_dq_acc,
batch_stride_dk,
batch_stride_dv,
batch_stride_dbias,
split_stride_dq_acc,
window_size_left,
window_size_right,
mask_type,
p_drop,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargsImpl(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
......@@ -616,6 +858,208 @@ struct FmhaBwdDQDKDVKernel
return kargs;
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
lse_ptr,
do_ptr,
d_ptr,
rand_val_ptr,
dk_ptr,
dv_ptr,
dbias_ptr,
dq_acc_ptr,
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_k_ptr,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_do,
stride_dq_acc,
stride_dk,
stride_dv,
stride_dbias,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_dq_acc,
nhead_stride_dk,
nhead_stride_dv,
nhead_stride_dbias,
split_stride_dq_acc,
window_size_left,
window_size_right,
mask_type,
p_drop,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<const void*, const void*>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
lse_ptr,
do_ptr,
d_ptr,
rand_val_ptr,
dk_ptr,
dv_ptr,
dbias_ptr,
dq_acc_ptr,
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_k_ptr,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_do,
stride_dq_acc,
stride_dk,
stride_dv,
stride_dbias,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_dq_acc,
nhead_stride_dk,
nhead_stride_dv,
nhead_stride_dbias,
split_stride_dq_acc,
window_size_left,
window_size_right,
mask_type,
p_drop,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
{
......
......@@ -64,12 +64,13 @@ struct FmhaFwdKernel
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on
__host__ static std::string GetName()
CK_TILE_HOST static std::string GetName()
{
// sync with generate.py
// clang-format off
using bfs = typename FmhaPipeline::BlockFmhaShape;
using gbr = typename bfs::Gemm0BlockWarps;
using g0br = typename bfs::Gemm0BlockWarps;
using g1br = typename bfs::Gemm1BlockWarps;
using gwt = typename bfs::Gemm0WarpTile;
#define _SS_ std::string
#define _TS_ std::to_string
......@@ -81,11 +82,12 @@ struct FmhaFwdKernel
if (kPadHeadDimV) n += "dv";
return n.empty() ? n : std::string("p") + n; }();
return
_SS_("fmha_fwd_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s<QDataType>::name) +
_SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
......@@ -265,50 +267,50 @@ struct FmhaFwdKernel
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargsImpl(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
......@@ -397,8 +399,9 @@ struct FmhaFwdKernel
return kargs;
}
template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
......@@ -406,9 +409,99 @@ struct FmhaFwdKernel
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
rand_val_ptr,
lse_ptr,
o_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale_s,
scale_p,
scale_o,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_randval,
batch_stride_lse,
batch_stride_o,
window_size_left,
window_size_right,
mask_type,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
......@@ -429,13 +522,104 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
const std::tuple<const void*, const void*>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
rand_val_ptr,
lse_ptr,
o_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale_s,
scale_p,
scale_o,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_randval,
batch_stride_lse,
batch_stride_o,
window_size_left,
window_size_right,
mask_type,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargsImpl(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
......@@ -520,15 +704,173 @@ struct FmhaFwdKernel
return kargs;
}
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
rand_val_ptr,
lse_ptr,
o_ptr,
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_k_ptr,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale_s,
scale_p,
scale_o,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
window_size_left,
window_size_right,
mask_type,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<const void*, const void*>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
rand_val_ptr,
lse_ptr,
o_ptr,
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_k_ptr,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale_s,
scale_p,
scale_o,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
window_size_left,
window_size_right,
mask_type,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_);
}
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
......@@ -655,7 +997,7 @@ struct FmhaFwdKernel
{
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0BlockLength>{}),
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
else
......@@ -722,7 +1064,7 @@ struct FmhaFwdKernel
[&]() {
if constexpr(FmhaPipeline::kQLoadOnce)
return make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kK0BlockLength>{});
number<FmhaPipeline::kSubQKHeaddim>{});
else
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{});
}(),
......
......@@ -339,7 +339,7 @@ struct FmhaFwdSplitKVCombineKernel
number<FmhaPipeline::kAlignmentOacc>{},
number<1>{});
auto o_acc_dram_view = pad_tensor_view(
const auto o_acc_dram_view = pad_tensor_view(
o_acc_dram_naive,
make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<false, kPadSeqLenQ, kPadHeadDimV>{});
......
......@@ -35,6 +35,7 @@ struct FmhaFwdSplitKVKernel
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
using OaccDataType = remove_cvref_t<typename FmhaPipeline::OaccDataType>;
using ODataType = remove_cvref_t<typename FmhaPipeline::ODataType>;
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
......@@ -46,8 +47,7 @@ struct FmhaFwdSplitKVKernel
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
static_assert(!kIsGroupMode || (kIsGroupMode && !kIsPagedKV),
"paged-kvcache only supported by batch mode kernels");
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
......@@ -65,7 +65,8 @@ struct FmhaFwdSplitKVKernel
// sync with generate.py
// clang-format off
using bfs = typename FmhaPipeline::BlockFmhaShape;
using gbr = typename bfs::Gemm0BlockWarps;
using g0br = typename bfs::Gemm0BlockWarps;
using g1br = typename bfs::Gemm1BlockWarps;
using gwt = typename bfs::Gemm0WarpTile;
#define _SS_ std::string
#define _TS_ std::to_string
......@@ -77,11 +78,12 @@ struct FmhaFwdSplitKVKernel
if (kPadHeadDimV) n += "dv";
return n.empty() ? n : std::string("p") + n; }();
return
_SS_("fmha_fwd_splitkv_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s<QDataType>::name) +
_SS_("fmha_fwd_splitkv_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
......@@ -170,13 +172,18 @@ struct FmhaFwdSplitKVKernel
float scale_p;
};
struct PageBlockTableKargs
struct CommonPageBlockTableKargs
{
const int32_t* block_table_ptr;
ck_tile::index_t batch_stride_block_table;
ck_tile::index_t page_block_size;
};
struct GroupModePageBlockTableKargs : CommonPageBlockTableKargs
{
bool is_gappy = false;
};
struct CacheBatchIdxKargs
{
const int32_t* cache_batch_idx;
......@@ -191,13 +198,15 @@ struct FmhaFwdSplitKVKernel
EmptyKargs<0>>>,
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
std::conditional_t<kIsPagedKV, PageBlockTableKargs, CacheBatchIdxKargs>
std::conditional_t<kIsPagedKV, CommonPageBlockTableKargs, CacheBatchIdxKargs>
{
const int32_t* seqlen_k_ptr;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_k; // when using paged-kvcache, this will be stride/size for
// single kcache page-block
ck_tile::index_t batch_stride_v; // when using paged-kvcache, this will be stride/size for
// single vcache page-block
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
};
......@@ -210,14 +219,17 @@ struct FmhaFwdSplitKVKernel
AlibiKargs,
EmptyKargs<0>>>,
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
std::conditional_t<kIsPagedKV, GroupModePageBlockTableKargs, EmptyKargs<3>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr;
ck_tile::index_t batch_stride_k; // only used for paged-kvcache
ck_tile::index_t batch_stride_v; // only used for paged-kvcache
ck_tile::index_t batch_stride_k; // only used for paged-kvcache, this will be stride/size
// for single kcache page-block
ck_tile::index_t batch_stride_v; // only used for paged-kvcache, this will be stride/size
// for single vcache page-block
};
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
......@@ -228,8 +240,10 @@ struct FmhaFwdSplitKVKernel
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* lse_acc_ptr,
void* o_acc_ptr,
void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise
final lse */
void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final
o */
ck_tile::index_t batch,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k, // only used if 'seqlen_k_ptr' is not specified
......@@ -350,8 +364,10 @@ struct FmhaFwdSplitKVKernel
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* lse_acc_ptr,
void* o_acc_ptr,
void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise
final lse */
void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final
o */
ck_tile::index_t batch,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
......@@ -361,6 +377,10 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
ck_tile::index_t num_splits,
const void* block_table_ptr,
ck_tile::index_t batch_stride_block_table,
ck_tile::index_t page_block_size,
bool is_gappy,
float scale_s,
float scale_p,
ck_tile::index_t stride_q,
......@@ -414,6 +434,7 @@ struct FmhaFwdSplitKVKernel
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for fp8_static_quant args
{}, // placeholder for paged-block table
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
......@@ -441,6 +462,13 @@ struct FmhaFwdSplitKVKernel
{
kargs.scale_p = scale_p;
}
if constexpr(kIsPagedKV)
{
kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
kargs.batch_stride_block_table = batch_stride_block_table;
kargs.page_block_size = page_block_size;
kargs.is_gappy = is_gappy;
}
return kargs;
}
......@@ -474,11 +502,13 @@ struct FmhaFwdSplitKVKernel
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
long_index_t batch_offset_q = 0;
long_index_t batch_offset_k = 0;
long_index_t batch_offset_v = 0;
long_index_t batch_offset_k = 0; // unused for paged-kvcache
long_index_t batch_offset_v = 0; // unused for paged-kvcache
long_index_t batch_offset_bias = 0;
long_index_t batch_offset_lse_acc = 0;
long_index_t batch_offset_o_acc = 0;
index_t kv_l2p_offset =
0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache
if constexpr(kIsGroupMode)
{
......@@ -488,7 +518,6 @@ struct FmhaFwdSplitKVKernel
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start * kargs.stride_v;
......@@ -523,6 +552,15 @@ struct FmhaFwdSplitKVKernel
{
kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
}
if constexpr(kIsPagedKV)
{
if(kargs.is_gappy)
{
// seqstart_k_ptr has different meaning in this case
kv_l2p_offset = kargs.seqstart_k_ptr[i_batch];
}
}
}
else
{
......@@ -568,9 +606,9 @@ struct FmhaFwdSplitKVKernel
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
batch_offset_v;
OaccDataType* o_acc_ptr = reinterpret_cast<OaccDataType*>(kargs.o_acc_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc +
batch_offset_o_acc + i_split * kargs.split_stride_o_acc;
ODataType* o_acc_ptr = reinterpret_cast<ODataType*>(kargs.o_acc_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc +
batch_offset_o_acc + i_split * kargs.split_stride_o_acc;
// Q/K/V DRAM and DRAM window
const auto q_dram = [&]() {
......@@ -584,15 +622,15 @@ struct FmhaFwdSplitKVKernel
{
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0BlockLength>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<false, kPadHeadDimQ>{});
}
else
{
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
sequence<false, kPadHeadDimQ>{});
}
}();
......@@ -607,7 +645,7 @@ struct FmhaFwdSplitKVKernel
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
sequence<false, kPadHeadDimQ>{});
};
const auto k_dram = [&]() {
if constexpr(kIsPagedKV)
......@@ -640,7 +678,7 @@ struct FmhaFwdSplitKVKernel
return pad_tensor_view(
v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
sequence<kPadHeadDimV, false>{});
}
else
{
......@@ -654,7 +692,7 @@ struct FmhaFwdSplitKVKernel
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
sequence<false, kPadSeqLenK>{});
}
};
const auto v_dram = [&]() {
......@@ -675,7 +713,7 @@ struct FmhaFwdSplitKVKernel
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
i_batch_ * kargs.batch_stride_block_table;
const index_t num_blocks =
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
const long_index_t fixed_offset =
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
......@@ -683,14 +721,15 @@ struct FmhaFwdSplitKVKernel
return make_page_block_navigator<const KDataType, 0>(
kargs.k_ptr,
kargs.batch_stride_k,
kargs.batch_stride_k, // kcache page-block stride/size
fixed_offset,
block_indices,
num_blocks,
kargs.page_block_size,
k_dram,
make_k_dram(nullptr,
kargs.seqlen_k - (num_blocks - 1) * kargs.page_block_size));
(kv_l2p_offset + kargs.seqlen_k) -
(num_blocks - 1) * kargs.page_block_size));
}
else
{
......@@ -705,7 +744,7 @@ struct FmhaFwdSplitKVKernel
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
i_batch_ * kargs.batch_stride_block_table;
const index_t num_blocks =
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
const long_index_t fixed_offset =
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
......@@ -713,14 +752,15 @@ struct FmhaFwdSplitKVKernel
return make_page_block_navigator<const VDataType, 1>(
kargs.v_ptr,
kargs.batch_stride_v,
kargs.batch_stride_v, // vcache page-block stride/size
fixed_offset,
block_indices,
num_blocks,
kargs.page_block_size,
v_dram,
make_v_dram(nullptr,
kargs.seqlen_k - (num_blocks - 1) * kargs.page_block_size));
(kv_l2p_offset + kargs.seqlen_k) -
(num_blocks - 1) * kargs.page_block_size));
}
else
{
......@@ -733,7 +773,7 @@ struct FmhaFwdSplitKVKernel
[&]() {
if constexpr(FmhaPipeline::kQLoadOnce)
return make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kK0BlockLength>{});
number<FmhaPipeline::kSubQKHeaddim>{});
else
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{});
}(),
......@@ -764,9 +804,8 @@ struct FmhaFwdSplitKVKernel
number<FmhaPipeline::kAlignmentBias>{},
number<1>{});
return pad_tensor_view(bias_dram_naive,
bias_dram_window_lengths,
sequence<kPadSeqLenQ, kPadSeqLenK>{});
return pad_tensor_view(
bias_dram_naive, bias_dram_window_lengths, sequence<false, kPadSeqLenK>{});
}();
return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
......@@ -868,6 +907,7 @@ struct FmhaFwdSplitKVKernel
mask,
position_encoding,
kargs.scale_s,
kv_l2p_offset,
smem_ptr);
}
else
......@@ -884,6 +924,7 @@ struct FmhaFwdSplitKVKernel
mask,
position_encoding,
kargs.scale_s,
kv_l2p_offset,
smem_ptr);
}
}();
......@@ -894,7 +935,7 @@ struct FmhaFwdSplitKVKernel
o_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.stride_o_acc, 1),
number<1>{},
number<FmhaPipeline::kAlignmentOacc>{},
number<1>{});
return pad_tensor_view(
......
......@@ -18,16 +18,16 @@ struct FmhaFwdSplitKVTilePartitioner
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits)
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
ck_tile::integer_divide_ceil(hdim_v, kN1),
nhead * num_splits,
ck_tile::integer_divide_ceil(hdim_v, kN1) * num_splits,
nhead,
batch_size);
}
......@@ -42,8 +42,9 @@ struct FmhaFwdSplitKVTilePartitioner
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(blockIdx.x, num_tile_n1);
const auto [i_nhead, i_split] = f(blockIdx.y, num_splits);
const auto [mn, i_split] = f(blockIdx.x, num_splits);
const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
......
......@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
auto k_lds_write_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
auto k_lds_read_window =
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK0>{}),
k_lds_write_window.get_window_origin(),
Policy::template MakeKRegSliceBlockDescriptor<Problem>());
Policy::template MakeKRegBlockDescriptor<Problem>());
auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
Policy::template MakeKRegBlockDescriptor<Problem>());
......@@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
auto v_lds_write_window =
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0});
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddim>{}), {0, 0});
auto v_lds_read_window =
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK2>{}),
v_lds_write_window.get_window_origin(),
Policy::template MakeVRegSliceBlockDescriptor<Problem>());
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
Policy::template MakeVRegBlockDescriptor<Problem>());
Policy::template MakeVRegBlockDescriptor<Problem>());
//------------------------------------------------------------------
// KT, Reg ->LDS ->Reg
......@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
auto shuffled_k_lds_write_window = make_tile_window(
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
......@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
block_sync_lds();
v_reg_tensor = load_tile(v_lds_read_window);
auto v_reg_tensor = load_tile(v_lds_read_window);
block_sync_lds();
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
......@@ -276,7 +273,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
auto q_lds_read_window =
make_tile_window(q_lds_window.get_bottom_tensor_view(),
......@@ -297,7 +294,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
auto shuffled_q_lds_write_window = make_tile_window(
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
......@@ -322,7 +319,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
auto do_lds_read_window =
make_tile_window(do_lds_window.get_bottom_tensor_view(),
......@@ -341,7 +338,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
auto shuffled_do_lds_write_window = make_tile_window(
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
......@@ -483,9 +480,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
index_t i_total_loops = 0;
index_t seqlen_q_step = seqlen_q_start;
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0");
static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0");
static_assert(kM0 == kK1, "kM0 should equal to kK1");
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2");
static_assert(kM0 == kK3, "kM0 should equal to kK3");
constexpr index_t k4_loops = kN0 / kK4;
......
......@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
auto k_lds_write_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
auto k_lds_read_window =
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK0>{}),
k_lds_write_window.get_window_origin(),
Policy::template MakeKRegSliceBlockDescriptor<Problem>());
Policy::template MakeKRegBlockDescriptor<Problem>());
auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
Policy::template MakeKRegBlockDescriptor<Problem>());
......@@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
auto v_lds_write_window =
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0});
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddim>{}), {0, 0});
auto v_lds_read_window =
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK2>{}),
v_lds_write_window.get_window_origin(),
Policy::template MakeVRegSliceBlockDescriptor<Problem>());
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
Policy::template MakeVRegBlockDescriptor<Problem>());
Policy::template MakeVRegBlockDescriptor<Problem>());
//------------------------------------------------------------------
// KT, Reg ->LDS ->Reg
......@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
auto shuffled_k_lds_write_window = make_tile_window(
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
......@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
block_sync_lds();
v_reg_tensor = load_tile(v_lds_read_window);
auto v_reg_tensor = load_tile(v_lds_read_window);
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
auto q_dram_window =
......@@ -275,7 +272,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
auto q_lds_read_window =
make_tile_window(q_lds_window.get_bottom_tensor_view(),
......@@ -296,7 +293,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
auto shuffled_q_lds_write_window = make_tile_window(
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
......@@ -321,7 +318,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
auto do_lds_read_window =
make_tile_window(do_lds_window.get_bottom_tensor_view(),
......@@ -340,7 +337,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
auto shuffled_do_lds_write_window = make_tile_window(
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
......@@ -482,9 +479,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
index_t i_total_loops = 0;
index_t seqlen_q_step = seqlen_q_start;
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0");
static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0");
static_assert(kM0 == kK1, "kM0 should equal to kK1");
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2");
static_assert(kM0 == kK3, "kM0 should equal to kK3");
constexpr index_t k4_loops = kN0 / kK4;
......
......@@ -5,9 +5,8 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
......@@ -27,20 +26,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
using GemmProblem =
GemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::QDataType,
......@@ -66,20 +60,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
{
using GemmProblem =
GemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kVHeaddim,
Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadHeadDimV,
Problem::kPadHeadDimV,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
BlockGemmProblem<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kVHeaddim,
Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
......@@ -104,20 +93,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
{
using GemmProblem =
GemmPipelineProblem<typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK2>,
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
typename Problem::BlockFmhaShape::Gemm2WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
BlockGemmProblem<typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK2>,
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::OGradDataType,
......@@ -143,20 +127,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm()
{
using GemmProblem =
GemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK3>,
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
typename Problem::BlockFmhaShape::Gemm3WarpTile>,
TileGemmTraits<Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
Problem::kPadSeqLenK,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
BlockGemmProblem<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK3>,
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
......@@ -181,20 +160,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm()
{
using GemmProblem =
GemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK4>,
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
typename Problem::BlockFmhaShape::Gemm4WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadHeadDimQ,
Problem::kPadSeqLenK,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
BlockGemmProblem<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK4>,
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
......@@ -222,7 +196,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using QDataType = remove_cvref_t<typename Problem::QDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kMaxVecLoad = 16 / sizeof(QDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(QDataType);
......@@ -241,7 +215,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kMaxVecLoad = 16 / sizeof(KDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(KDataType);
......@@ -260,7 +234,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kMaxVecLoad = 16 / sizeof(VDataType);
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
......@@ -280,7 +254,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kMaxVecLoad = 16 / sizeof(OGradDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(OGradDataType);
......@@ -341,7 +315,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
......@@ -353,7 +327,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
return total_pixels / GetAlignmentK<Problem>();
......@@ -364,7 +338,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
......@@ -402,7 +376,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t K1 = GetAlignmentK<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
......@@ -425,7 +399,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t K1 = GetAlignmentV<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
......@@ -448,7 +422,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t K1 = GetAlignmentQ<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
......@@ -471,7 +445,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t K1 = GetAlignmentOGrad<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
......@@ -842,44 +816,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsWriteBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKRegSliceBlockDescriptor()
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto k_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode);
return k_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKRegBlockDescriptor()
{
......@@ -891,7 +833,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
......@@ -916,45 +858,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kVPack = GetSmemKPackV<Problem>();
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kVPack>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegSliceBlockDescriptor()
{
using BlockGemm = remove_cvref_t<decltype(GetOGradVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto v_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode);
return v_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegBlockDescriptor()
{
......@@ -966,7 +876,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
......@@ -992,7 +902,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t K1 = GetAlignmentK<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
......@@ -1074,7 +984,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
......@@ -1118,7 +1028,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t K1 = GetAlignmentQ<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
......@@ -1281,7 +1191,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
// Hold full block data
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
......@@ -1325,7 +1235,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t K1 = GetAlignmentOGrad<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
......@@ -1885,6 +1795,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0;
static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim;
static constexpr index_t kK0 = Problem::BlockFmhaShape::kK0;
static constexpr index_t kK2 = Problem::BlockFmhaShape::kK2;
static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4;
static constexpr index_t WarpGemmM =
......@@ -1899,14 +1811,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
// Compute
static constexpr index_t Gemm0MFMA =
kM0 * kN0 * kQKHeaddim /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
kM0 * kN0 * kK0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm1MFMA =
kM0 * kN0 * kVHeaddim /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm2MFMA =
kN0 * kVHeaddim * kM0 /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm2MFMA =
kM0 * kN0 * kK2 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm3MFMA =
kN0 * kQKHeaddim * kM0 /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
......@@ -1929,13 +1839,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
static constexpr index_t SGradT_LDS_READ_P1 =
kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
static constexpr index_t Q_LDS_READ =
kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ<Problem>();
static constexpr index_t Q_LDS_READ = kM0 * kK0 / kBlockSize / GetAlignmentQ<Problem>();
static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
static constexpr index_t SGradT_LDS_READ_P2 =
kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
static constexpr index_t OGrad_LDS_READ =
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
kM0 * kK2 / kBlockSize / GetAlignmentOGrad<Problem>();
static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
// LDS Write
......
......@@ -12,6 +12,16 @@ namespace detail {
template <index_t N>
struct log2;
template <>
struct log2<4> : std::integral_constant<index_t, 2>
{
};
template <>
struct log2<8> : std::integral_constant<index_t, 3>
{
};
template <>
struct log2<16> : std::integral_constant<index_t, 4>
{
......@@ -72,18 +82,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline
{
if constexpr(kHeadDimV <= 32)
{
constexpr std::array<int, 4> occupancy{3, 3, 3, 1};
return occupancy[detail::log2<kMaxSplits>::value - 4];
constexpr std::array occupancy{3, 3, 3, 3, 3, 1};
return occupancy[detail::log2<kMaxSplits>::value - 2];
}
else if constexpr(kHeadDimV <= 128)
{
constexpr std::array<int, 4> occupancy{3, 3, 2, 1};
return occupancy[detail::log2<kMaxSplits>::value - 4];
constexpr std::array occupancy{3, 3, 3, 3, 2, 1};
return occupancy[detail::log2<kMaxSplits>::value - 2];
}
else if constexpr(kHeadDimV <= 256)
{
constexpr std::array<int, 4> occupancy{2, 2, 2, 1};
return occupancy[detail::log2<kMaxSplits>::value - 4];
constexpr std::array occupancy{2, 2, 2, 2, 2, 1};
return occupancy[detail::log2<kMaxSplits>::value - 2];
}
}
}();
......@@ -138,9 +148,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto lse_accum = make_static_distributed_tensor<LSEDataType>(
Policy::template MakeLSEaccRegTileDistribution<Problem>());
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, max(kMaxSplits, warp_size)])
// this will extend the distributed tensor width so that each thread in wave have data to
// reduce.
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits])
// and fill up -INF values outside the [kM0, num_splits] region.
{
constexpr auto spans = decltype(lse_accum)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
......
......@@ -10,11 +10,26 @@ namespace ck_tile {
struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
template <index_t BlockSize, index_t M, index_t N, typename DataType>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeForTile()
{
constexpr index_t PixelsPerThread = (M * N) / BlockSize;
static_assert(0 < PixelsPerThread);
constexpr index_t MaxNPerThread = 16 / sizeof(DataType);
constexpr index_t NPerThread = min(MaxNPerThread, PixelsPerThread);
return NPerThread;
}
// alignment for dram lse tile (shape=[kMaxSplits, kM0])
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentLSE()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
return 16 / sizeof(LSEDataType);
return GetVectorSizeForTile<Problem::kBlockSize,
Problem::kMaxSplits,
Problem::kM0,
typename Problem::LSEDataType>();
}
template <typename Problem>
......@@ -47,29 +62,31 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
MakeLSEaccLdsBlockDescriptor<Problem>().get_element_space_size();
}
// shape=[kMaxSplits, kM0]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNumWarps = Problem::kNumWarps;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t NPerThread = 16 / sizeof(LSEDataType);
constexpr index_t NThreads = kNPerBlock / NPerThread;
constexpr index_t NPerThread =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr index_t NThreads = kNPerBlock / NPerThread;
constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads;
constexpr index_t TotalWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (TotalWarps * MThreadsPerWarp);
constexpr index_t MPerThread = kMPerBlock / (kNumWarps * MThreadsPerWarp);
static_assert(NThreads * NPerThread == kNPerBlock);
static_assert(MPerThread * TotalWarps * MThreadsPerWarp == kMPerBlock);
static_assert(MPerThread * kNumWarps * MThreadsPerWarp == kMPerBlock);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, TotalWarps, MThreadsPerWarp>,
tuple<sequence<MPerThread, kNumWarps, MThreadsPerWarp>,
sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
......@@ -77,15 +94,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
sequence<0, 1>>{});
}
// 3d + padding, [kMaxSplits, kM0]
// 3d + padding, shape=[kMaxSplits, kM0]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsStoreBlockDescriptor()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t NPack = 16 / sizeof(LSEDataType);
constexpr index_t NPack =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
......@@ -103,15 +123,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
return lse_acc_lds_block_desc;
}
// 3d + padding, [kM0, kMaxSplits]
// 3d + padding, shape=[kM0, kMaxSplits]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsBlockDescriptor()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t NPack = 16 / sizeof(LSEDataType);
constexpr index_t NPack =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
......@@ -134,26 +157,28 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = max(Problem::kMaxSplits, get_warp_size());
constexpr index_t kNPerBlock = Problem::kMaxSplits;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t NThreads = get_warp_size();
constexpr index_t NThreads = 4;
constexpr index_t NPerThread = kNPerBlock / NThreads;
constexpr index_t MThreads = kBlockSize / NThreads;
constexpr index_t MPerThread = kMPerBlock / MThreads;
constexpr index_t MThreads = kBlockSize / NThreads;
constexpr index_t MPerThread = kMPerBlock / MThreads;
constexpr index_t MWarps = kBlockSize / get_warp_size();
constexpr index_t MThreadPerWarp = get_warp_size() / NThreads;
static_assert(NThreads * NPerThread == kNPerBlock);
static_assert(MThreads * MPerThread == kMPerBlock);
static_assert(MWarps * MThreadPerWarp * MPerThread == kMPerBlock);
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<MThreads, MPerThread>, sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<2>>,
tuple<sequence<0>, sequence<0>>,
tuple<sequence<MWarps, MThreadPerWarp, MPerThread>, sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<2, 1>>,
tuple<sequence<0>, sequence<0, 1>>,
sequence<1, 2>,
sequence<1, 1>>{});
sequence<2, 1>>{});
}
template <typename Problem>
......
......@@ -25,6 +25,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
......@@ -34,12 +35,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
......@@ -47,7 +49,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = true; // always store LSE (acc)
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
......@@ -64,6 +66,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentOacc =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOacc<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
......@@ -72,22 +77,22 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return Problem::kBlockPerCu;
else
{
if constexpr(kK0BlockLength <= 32)
if constexpr(kQKHeaddim <= 32)
{
return 2;
}
else if constexpr(kK0BlockLength <= 64)
else if constexpr(kQKHeaddim <= 64)
{
return 3;
}
else if constexpr(kK0BlockLength <= 128)
else if constexpr(kQKHeaddim <= 128)
{
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
}
else if constexpr(kK0BlockLength <= 256)
else if constexpr(kQKHeaddim <= 256)
{
return 1;
}
......@@ -138,6 +143,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
void* smem_ptr) const
{
static_assert(
......@@ -206,16 +212,16 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(
const auto q_origin = q_dram_window.get_window_origin();
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
{
const index_t original_num_total_loop =
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
if(original_num_total_loop <= 0)
const index_t logical_num_total_loop =
integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0);
if(logical_num_total_loop <= 0)
{
if constexpr(kStoreLSE)
{
......@@ -234,40 +240,48 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
}
// make sure the first tile is completely located in page-block
const index_t adjusted_seqlen_k_start = [&, seqlen_k_start_ = seqlen_k_start] {
if constexpr(kIsPagedKV)
{
return kN0 * integer_divide_floor(seqlen_k_start_, kN0);
}
else
{
return seqlen_k_start_;
}
}();
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
// make sure the first tile is completely located in page-block (page-block size should be
// divisible by kN0)
// relationship between each *_start variables: aligned_physical_seqlen_k_start <=
// physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
const index_t aligned_physical_seqlen_k_start =
[&, physical_seqlen_k_start_ = physical_seqlen_k_start] {
if constexpr(kIsPagedKV)
{
return kN0 * integer_divide_floor(physical_seqlen_k_start_, kN0);
}
else
{
return physical_seqlen_k_start_;
}
}();
const index_t num_total_loop =
integer_divide_ceil(seqlen_k_end - adjusted_seqlen_k_start, kN0);
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
k_dram_block_window_lengths, {adjusted_seqlen_k_start, 0});
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window(
bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), adjusted_seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}),
logical_seqlen_k_start - (physical_seqlen_k_start -
aligned_physical_seqlen_k_start)}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths,
{0, adjusted_seqlen_k_start}, // TODO: hdim split?
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
auto q_tile = tile_elementwise_in(q_element_func, q);
// prefetch K tile
index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops);
......@@ -374,7 +388,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col);
// position_encoding accept only logical coordinates, do conversion here
position_encoding.update(s_acc(i_j_idx), row, col - kv_l2p_offset);
});
});
}
......@@ -392,29 +407,31 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
{
const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
set_tile_if(s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&, seqlen_k_start_ = seqlen_k_start, seqlen_k_end_ = seqlen_k_end](
auto tile_idx) {
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
if constexpr(kIsPagedKV)
{
return col < seqlen_k_start_ || seqlen_k_end_ <= col;
}
else
{
return seqlen_k_end_ <= col;
}
});
set_tile_if(
s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&,
physical_seqlen_k_start_ = physical_seqlen_k_start,
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
if constexpr(kIsPagedKV)
{
return col < physical_seqlen_k_start_ || physical_seqlen_k_end_ <= col;
}
else
{
return physical_seqlen_k_end_ <= col;
}
});
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
// mask accept only logical coordinates, do conversion here
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
k_origin.at(number<0>{}) - kv_l2p_offset,
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
......@@ -423,7 +440,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
return mask.IsOutOfBound(row, col - kv_l2p_offset);
});
}
}
......@@ -654,6 +671,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
void* smem_ptr) const
{
return operator()(q_dram_block_window_tmp,
......@@ -676,6 +694,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
mask,
position_encoding,
scale_s,
kv_l2p_offset,
smem_ptr);
}
};
......
......@@ -9,11 +9,20 @@
namespace ck_tile {
// This pipeline is qkv all located in LDS
using BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy =
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>;
struct BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
{
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
return static_cast<index_t>(16 / sizeof(OaccDataType));
}
};
} // namespace ck_tile
......@@ -39,8 +39,11 @@ struct BlockFmhaPipelineProblem
using FmhaMask = remove_cvref_t<FmhaMask_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
......@@ -84,8 +87,11 @@ struct BlockFmhaFwdSplitKVPipelineProblem
using FmhaMask = remove_cvref_t<FmhaMask_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
......@@ -115,7 +121,8 @@ struct BlockFmhaSplitKVCombinePipelineProblem
using ODataType = remove_cvref_t<ODataType_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = 256;
static constexpr index_t kNumWarps = kM0_ / (get_warp_size() / 4);
static constexpr index_t kBlockSize = kNumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr index_t kHeadDimV = HeadDimV_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment