Commit 7cc808f2 authored by shengnxu's avatar shengnxu
Browse files

fix some codes

parent 6f7d1272
......@@ -245,10 +245,12 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
template <typename Ascale, typename GQscale, typename ARes, typename ACoords, typename BRes, typename BCoords>
template <typename Ascale, typename GQscale, typename DRes, typename DQRes, typename ARes, typename ACoords, typename BRes, typename BCoords>
CK_TILE_DEVICE auto
operator()( const Ascale& a_scale_,
const GQscale& gq_scale_,
const DRes& res_d,
const DQRes& res_dq,
const ARes& res_a,
const ACoords& cached_coords_a,
const BRes& res_b,
......@@ -446,14 +448,22 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
[a_scale1]"v"(a_scale_[1]),
[gq_scale0]"v"(gq_scale_[0]),
[gq_scale1]"v"(gq_scale_[1]),
[s_res_a]"s"(res_a),
// [s_res_a1]"s"(res_a[1]),
// [s_res_a2]"s"(res_a[2]),
// [s_res_a3]"s"(res_a[3]),
[s_res_b]"s"(res_b),
// [s_res_b1]"s"(res_b[1]),
// [s_res_b2]"s"(res_b[2]),
// [s_res_b3]"s"(res_b[3]),
[s_res_a0]"s"(res_a[0]),
[s_res_a1]"s"(res_a[1]),
[s_res_a2]"s"(res_a[2]),
[s_res_a3]"s"(res_a[3]),
[s_res_b0]"s"(res_b[0]),
[s_res_b1]"s"(res_b[1]),
[s_res_b2]"s"(res_b[2]),
[s_res_b3]"s"(res_b[3]),
[s_res_d0]"s"(res_d[0]),
[s_res_d1]"s"(res_d[1]),
[s_res_d2]"s"(res_d[2]),
[s_res_d3]"s"(res_d[3]),
[s_res_dq0]"s"(res_dq[0]),
[s_res_dq1]"s"(res_dq[1]),
[s_res_dq2]"s"(res_dq[2]),
[s_res_dq3]"s"(res_dq[3]),
[v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))),
[v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))),
[v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))),
......
......@@ -78,8 +78,9 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
template <typename DQRes,
typename BRes,
template <
// typename DQRes,
// typename BRes,
typename DQCoords,
typename BCoords,
typename ORes,
......@@ -88,8 +89,9 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
typename ScaleTensor,
typename YScaleTensor>
CK_TILE_DEVICE auto
operator()(const DQRes& res_dq,
const BRes& res_b,
operator()(
// const DQRes& res_dq,
// const BRes& res_b,
const DQCoords& cached_coords_dq,
const BCoords& cached_coords_b,
const ORes& res_o,
......@@ -118,38 +120,6 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
index_t loop_cnt = n ;
// register float v_c0 asm("v64");
// register float v_c1 asm("v65");
// register float v_c2 asm("v66");
// register float v_c3 asm("v67");
// register float v_c4 asm("v68");
// register float v_c5 asm("v69");
// register float v_c6 asm("v70");
// register float v_c7 asm("v71");
// register float v_c8 asm("v72");
// register float v_c9 asm("v73");
// register float v_c10 asm("v74");
// register float v_c11 asm("v75");
// register float v_c12 asm("v76");
// register float v_c13 asm("v77");
// register float v_c14 asm("v78");
// register float v_c15 asm("v79");
// register float v_c16 asm("v80");
// register float v_c17 asm("v81");
// register float v_c18 asm("v82");
// register float v_c19 asm("v83");
// register float v_c20 asm("v84");
// register float v_c21 asm("v85");
// register float v_c22 asm("v86");
// register float v_c23 asm("v87");
// register float v_c24 asm("v88");
// register float v_c25 asm("v89");
// register float v_c26 asm("v90");
// register float v_c27 asm("v91");
// register float v_c28 asm("v92");
// register float v_c29 asm("v93");
// register float v_c30 asm("v94");
// register float v_c31 asm("v95");
// int32_t nan_hi = 0x7fff0000;
// int32_t nan_lo = 0x00007fff;
......@@ -187,38 +157,6 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
#undef CK_TILE_FLATMM_UK_MFMA
:[smem_]"+r"(smem),
[s_loop_cnt]"+s"(loop_cnt)
// [c0]"+v" (v_c0),
// [c1]"+v" (v_c1),
// [c2]"+v" (v_c2),
// [c3]"+v" (v_c3),
// [c4]"+v" (v_c4),
// [c5]"+v" (v_c5),
// [c6]"+v" (v_c6),
// [c7]"+v" (v_c7),
// [c8]"+v" (v_c8),
// [c9]"+v" (v_c9),
// [c10]"+v"(v_c10),
// [c11]"+v"(v_c11),
// [c12]"+v"(v_c12),
// [c13]"+v"(v_c13),
// [c14]"+v"(v_c14),
// [c15]"+v"(v_c15),
// [c16]"+v"(v_c16),
// [c17]"+v"(v_c17),
// [c18]"+v"(v_c18),
// [c19]"+v"(v_c19),
// [c20]"+v"(v_c20),
// [c21]"+v"(v_c21),
// [c22]"+v"(v_c22),
// [c23]"+v"(v_c23),
// [c24]"+v"(v_c24),
// [c25]"+v"(v_c25),
// [c26]"+v"(v_c26),
// [c27]"+v"(v_c27),
// [c28]"+v"(v_c28),
// [c29]"+v"(v_c29),
// [c30]"+v"(v_c30),
// [c31]"+v"(v_c31)
:[sld_a_base]"n"(0),
// [shfl_base]"n"(0),
// [v_sld_y_os]"v"(sld_y_os),
......@@ -226,15 +164,10 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
// [v_sfl_sst]"v"(sfl_sst),
[smq_scale0]"s"(smq_scale_[0]),
[smq_scale1]"s"(smq_scale_[1]),
[s_res_dq]"s"(res_dq),
[s_res_o0]"s"(res_o[0]),
[s_res_o1]"s"(res_o[1]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]),
[s_res_d]"s"(res_b),
// [s_res_b1]"s"(res_b[1]),
// [s_res_b2]"s"(res_b[2]),
// [s_res_b3]"s"(res_b[3]),
[v_os_dq]"v"(static_cast<index_t>(cached_coords_dq * sizeof(DScaleDataType))),
[v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
[v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
......@@ -334,52 +267,10 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
#undef CK_TILE_FLATMM_UK_MFMA
:[smem_]"+r"(smem),
[s_loop_cnt]"+s"(loop_cnt)
// [c0]"+v" (v_c0),
// [c1]"+v" (v_c1),
// [c2]"+v" (v_c2),
// [c3]"+v" (v_c3),
// [c4]"+v" (v_c4),
// [c5]"+v" (v_c5),
// [c6]"+v" (v_c6),
// [c7]"+v" (v_c7),
// [c8]"+v" (v_c8),
// [c9]"+v" (v_c9),
// [c10]"+v"(v_c10),
// [c11]"+v"(v_c11),
// [c12]"+v"(v_c12),
// [c13]"+v"(v_c13),
// [c14]"+v"(v_c14),
// [c15]"+v"(v_c15),
// [c16]"+v"(v_c16),
// [c17]"+v"(v_c17),
// [c18]"+v"(v_c18),
// [c19]"+v"(v_c19),
// [c20]"+v"(v_c20),
// [c21]"+v"(v_c21),
// [c22]"+v"(v_c22),
// [c23]"+v"(v_c23),
// [c24]"+v"(v_c24),
// [c25]"+v"(v_c25),
// [c26]"+v"(v_c26),
// [c27]"+v"(v_c27),
// [c28]"+v"(v_c28),
// [c29]"+v"(v_c29),
// [c30]"+v"(v_c30),
// [c31]"+v"(v_c31)
:[sld_a_base]"n"(0),
// [shfl_base]"n"(0),
// [v_sld_y_os]"v"(sld_y_os),
// [v_sfl_sld]"v"(sfl_sld),
// [v_sfl_sst]"v"(sfl_sst),
[s_res_dq]"s"(res_dq),
[s_res_o0]"s"(res_o[0]),
[s_res_o1]"s"(res_o[1]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]),
[s_res_d]"s"(res_b),
// [s_res_b1]"s"(res_b[1]),
// [s_res_b2]"s"(res_b[2]),
// [s_res_b3]"s"(res_b[3]),
[v_os_dq]"v"(static_cast<index_t>(cached_coords_dq * sizeof(DScaleDataType))),
[v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
[v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
[v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
......
......@@ -31,8 +31,8 @@
" v_lshrrev_b32 v3, 6, v0 \n"
" v_readfirstlane_b32 s7, v3 \n"
" s_waitcnt vmcnt(24) \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], %[s_res_d], 0 offen\n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], %[s_res_d], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen\n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024\n"
" v_mul_f32 v54, v128, v128 \n"
" v_mul_f32 v55, v129, v129 \n"
" v_mul_f32 v56, v130, v130 \n"
......@@ -65,7 +65,7 @@
" v_mul_f32 v129, v129, v55 \n"
" v_mul_f32 v130, v130, v56 \n"
" v_mul_f32 v131, v131, v57 \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], %[s_res_d], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v132, v132 \n"
" v_mul_f32 v55, v133, v133 \n"
" v_mul_f32 v56, v134, v134 \n"
......@@ -86,7 +86,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], %[s_res_d], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -99,7 +99,7 @@
" v_mul_f32 v133, v133, v55 \n"
" v_mul_f32 v134, v134, v56 \n"
" v_mul_f32 v135, v135, v57 \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], %[s_res_d], 0 offen\n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen\n"
" v_mul_f32 v54, v136, v136 \n"
" v_mul_f32 v55, v137, v137 \n"
" v_mul_f32 v56, v138, v138 \n"
......@@ -120,7 +120,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], %[s_res_d], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -133,7 +133,7 @@
" v_mul_f32 v137, v137, v55 \n"
" v_mul_f32 v138, v138, v56 \n"
" v_mul_f32 v139, v139, v57 \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], %[s_res_d], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v140, v140 \n"
" v_mul_f32 v55, v141, v141 \n"
" v_mul_f32 v56, v142, v142 \n"
......@@ -154,7 +154,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], %[s_res_d], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -168,7 +168,7 @@
" v_mul_f32 v142, v142, v56 \n"
" v_mul_f32 v143, v143, v57 \n"
" s_waitcnt vmcnt(24) \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], %[s_res_d], 0 offen\n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen\n"
" v_mul_f32 v54, v144, v144 \n"
" v_mul_f32 v55, v145, v145 \n"
" v_mul_f32 v56, v146, v146 \n"
......@@ -189,7 +189,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], %[s_res_d], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -202,7 +202,7 @@
" v_mul_f32 v145, v145, v55 \n"
" v_mul_f32 v146, v146, v56 \n"
" v_mul_f32 v147, v147, v57 \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], %[s_res_d], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v148, v148 \n"
" v_mul_f32 v55, v149, v149 \n"
" v_mul_f32 v56, v150, v150 \n"
......@@ -223,7 +223,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], %[s_res_d], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -236,7 +236,7 @@
" v_mul_f32 v149, v149, v55 \n"
" v_mul_f32 v150, v150, v56 \n"
" v_mul_f32 v151, v151, v57 \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], %[s_res_d], 0 offen\n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen\n"
" v_mul_f32 v54, v152, v152 \n"
" v_mul_f32 v55, v153, v153 \n"
" v_mul_f32 v56, v154, v154 \n"
......@@ -257,7 +257,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], %[s_res_d], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -270,7 +270,7 @@
" v_mul_f32 v153, v153, v55 \n"
" v_mul_f32 v154, v154, v56 \n"
" v_mul_f32 v155, v155, v57 \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], %[s_res_d], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v156, v156 \n"
" v_mul_f32 v55, v157, v157 \n"
" v_mul_f32 v56, v158, v158 \n"
......@@ -291,7 +291,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], %[s_res_d], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072\n"
" s_add_u32 s12, %[s_tile_os_b_half], s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" v_add_f32 v54, v54, 1.0 \n"
......@@ -307,7 +307,7 @@
" v_mul_f32 v158, v158, v56 \n"
" v_mul_f32 v159, v159, v57 \n"
" s_waitcnt vmcnt(24) \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b0], %[s_res_d], 0 offen\n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b0], s[12:15], 0 offen\n"
" v_mul_f32 v54, v160, v160 \n"
" v_mul_f32 v55, v161, v161 \n"
" v_mul_f32 v56, v162, v162 \n"
......@@ -328,7 +328,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b0], %[s_res_d], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b0], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -341,7 +341,7 @@
" v_mul_f32 v161, v161, v55 \n"
" v_mul_f32 v162, v162, v56 \n"
" v_mul_f32 v163, v163, v57 \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b0], %[s_res_d], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b0], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v164, v164 \n"
" v_mul_f32 v55, v165, v165 \n"
" v_mul_f32 v56, v166, v166 \n"
......@@ -362,7 +362,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b0], %[s_res_d], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b0], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -375,7 +375,7 @@
" v_mul_f32 v165, v165, v55 \n"
" v_mul_f32 v166, v166, v56 \n"
" v_mul_f32 v167, v167, v57 \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b1], %[s_res_d], 0 offen\n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b1], s[12:15], 0 offen\n"
" v_mul_f32 v54, v168, v168 \n"
" v_mul_f32 v55, v169, v169 \n"
" v_mul_f32 v56, v170, v170 \n"
......@@ -396,7 +396,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b1], %[s_res_d], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b1], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -409,7 +409,7 @@
" v_mul_f32 v169, v169, v55 \n"
" v_mul_f32 v170, v170, v56 \n"
" v_mul_f32 v171, v171, v57 \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b1], %[s_res_d], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b1], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v172, v172 \n"
" v_mul_f32 v55, v173, v173 \n"
" v_mul_f32 v56, v174, v174 \n"
......@@ -430,7 +430,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b1], %[s_res_d], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b1], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -444,7 +444,7 @@
" v_mul_f32 v174, v174, v56 \n"
" v_mul_f32 v175, v175, v57 \n"
" s_waitcnt vmcnt(24) \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b2], %[s_res_d], 0 offen\n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b2], s[12:15], 0 offen\n"
" v_mul_f32 v54, v176, v176 \n"
" v_mul_f32 v55, v177, v177 \n"
" v_mul_f32 v56, v178, v178 \n"
......@@ -465,7 +465,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b2], %[s_res_d], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b2], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -478,7 +478,7 @@
" v_mul_f32 v177, v177, v55 \n"
" v_mul_f32 v178, v178, v56 \n"
" v_mul_f32 v179, v179, v57 \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b2], %[s_res_d], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b2], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v180, v180 \n"
" v_mul_f32 v55, v181, v181 \n"
" v_mul_f32 v56, v182, v182 \n"
......@@ -499,7 +499,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b2], %[s_res_d], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b2], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -512,7 +512,7 @@
" v_mul_f32 v181, v181, v55 \n"
" v_mul_f32 v182, v182, v56 \n"
" v_mul_f32 v183, v183, v57 \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b3], %[s_res_d], 0 offen\n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b3], s[12:15], 0 offen\n"
" v_mul_f32 v54, v184, v184 \n"
" v_mul_f32 v55, v185, v185 \n"
" v_mul_f32 v56, v186, v186 \n"
......@@ -533,7 +533,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b3], %[s_res_d], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b3], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -546,7 +546,7 @@
" v_mul_f32 v185, v185, v55 \n"
" v_mul_f32 v186, v186, v56 \n"
" v_mul_f32 v187, v187, v57 \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b3], %[s_res_d], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b3], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v188, v188 \n"
" v_mul_f32 v55, v189, v189 \n"
" v_mul_f32 v56, v190, v190 \n"
......@@ -567,7 +567,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b3], %[s_res_d], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b3], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -644,7 +644,7 @@
" v_mul_f32 v189, %[smq_scale1], v189 row_newbcast:13 \n"
" v_mul_f32 v190, %[smq_scale1], v190 row_newbcast:14 \n"
" v_mul_f32 v191, %[smq_scale1], v191 row_newbcast:15 \n"
" buffer_load_dword v12, %[v_os_dq], %[s_res_dq], 0 offen \n"
" buffer_load_dword v12, %[v_os_dq], s[16:19], 0 offen \n"
" v_mov_b32 v22, 0x358637bd \n"
" v_mov_b32 v23, 0x358637bd \n"
" v_max3_f32 v22, abs(v128), abs(v129), v22 \n"
......@@ -974,5 +974,3 @@
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
......@@ -157,7 +157,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
return w;
}
template <typename ROW_COORDS>
template <typename ROW_IDS>
CK_TILE_DEVICE auto GetAScale(const ROW_IDS row_ids_mma,
const AScaleDataType* a_scale_ptr)
{
......@@ -165,9 +165,9 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
array<TopkWeightDataType, n_size> w;
static_for<0, n_size, 1>{}([&](auto i) {
auto row_id = row_idx_mma[i] & 0xffffff;
auto itp_k = row_idx_mma[i] >> 24;
w.at(i) = sorted_weight_ptr[row_id *kargs.topk+itp_k];
auto row_id = row_ids_mma[i] & 0xffffff;
auto itp_k = row_ids_mma[i] >> 24;
w.at(i) = a_scale_ptr[row_id * 5+itp_k];
});
return w;
......@@ -199,13 +199,14 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
// auto q_is = threadIdx.x & 0x3;
array<index_t, Repeat_N> coords;
static_for<0, Repeat_N, 1>{}([&](auto i) { coords.at(i) = base_coord + (threadIdx.x / MLanes) * 4 +
static_for<0, Repeat_N, 1>{}([&](auto i) { coords.at(i) = base_offset + (threadIdx.x / MLanes) * 4 +
(threadIdx.x & 0xffff)/4 * 64 +
q_id +
threadIdx.x & 0x3 +
i * 256 ; });
return coords;
}
//this calculation shared by G and SMQ
template <typename COL_IDS>
CK_TILE_DEVICE auto GetGQScale(const COL_IDS coords,
const GScaleDataType* g_scale_ptr)
{
......@@ -218,6 +219,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
return g_scale_value;
}
template <typename COL_IDS>
CK_TILE_DEVICE auto GetSMQScale(const COL_IDS coords,
const YSmoothScaleDataType * y_scale_ptr)
{
......@@ -251,8 +253,6 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size;
index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size;
/////////////
index_t g_scale_expert_stride_0 = shared_intermediate_size_0;
index_t smq_scale_expert_stride_0 = shared_intermediate_size_0;
index_t d_scale_expert_stride_1 = kargs.hidden_size;
// nr*kr*w
index_t interm_idx_nr0 = __builtin_amdgcn_readfirstlane(
......@@ -283,20 +283,6 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto a_res =
make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
//////aq
auto aq_win = [&]() {
const AScaleDataType* aq_ptr = reinterpret_cast<const AScaleDataType*>(kargs.a_scale_ptr);
auto aq_view_ = make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.num_tokens * kargs.topk),
make_tuple(1),
number<1>{},
number<1>{});
return aq_view_;
}();
auto aq_res = aq_win.get_buffer_view().cached_buf_res_;
////////
auto g_win = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
......@@ -323,40 +309,6 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto g_res = g_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
auto g_coords = generate_tuple([&](auto i) { return g_win.cached_coords_[i].get_offset(); },
number<decltype(g_win)::NumAccess_NonLinear>{});
//////gq
auto gq_win = [&]() {
const GScaleDataType* gq_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr) +
static_cast<long_index_t>(expert_id) * g_scale_expert_stride_0 +
intermediate_tile_id * BlockShape::Block_N0;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline
auto gq_view_ = make_naive_tensor_view<address_space_enum::global>(
gq_ptr,
make_tuple(shared_intermediate_size_1),
make_tuple(1),
number<1>{},
number<1>{});
return gq_view_;
}();
auto gq_res = gq_win.get_buffer_view().cached_buf_res_;
////smQ
auto smq_win = [&]() {
const YSmoothScaleDataType* smq_ptr = reinterpret_cast<const YSmoothScaleDataType*>(kargs.y_smooth_scale_ptr) +
static_cast<long_index_t>(expert_id) * smq_scale_expert_stride_0 +
intermediate_tile_id * BlockShape::Block_K1;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline
auto smq_view_ = make_naive_tensor_view<address_space_enum::global>(
smq_ptr,
make_tuple(shared_intermediate_size_1),
make_tuple(1),
number<1>{},
number<1>{});
return smq_view_;
}();
auto smq_res = smq_win.get_buffer_view().cached_buf_res_;
/////////////////////
const auto d_win = [&]() {
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
......@@ -447,16 +399,18 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
auto a_scale = GetAScale(
row_ids_a_mma, reinterpret_cast<const AScaleDataType*>(kargs.a_scale_ptr));
auto gqsmq_coords = GetColCoords_GQSMQ(intermediated_tile_id * BlockShape::Block_K1);
auto gqsmq_coords = GetColCoords_GQSMQ(intermediate_tile_id * BlockShape::Block_K1);
auto dq_coords = gqsmq_coords[0];//only one for this tiling
auto gq_scale = GetGQScale(
gqsmq_coords, reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr + static_cast<long_index_t>(expert_id) * shared_intermediate_size_0));
gqsmq_coords, (reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr) + static_cast<long_index_t>(expert_id) * shared_intermediate_size_0));
auto smq_scale = GetSMQScale(
gqsmq_coords, reinterpret_cast<const YSmoothScaleDataType*>(kargs.y_smooth_scale_ptr + static_cast<long_index_t>(expert_id) * shared_intermediate_size_0));
gqsmq_coords, (reinterpret_cast<const YSmoothScaleDataType*>(kargs.y_smooth_scale_ptr) + static_cast<long_index_t>(expert_id) * shared_intermediate_size_0));
auto uk_0 = Policy::template GetUK_0<Problem>();
// auto acc_0= uk_0(
uk_0( a_scale,
gq_scale,
d_res,
dq_res,
a_res,
a_coords,
g_res,
......@@ -485,8 +439,9 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
// block_sync_lds();
auto uk_1 = Policy::template GetUK_1<Problem>();
uk_1(dq_res,
d_res,
uk_1(
// dq_res,
// d_res,
dq_coords,
d_coords,
o_res,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment