Commit 8af425fc authored by so's avatar so
Browse files

fix v0 and wave id calc

parent 339a674b
......@@ -90,7 +90,7 @@ auto create_args(int argc, char* argv[])
arg_parser.insert("t", "32", "num input tokens")
.insert("e", "1", "num of experts")
.insert("k", "1", "topk")
.insert("h", "256", "hidden_size of this model")
.insert("h", "512", "hidden_size of this model")
.insert("i", "4096", "intermediate_size between 2 gemms of FFN")
.insert("stride", "-1", "stride per row, if -1 then equal to hidden_size")
.insert("bm", "32", "blocking factor for sorted tokens")
......
......@@ -27,12 +27,9 @@
# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_f16"
#endif
" v_and_b32 v0, 0x3f, v0 \n"
" v_lshrrev_b32 v3, 6, v0 \n"
" v_readfirstlane_b32 s7, v3 \n"
" s_waitcnt vmcnt(24) \n"
"buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen\n"
"buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen\n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024\n"
" v_mul_f32 v54, v128, v128 \n"
" v_mul_f32 v55, v129, v129 \n"
" v_mul_f32 v56, v130, v130 \n"
......@@ -65,7 +62,7 @@
" v_mul_f32 v129, v129, v55 \n"
" v_mul_f32 v130, v130, v56 \n"
" v_mul_f32 v131, v131, v57 \n"
"buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v132, v132 \n"
" v_mul_f32 v55, v133, v133 \n"
" v_mul_f32 v56, v134, v134 \n"
......@@ -86,7 +83,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
"buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -99,7 +96,7 @@
" v_mul_f32 v133, v133, v55 \n"
" v_mul_f32 v134, v134, v56 \n"
" v_mul_f32 v135, v135, v57 \n"
"buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen\n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen\n"
" v_mul_f32 v54, v136, v136 \n"
" v_mul_f32 v55, v137, v137 \n"
" v_mul_f32 v56, v138, v138 \n"
......@@ -120,7 +117,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
"buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -133,7 +130,7 @@
" v_mul_f32 v137, v137, v55 \n"
" v_mul_f32 v138, v138, v56 \n"
" v_mul_f32 v139, v139, v57 \n"
"buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v140, v140 \n"
" v_mul_f32 v55, v141, v141 \n"
" v_mul_f32 v56, v142, v142 \n"
......@@ -154,7 +151,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
"buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -168,7 +165,7 @@
" v_mul_f32 v142, v142, v56 \n"
" v_mul_f32 v143, v143, v57 \n"
" s_waitcnt vmcnt(24) \n"
"buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen\n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen\n"
" v_mul_f32 v54, v144, v144 \n"
" v_mul_f32 v55, v145, v145 \n"
" v_mul_f32 v56, v146, v146 \n"
......@@ -189,7 +186,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
"buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -202,7 +199,7 @@
" v_mul_f32 v145, v145, v55 \n"
" v_mul_f32 v146, v146, v56 \n"
" v_mul_f32 v147, v147, v57 \n"
"buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v148, v148 \n"
" v_mul_f32 v55, v149, v149 \n"
" v_mul_f32 v56, v150, v150 \n"
......@@ -223,7 +220,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
"buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -236,7 +233,7 @@
" v_mul_f32 v149, v149, v55 \n"
" v_mul_f32 v150, v150, v56 \n"
" v_mul_f32 v151, v151, v57 \n"
"buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen\n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen\n"
" v_mul_f32 v54, v152, v152 \n"
" v_mul_f32 v55, v153, v153 \n"
" v_mul_f32 v56, v154, v154 \n"
......@@ -257,7 +254,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
"buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -270,7 +267,7 @@
" v_mul_f32 v153, v153, v55 \n"
" v_mul_f32 v154, v154, v56 \n"
" v_mul_f32 v155, v155, v57 \n"
"buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v156, v156 \n"
" v_mul_f32 v55, v157, v157 \n"
" v_mul_f32 v56, v158, v158 \n"
......@@ -291,7 +288,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
"buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072\n"
" s_add_u32 s12, %[s_tile_os_b_half], s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" v_add_f32 v54, v54, 1.0 \n"
......@@ -307,7 +304,7 @@
" v_mul_f32 v158, v158, v56 \n"
" v_mul_f32 v159, v159, v57 \n"
" s_waitcnt vmcnt(24) \n"
"buffer_load_dwordx4 acc[64:67], %[v_os_b0], s[12:15], 0 offen\n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b0], s[12:15], 0 offen\n"
" v_mul_f32 v54, v160, v160 \n"
" v_mul_f32 v55, v161, v161 \n"
" v_mul_f32 v56, v162, v162 \n"
......@@ -328,7 +325,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
"buffer_load_dwordx4 acc[68:71], %[v_os_b0], s[12:15], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b0], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -341,7 +338,7 @@
" v_mul_f32 v161, v161, v55 \n"
" v_mul_f32 v162, v162, v56 \n"
" v_mul_f32 v163, v163, v57 \n"
"buffer_load_dwordx4 acc[72:75], %[v_os_b0], s[12:15], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b0], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v164, v164 \n"
" v_mul_f32 v55, v165, v165 \n"
" v_mul_f32 v56, v166, v166 \n"
......@@ -362,7 +359,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
"buffer_load_dwordx4 acc[76:79], %[v_os_b0], s[12:15], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b0], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -375,7 +372,7 @@
" v_mul_f32 v165, v165, v55 \n"
" v_mul_f32 v166, v166, v56 \n"
" v_mul_f32 v167, v167, v57 \n"
"buffer_load_dwordx4 acc[80:83], %[v_os_b1], s[12:15], 0 offen\n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b1], s[12:15], 0 offen\n"
" v_mul_f32 v54, v168, v168 \n"
" v_mul_f32 v55, v169, v169 \n"
" v_mul_f32 v56, v170, v170 \n"
......@@ -396,7 +393,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
"buffer_load_dwordx4 acc[84:87], %[v_os_b1], s[12:15], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b1], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -409,7 +406,7 @@
" v_mul_f32 v169, v169, v55 \n"
" v_mul_f32 v170, v170, v56 \n"
" v_mul_f32 v171, v171, v57 \n"
"buffer_load_dwordx4 acc[88:91], %[v_os_b1], s[12:15], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b1], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v172, v172 \n"
" v_mul_f32 v55, v173, v173 \n"
" v_mul_f32 v56, v174, v174 \n"
......@@ -430,7 +427,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
"buffer_load_dwordx4 acc[92:95], %[v_os_b1], s[12:15], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b1], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -444,7 +441,7 @@
" v_mul_f32 v174, v174, v56 \n"
" v_mul_f32 v175, v175, v57 \n"
" s_waitcnt vmcnt(24) \n"
"buffer_load_dwordx4 acc[96:99], %[v_os_b2], s[12:15], 0 offen\n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b2], s[12:15], 0 offen\n"
" v_mul_f32 v54, v176, v176 \n"
" v_mul_f32 v55, v177, v177 \n"
" v_mul_f32 v56, v178, v178 \n"
......@@ -465,7 +462,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
"buffer_load_dwordx4 acc[100:103], %[v_os_b2], s[12:15], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b2], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -478,7 +475,7 @@
" v_mul_f32 v177, v177, v55 \n"
" v_mul_f32 v178, v178, v56 \n"
" v_mul_f32 v179, v179, v57 \n"
"buffer_load_dwordx4 acc[104:107], %[v_os_b2], s[12:15], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b2], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v180, v180 \n"
" v_mul_f32 v55, v181, v181 \n"
" v_mul_f32 v56, v182, v182 \n"
......@@ -499,7 +496,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
"buffer_load_dwordx4 acc[108:111], %[v_os_b2], s[12:15], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b2], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -512,7 +509,7 @@
" v_mul_f32 v181, v181, v55 \n"
" v_mul_f32 v182, v182, v56 \n"
" v_mul_f32 v183, v183, v57 \n"
"buffer_load_dwordx4 acc[112:115], %[v_os_b3], s[12:15], 0 offen\n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b3], s[12:15], 0 offen\n"
" v_mul_f32 v54, v184, v184 \n"
" v_mul_f32 v55, v185, v185 \n"
" v_mul_f32 v56, v186, v186 \n"
......@@ -533,7 +530,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
"buffer_load_dwordx4 acc[116:119], %[v_os_b3], s[12:15], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b3], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -546,7 +543,7 @@
" v_mul_f32 v185, v185, v55 \n"
" v_mul_f32 v186, v186, v56 \n"
" v_mul_f32 v187, v187, v57 \n"
"buffer_load_dwordx4 acc[120:123], %[v_os_b3], s[12:15], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b3], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v188, v188 \n"
" v_mul_f32 v55, v189, v189 \n"
" v_mul_f32 v56, v190, v190 \n"
......@@ -567,7 +564,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
"buffer_load_dwordx4 acc[124:127], %[v_os_b3], s[12:15], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b3], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -644,7 +641,7 @@
" v_mul_f32 v189, %[smq_scale1], v189 row_newbcast:13 \n"
" v_mul_f32 v190, %[smq_scale1], v190 row_newbcast:14 \n"
" v_mul_f32 v191, %[smq_scale1], v191 row_newbcast:15 \n"
";--buffer_load_dword v12, %[v_os_dq], s[16:19], 0 offen \n"
" buffer_load_dword v12, %[v_os_dq], s[16:19], 0 offen \n"
" v_mov_b32 v22, 0x358637bd \n"
" v_mov_b32 v23, 0x358637bd \n"
" v_max3_f32 v22, abs(v128), abs(v129), v22 \n"
......@@ -974,5 +971,3 @@
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
......@@ -166,7 +166,7 @@
" buffer_load_dwordx4 acc[224:227], %[v_os_b2], s[12:15], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[208:211], acc[100:101], v[148:149], v[208:211] \n"
" v_mfma_i32_16x16x32_i8 v[208:211], acc[102:103], v[150:151], v[208:211] \n"
";--- buffer_load_dword v13, %[v_os_dq], s[16:19], 0 offen \n"
";-- buffer_load_dword v13, %[v_os_dq], s[16:19], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[208:211], acc[104:105], v[152:153], v[208:211] \n"
" v_mfma_i32_16x16x32_i8 v[208:211], acc[106:107], v[154:155], v[208:211] \n"
" buffer_load_dwordx4 acc[228:231], %[v_os_b2], s[12:15], 0 offen offset:1024 \n"
......@@ -483,3 +483,4 @@
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
......@@ -214,7 +214,7 @@
" buffer_load_dwordx4 acc[96:99], %[v_os_b2], s[12:15], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[228:229], v[148:149], v[240:243] \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[230:231], v[150:151], v[240:243] \n"
";-- buffer_load_dword v12, %[v_os_dq], s[16:19], 0 offen \n"
" buffer_load_dword v12, %[v_os_dq], s[16:19], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[232:233], v[152:153], v[240:243] \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[234:235], v[154:155], v[240:243] \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b2], s[12:15], 0 offen offset:1024 \n"
......@@ -587,4 +587,3 @@
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
......@@ -54,6 +54,10 @@
" v_mov_b32 v53, 0x00007fff \n"
" s_waitcnt 0x0000 \n"
" s_mov_b32 s80, 0 \n"
" v_lshrrev_b32 v3, 6, v0 \n"
" v_and_b32 v0, 0x3f, v0 \n"
" v_readfirstlane_b32 s7, v3 \n"
";--ds write v3 gen"
" v_lshrrev_b32 v54, 4, v0 \n"
" v_mul_i32_i24 v3, 34, v54 \n"
" v_and_b32 v54, 15, v0 \n"
......@@ -62,6 +66,7 @@
" s_mul_i32 s60, s7, 0x00000088 \n"
" v_add_u32 v3, s60, v3 \n"
" v_lshlrev_b32 v3, 2, v3 \n"
";--ds read v4 gen\n"
" v_lshrrev_b32 v54, 1, v0 \n"
" v_mul_i32_i24 v4, 34, v54 \n"
" v_and_b32 v55, 1, v0 \n"
......
......@@ -105,9 +105,9 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
// TODO: properlly support scatter/gather for load only
CK_TILE_DEVICE auto GetRowCoords_A(index_t base_offset)
{
constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
constexpr index_t MLans = BlockShape::BlockSize / KLans;
constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;//64
constexpr index_t MLans = BlockShape::BlockSize / KLans;//4
constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;//8
auto base_coord = threadIdx.x / KLans + base_offset;
......@@ -424,50 +424,58 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
gqsmq_coords, (reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr) + static_cast<long_index_t>(expert_id) * shared_intermediate_size_0));
auto smq_scale = GetSMQScale(
gqsmq_coords, (reinterpret_cast<const YSmoothScaleDataType*>(kargs.y_smooth_scale_ptr) + static_cast<long_index_t>(expert_id) * shared_intermediate_size_0));
if(threadIdx.x == 95 && blockIdx.x == 0 && blockIdx.y == 0)
{
printf("\nblockIdx.x :%x, blockIdx.y :%x, d ptr: %p, wg d ptr :%x%x,gemm0 done\n", blockIdx.x, blockIdx.y, kargs.d_ptr,d_res[1],d_res[0]);
// // printf("\n wg 1 1, wave 1, row_coords_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_coords_a[number<0>{}],row_coords_a[number<1>{}],row_coords_a[number<2>{}],row_coords_a[number<3>{}], row_coords_a[number<4>{}],row_coords_a[number<5>{}],row_coords_a[number<6>{}],row_coords_a[number<7>{}]);
// // printf("\n -------------- -row_ids_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_ids_a[number<0>{}],row_ids_a[number<1>{}],row_ids_a[number<2>{}],row_ids_a[number<3>{}], row_ids_a[number<4>{}],row_ids_a[number<5>{}],row_ids_a[number<6>{}],row_ids_a[number<7>{}]);
// // printf("\n -----------thread id %x--- - token_id 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", hipThreadIdx_x , token_id[number<0>{}],token_id[number<1>{}],token_id[number<2>{}],token_id[number<3>{}], token_id[number<4>{}],token_id[number<5>{}],token_id[number<6>{}],token_id[number<7>{}]);
// // printf("\n -----------thread id %x--- - token_id , 7:%x,, \n", hipThreadIdx_x , token_id[number<7>{}]);
// // printf("\n -------------- - exec 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", o_flags[number<0>{}][0],o_flags[number<1>{}][0],o_flags[number<2>{}][0],o_flags[number<3>{}][0], o_flags[number<4>{}][0],o_flags[number<5>{}][0],o_flags[number<6>{}][0],o_flags[number<7>{}][0]);
printf("\ntoken id :%x,%x,%x,%x, %x,%x,%x,%x \n d_coords: %x,%x,%x,%x, \n row_idx: %x,%x,%x,%x, %x,%x,%x,%x \n o_flags:%x,%x,%x,%x, %x,%x,%x,%x \n",
token_id[number<0>{}],
token_id[number<1>{}],
token_id[number<2>{}],
token_id[number<3>{}],
token_id[number<4>{}],
token_id[number<5>{}],
token_id[number<6>{}],
token_id[number<7>{}],
d_coords[number<0>{}],
d_coords[number<1>{}],
d_coords[number<2>{}],
d_coords[number<3>{}],
// d_coords[number<4>{}],
// d_coords[number<5>{}],
// d_coords[number<6>{}],
// d_coords[number<7>{}],
row_ids_a[number<0>{}],
row_ids_a[number<1>{}],
row_ids_a[number<2>{}],
row_ids_a[number<3>{}],
row_ids_a[number<4>{}],
row_ids_a[number<5>{}],
row_ids_a[number<6>{}],
row_ids_a[number<7>{}],
o_flags[number<0>{}][0],
o_flags[number<1>{}][0],
o_flags[number<2>{}][0],
o_flags[number<3>{}][0],
o_flags[number<4>{}][0],
o_flags[number<5>{}][0],
o_flags[number<6>{}][0],
o_flags[number<7>{}][0]);
// return;
}
__builtin_amdgcn_sched_barrier(0);
// if(threadIdx.x == 255 && blockIdx.x == 0 && blockIdx.y == 0)
// {
// printf("\nblockIdx.x :%x, blockIdx.y :%x, d ptr: %p, o ptr: %p, wg d ptr :%x%x,gemm0 done\n", blockIdx.x, blockIdx.y,kargs.d_ptr, kargs.o_ptr,d_res[1],d_res[0]);
// // // printf("\n wg 1 1, wave 1, row_coords_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_coords_a[number<0>{}],row_coords_a[number<1>{}],row_coords_a[number<2>{}],row_coords_a[number<3>{}], row_coords_a[number<4>{}],row_coords_a[number<5>{}],row_coords_a[number<6>{}],row_coords_a[number<7>{}]);
// // // printf("\n -------------- -row_ids_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_ids_a[number<0>{}],row_ids_a[number<1>{}],row_ids_a[number<2>{}],row_ids_a[number<3>{}], row_ids_a[number<4>{}],row_ids_a[number<5>{}],row_ids_a[number<6>{}],row_ids_a[number<7>{}]);
// // // printf("\n -----------thread id %x--- - token_id 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", hipThreadIdx_x , token_id[number<0>{}],token_id[number<1>{}],token_id[number<2>{}],token_id[number<3>{}], token_id[number<4>{}],token_id[number<5>{}],token_id[number<6>{}],token_id[number<7>{}]);
// // // printf("\n -----------thread id %x--- - token_id , 7:%x,, \n", hipThreadIdx_x , token_id[number<7>{}]);
// // // printf("\n -------------- - exec 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", o_flags[number<0>{}][0],o_flags[number<1>{}][0],o_flags[number<2>{}][0],o_flags[number<3>{}][0], o_flags[number<4>{}][0],o_flags[number<5>{}][0],o_flags[number<6>{}][0],o_flags[number<7>{}][0]);
// printf("\ntoken_id :%x,%x,%x,%x, %x,%x,%x,%x \no_coords :%x,%x,%x,%x, %x,%x,%x,%x \n d_coords: %x,%x,%x,%x, \n row_idx: %x,%x,%x,%x, %x,%x,%x,%x \n o_flags:%x,%x,%x,%x, %x,%x,%x,%x \n",
// token_id[number<0>{}],
// token_id[number<1>{}],
// token_id[number<2>{}],
// token_id[number<3>{}],
// token_id[number<4>{}],
// token_id[number<5>{}],
// token_id[number<6>{}],
// token_id[number<7>{}],
// o_coords[number<0>{}],
// o_coords[number<1>{}],
// o_coords[number<2>{}],
// o_coords[number<3>{}],
// o_coords[number<4>{}],
// o_coords[number<5>{}],
// o_coords[number<6>{}],
// o_coords[number<7>{}],
// d_coords[number<0>{}],
// d_coords[number<1>{}],
// d_coords[number<2>{}],
// d_coords[number<3>{}],
// // d_coords[number<4>{}],
// // d_coords[number<5>{}],
// // d_coords[number<6>{}],
// // d_coords[number<7>{}],
// row_ids_a[number<0>{}],
// row_ids_a[number<1>{}],
// row_ids_a[number<2>{}],
// row_ids_a[number<3>{}],
// row_ids_a[number<4>{}],
// row_ids_a[number<5>{}],
// row_ids_a[number<6>{}],
// row_ids_a[number<7>{}],
// o_flags[number<0>{}][0],
// o_flags[number<1>{}][0],
// o_flags[number<2>{}][0],
// o_flags[number<3>{}][0],
// o_flags[number<4>{}][0],
// o_flags[number<5>{}][0],
// o_flags[number<6>{}][0],
// o_flags[number<7>{}][0]);
// // return;
// }
// __builtin_amdgcn_sched_barrier(0);
auto uk_0 = Policy::template GetUK_0<Problem>();
// auto acc_0= uk_0(
......@@ -483,10 +491,10 @@ if(threadIdx.x == 95 && blockIdx.x == 0 && blockIdx.y == 0)
kargs.hidden_size,
BlockShape::Block_K0, // tile offset for B matrix each unroll
16*256,
kargs.num_tokens * kargs.stride_token); // tile offset for B matrix each unroll
__builtin_amdgcn_readfirstlane(kargs.num_tokens * kargs.stride_token)); // tile offset for B matrix each unroll
// return;
__builtin_amdgcn_sched_barrier(0);
// return;
// // sweep_tile(
// acc_0,
// [&](auto idx0, auto idx1) {
......@@ -515,7 +523,7 @@ if(threadIdx.x == 95 && blockIdx.x == 0 && blockIdx.y == 0)
o_coords,
o_flags,
smem,
kargs.hidden_size, // total n number
__builtin_amdgcn_readfirstlane(kargs.hidden_size), // total n number
w_scale,
smq_scale,
BlockShape::Block_N1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment