union_vec2_f16x2<Element>k_reg[(K/kBlockK_)*((WARP_N_*kBlockK_)/(32*32))*2/((K_prefetch_level==3)?1:2)][2];//ds_read mini size is 32*32,2 is seq, 4 is head dim
union_vec4_f16x2<Element>k_reg[(K/kBlockK_)*((WARP_N_*kBlockK_)/(32*32))*2/((K_prefetch_level==3)?1:2)];//ds_read mini size is 32*32,2 is seq, 4 is head dim
union_vec4_f16x2<Element>k_reg[(K/kBlockK_)*((WARP_N_*kBlockK_)/(32*32))*2/((K_prefetch_level==3)?1:2)];//ds_read mini size is 32*32,2 is seq, 4 is head dim