Commit 6f7d1272 authored by shengnxu's avatar shengnxu
Browse files

changed all the scale outside except for uq

parent 9a46c0e7
...@@ -245,13 +245,10 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16 ...@@ -245,13 +245,10 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
// TODO: need paired with tile_window_linear! // TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function! // TODO: need call init_raw() before call this function!
template <typename AToken_id, typename AQRes, typename DQRes, typename GQRes, typename SMQRes, typename ARes, typename ACoords, typename BRes, typename BCoords> template <typename Ascale, typename GQscale, typename ARes, typename ACoords, typename BRes, typename BCoords>
CK_TILE_DEVICE auto CK_TILE_DEVICE auto
operator()( const AToken_id& row_ids_a_, operator()( const Ascale& a_scale_,
const AQRes& res_aq, const GQscale& gq_scale_,
const DQRes& res_dq,
const GQRes& res_gq,
const SMQRes& res_smq,
const ARes& res_a, const ARes& res_a,
const ACoords& cached_coords_a, const ACoords& cached_coords_a,
const BRes& res_b, const BRes& res_b,
...@@ -263,7 +260,6 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16 ...@@ -263,7 +260,6 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
{ {
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 4 /*2x per dword*/); // 8 static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 4 /*2x per dword*/); // 8
static_assert(BCoords::size() == Repeat_N); static_assert(BCoords::size() == Repeat_N);
static_assert(AToken_id::size() == Repeat_M);
static_assert(Ascale::size() == Repeat_M); static_assert(Ascale::size() == Repeat_M);
auto a_sst = make_tile_window( auto a_sst = make_tile_window(
...@@ -372,10 +368,6 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16 ...@@ -372,10 +368,6 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
register int v_z61 asm("v189") = 0; register int v_z61 asm("v189") = 0;
register int v_z62 asm("v190") = 0; register int v_z62 asm("v190") = 0;
register int v_z63 asm("v191") = 0; register int v_z63 asm("v191") = 0;
index_t temp0 = static_cast<index_t>(row_ids_a_[number<0>{}]);
index_t temp1 = static_cast<index_t>(row_ids_a_[number<1>{}]);
// B nr->kr // B nr->kr
#pragma clang diagnostic push #pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm" #pragma clang diagnostic ignored "-Winline-asm"
...@@ -449,13 +441,11 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16 ...@@ -449,13 +441,11 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
[c61]"+v"(v_z61), [c61]"+v"(v_z61),
[c62]"+v"(v_z62), [c62]"+v"(v_z62),
[c63]"+v"(v_z63), [c63]"+v"(v_z63),
[v_token_id0]"+v"(temp0),
[v_token_id1]"+v"(temp1),
[s_mem_]"+r"(smem) [s_mem_]"+r"(smem)
: [s_res_aq]"s"(res_aq), : [a_scale0]"v"(a_scale_[0]),
[s_res_dq]"s"(res_dq), [a_scale1]"v"(a_scale_[1]),
[s_res_gq]"s"(res_gq), [gq_scale0]"v"(gq_scale_[0]),
[s_res_smq]"s"(res_smq), [gq_scale1]"v"(gq_scale_[1]),
[s_res_a]"s"(res_a), [s_res_a]"s"(res_a),
// [s_res_a1]"s"(res_a[1]), // [s_res_a1]"s"(res_a[1]),
// [s_res_a2]"s"(res_a[2]), // [s_res_a2]"s"(res_a[2]),
......
...@@ -80,21 +80,25 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -80,21 +80,25 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor> // template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
template <typename DQRes, template <typename DQRes,
typename BRes, typename BRes,
typename DQCoords,
typename BCoords, typename BCoords,
typename ORes, typename ORes,
typename OCoords, typename OCoords,
typename OFlags> typename OFlags,
// typename ScaleTensor> typename ScaleTensor,
typename YScaleTensor>
CK_TILE_DEVICE auto CK_TILE_DEVICE auto
operator()(const DQRes& res_dq, operator()(const DQRes& res_dq,
const BRes& res_b, const BRes& res_b,
const DQCoords& cached_coords_dq,
const BCoords& cached_coords_b, const BCoords& cached_coords_b,
const ORes& res_o, const ORes& res_o,
const OCoords& cached_coords_o, const OCoords& cached_coords_o,
const OFlags& o_flags, // this should be in sgpr const OFlags& o_flags, // this should be in sgpr
CK_TILE_LDS_ADDR void* smem, CK_TILE_LDS_ADDR void* smem,
index_t n, // loop along n dim index_t n, // loop along n dim
// const ScaleTensor& scale_, const ScaleTensor& scale_,
const YScaleTensor& smq_scale_,
index_t tile_offset_dq, index_t tile_offset_dq,
index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust
index_t tile_offset_half_b, //splited load alone K in to 2 part index_t tile_offset_half_b, //splited load alone K in to 2 part
...@@ -108,9 +112,9 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -108,9 +112,9 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType); const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
const index_t tile_stride_dq_bytes = tile_offset_dq * sizeof(DScaleDataType); const index_t tile_stride_dq_bytes = tile_offset_dq * sizeof(DScaleDataType);
// static_assert(ScaleTensor::size() == 2); static_assert(ScaleTensor::size() == 2);
// float s0 = scale_[number<0>{}]; float s0 = scale_[number<0>{}];
// float s1 = scale_[number<1>{}]; float s1 = scale_[number<1>{}];
index_t loop_cnt = n ; index_t loop_cnt = n ;
...@@ -220,6 +224,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -220,6 +224,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
// [v_sld_y_os]"v"(sld_y_os), // [v_sld_y_os]"v"(sld_y_os),
// [v_sfl_sld]"v"(sfl_sld), // [v_sfl_sld]"v"(sfl_sld),
// [v_sfl_sst]"v"(sfl_sst), // [v_sfl_sst]"v"(sfl_sst),
[smq_scale0]"s"(smq_scale_[0]),
[smq_scale1]"s"(smq_scale_[1]),
[s_res_dq]"s"(res_dq), [s_res_dq]"s"(res_dq),
[s_res_o0]"s"(res_o[0]), [s_res_o0]"s"(res_o[0]),
[s_res_o1]"s"(res_o[1]), [s_res_o1]"s"(res_o[1]),
...@@ -229,6 +235,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -229,6 +235,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
// [s_res_b1]"s"(res_b[1]), // [s_res_b1]"s"(res_b[1]),
// [s_res_b2]"s"(res_b[2]), // [s_res_b2]"s"(res_b[2]),
// [s_res_b3]"s"(res_b[3]), // [s_res_b3]"s"(res_b[3]),
[v_os_dq]"v"(static_cast<index_t>(cached_coords_dq * sizeof(DScaleDataType))),
[v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))), [v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
[v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))), [v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
[v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))), [v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
...@@ -293,8 +300,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -293,8 +300,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
"s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63", "s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63",
"s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72", "s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72",
"s73", "s74", "s75", "s76", "s77", "s78", "s79", "s80", // s86 as tmp "s73", "s74", "s75", "s76", "s77", "s78", "s79", "s80", // s86 as tmp
"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v1", "v2", "v3", "v4", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v50", "v51", "v52", "v53", "v54", "v55", "v20", "v21", "v22", "v23", "v24", "v25", "v50", "v51", "v52", "v53", "v54", "v55",
"v56", "v57", "v64", "v56", "v57", "v64",
"v65", "v66", "v67", "v68", "v69", "v70", "v71", "v72", "v73", "v65", "v66", "v67", "v68", "v69", "v70", "v71", "v72", "v73",
...@@ -390,8 +396,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -390,8 +396,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
[s_tile_os_b_half]"s"(tile_offset_half_b_bytes), [s_tile_os_b_half]"s"(tile_offset_half_b_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes), [s_tile_os_b]"s"(tile_stride_b_bytes),
[s_tile_os_dq]"s"(tile_stride_dq_bytes), [s_tile_os_dq]"s"(tile_stride_dq_bytes),
// [scale_0]"v"(s0), [scale_0]"v"(s0),
// [scale_1]"v"(s1), [scale_1]"v"(s1),
// [v_nan_lo]"v"(nan_lo), // [v_nan_lo]"v"(nan_lo),
// [v_nan_hi]"v"(nan_hi), // [v_nan_hi]"v"(nan_hi),
[s_execflag_0]"s"(o_flags[number<0>{}]), [s_execflag_0]"s"(o_flags[number<0>{}]),
...@@ -438,8 +444,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -438,8 +444,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
"s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63", "s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63",
"s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72", "s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72",
"s73", "s74", "s75", "s76", "s77", "s78", "s79", "s80", // s86 as tmp "s73", "s74", "s75", "s76", "s77", "s78", "s79", "s80", // s86 as tmp
"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v1", "v2", "v3", "v4", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v50", "v51", "v52", "v53", "v54", "v55", "v20", "v21", "v22", "v23", "v24", "v25", "v50", "v51", "v52", "v53", "v54", "v55",
"v56", "v57", "v64", "v56", "v57", "v64",
"v65", "v66", "v67", "v68", "v69", "v70", "v71", "v72", "v73", "v65", "v66", "v67", "v68", "v69", "v70", "v71", "v72", "v73",
......
...@@ -27,8 +27,12 @@ ...@@ -27,8 +27,12 @@
# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_f16" # define _UK_ATOMIC_ADD_ "global_atomic_pk_add_f16"
#endif #endif
" v_and_b32 v0, 0x3f, v0 \n"
" v_lshrrev_b32 v3, 6, v0 \n"
" v_readfirstlane_b32 s7, v3 \n"
" s_waitcnt vmcnt(24) \n" " s_waitcnt vmcnt(24) \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], %[s_res_d], 0 offen\n" " buffer_load_dwordx4 acc[0:3], %[v_os_b0], %[s_res_d], 0 offen\n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], %[s_res_d], 0 offen offset:1024\n"
" v_mul_f32 v54, v128, v128 \n" " v_mul_f32 v54, v128, v128 \n"
" v_mul_f32 v55, v129, v129 \n" " v_mul_f32 v55, v129, v129 \n"
" v_mul_f32 v56, v130, v130 \n" " v_mul_f32 v56, v130, v130 \n"
...@@ -49,7 +53,6 @@ ...@@ -49,7 +53,6 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], %[s_res_d], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -577,71 +580,71 @@ ...@@ -577,71 +580,71 @@
" v_mul_f32 v189, v189, v55 \n" " v_mul_f32 v189, v189, v55 \n"
" v_mul_f32 v190, v190, v56 \n" " v_mul_f32 v190, v190, v56 \n"
" v_mul_f32 v191, v191, v57 \n" " v_mul_f32 v191, v191, v57 \n"
" v_mul_f32 v128, v18, v128 row_newbcast:0 \n" " v_mul_f32 v128, %[smq_scale0], v128 row_newbcast:0 \n"
" v_mul_f32 v129, v18, v129 row_newbcast:1 \n" " v_mul_f32 v129, %[smq_scale0], v129 row_newbcast:1 \n"
" v_mul_f32 v130, v18, v130 row_newbcast:2 \n" " v_mul_f32 v130, %[smq_scale0], v130 row_newbcast:2 \n"
" v_mul_f32 v131, v18, v131 row_newbcast:3 \n" " v_mul_f32 v131, %[smq_scale0], v131 row_newbcast:3 \n"
" v_mul_f32 v132, v18, v132 row_newbcast:0 \n" " v_mul_f32 v132, %[smq_scale0], v132 row_newbcast:0 \n"
" v_mul_f32 v133, v18, v133 row_newbcast:1 \n" " v_mul_f32 v133, %[smq_scale0], v133 row_newbcast:1 \n"
" v_mul_f32 v134, v18, v134 row_newbcast:2 \n" " v_mul_f32 v134, %[smq_scale0], v134 row_newbcast:2 \n"
" v_mul_f32 v135, v18, v135 row_newbcast:3 \n" " v_mul_f32 v135, %[smq_scale0], v135 row_newbcast:3 \n"
" v_mul_f32 v136, v18, v136 row_newbcast:4 \n" " v_mul_f32 v136, %[smq_scale0], v136 row_newbcast:4 \n"
" v_mul_f32 v137, v18, v137 row_newbcast:5 \n" " v_mul_f32 v137, %[smq_scale0], v137 row_newbcast:5 \n"
" v_mul_f32 v138, v18, v138 row_newbcast:6 \n" " v_mul_f32 v138, %[smq_scale0], v138 row_newbcast:6 \n"
" v_mul_f32 v139, v18, v139 row_newbcast:7 \n" " v_mul_f32 v139, %[smq_scale0], v139 row_newbcast:7 \n"
" v_mul_f32 v140, v18, v140 row_newbcast:4 \n" " v_mul_f32 v140, %[smq_scale0], v140 row_newbcast:4 \n"
" v_mul_f32 v141, v18, v141 row_newbcast:5 \n" " v_mul_f32 v141, %[smq_scale0], v141 row_newbcast:5 \n"
" v_mul_f32 v142, v18, v142 row_newbcast:6 \n" " v_mul_f32 v142, %[smq_scale0], v142 row_newbcast:6 \n"
" v_mul_f32 v143, v18, v143 row_newbcast:7 \n" " v_mul_f32 v143, %[smq_scale0], v143 row_newbcast:7 \n"
" v_mul_f32 v144, v18, v144 row_newbcast:8 \n" " v_mul_f32 v144, %[smq_scale0], v144 row_newbcast:8 \n"
" v_mul_f32 v145, v18, v145 row_newbcast:9 \n" " v_mul_f32 v145, %[smq_scale0], v145 row_newbcast:9 \n"
" v_mul_f32 v146, v18, v146 row_newbcast:10 \n" " v_mul_f32 v146, %[smq_scale0], v146 row_newbcast:10 \n"
" v_mul_f32 v147, v18, v147 row_newbcast:11 \n" " v_mul_f32 v147, %[smq_scale0], v147 row_newbcast:11 \n"
" v_mul_f32 v148, v18, v148 row_newbcast:8 \n" " v_mul_f32 v148, %[smq_scale0], v148 row_newbcast:8 \n"
" v_mul_f32 v149, v18, v149 row_newbcast:9 \n" " v_mul_f32 v149, %[smq_scale0], v149 row_newbcast:9 \n"
" v_mul_f32 v150, v18, v150 row_newbcast:10 \n" " v_mul_f32 v150, %[smq_scale0], v150 row_newbcast:10 \n"
" v_mul_f32 v151, v18, v151 row_newbcast:11 \n" " v_mul_f32 v151, %[smq_scale0], v151 row_newbcast:11 \n"
" v_mul_f32 v152, v18, v152 row_newbcast:12 \n" " v_mul_f32 v152, %[smq_scale0], v152 row_newbcast:12 \n"
" v_mul_f32 v153, v18, v153 row_newbcast:13 \n" " v_mul_f32 v153, %[smq_scale0], v153 row_newbcast:13 \n"
" v_mul_f32 v154, v18, v154 row_newbcast:14 \n" " v_mul_f32 v154, %[smq_scale0], v154 row_newbcast:14 \n"
" v_mul_f32 v155, v18, v155 row_newbcast:15 \n" " v_mul_f32 v155, %[smq_scale0], v155 row_newbcast:15 \n"
" v_mul_f32 v156, v18, v156 row_newbcast:12 \n" " v_mul_f32 v156, %[smq_scale0], v156 row_newbcast:12 \n"
" v_mul_f32 v157, v18, v157 row_newbcast:13 \n" " v_mul_f32 v157, %[smq_scale0], v157 row_newbcast:13 \n"
" v_mul_f32 v158, v18, v158 row_newbcast:14 \n" " v_mul_f32 v158, %[smq_scale0], v158 row_newbcast:14 \n"
" v_mul_f32 v159, v18, v159 row_newbcast:15 \n" " v_mul_f32 v159, %[smq_scale0], v159 row_newbcast:15 \n"
" v_mul_f32 v160, v19, v160 row_newbcast:0 \n" " v_mul_f32 v160, %[smq_scale1], v160 row_newbcast:0 \n"
" v_mul_f32 v161, v19, v161 row_newbcast:1 \n" " v_mul_f32 v161, %[smq_scale1], v161 row_newbcast:1 \n"
" v_mul_f32 v162, v19, v162 row_newbcast:2 \n" " v_mul_f32 v162, %[smq_scale1], v162 row_newbcast:2 \n"
" v_mul_f32 v163, v19, v163 row_newbcast:3 \n" " v_mul_f32 v163, %[smq_scale1], v163 row_newbcast:3 \n"
" v_mul_f32 v164, v19, v164 row_newbcast:0 \n" " v_mul_f32 v164, %[smq_scale1], v164 row_newbcast:0 \n"
" v_mul_f32 v165, v19, v165 row_newbcast:1 \n" " v_mul_f32 v165, %[smq_scale1], v165 row_newbcast:1 \n"
" v_mul_f32 v166, v19, v166 row_newbcast:2 \n" " v_mul_f32 v166, %[smq_scale1], v166 row_newbcast:2 \n"
" v_mul_f32 v167, v19, v167 row_newbcast:3 \n" " v_mul_f32 v167, %[smq_scale1], v167 row_newbcast:3 \n"
" v_mul_f32 v168, v19, v168 row_newbcast:4 \n" " v_mul_f32 v168, %[smq_scale1], v168 row_newbcast:4 \n"
" v_mul_f32 v169, v19, v169 row_newbcast:5 \n" " v_mul_f32 v169, %[smq_scale1], v169 row_newbcast:5 \n"
" v_mul_f32 v170, v19, v170 row_newbcast:6 \n" " v_mul_f32 v170, %[smq_scale1], v170 row_newbcast:6 \n"
" v_mul_f32 v171, v19, v171 row_newbcast:7 \n" " v_mul_f32 v171, %[smq_scale1], v171 row_newbcast:7 \n"
" v_mul_f32 v172, v19, v172 row_newbcast:4 \n" " v_mul_f32 v172, %[smq_scale1], v172 row_newbcast:4 \n"
" v_mul_f32 v173, v19, v173 row_newbcast:5 \n" " v_mul_f32 v173, %[smq_scale1], v173 row_newbcast:5 \n"
" v_mul_f32 v174, v19, v174 row_newbcast:6 \n" " v_mul_f32 v174, %[smq_scale1], v174 row_newbcast:6 \n"
" v_mul_f32 v175, v19, v175 row_newbcast:7 \n" " v_mul_f32 v175, %[smq_scale1], v175 row_newbcast:7 \n"
" v_mul_f32 v176, v19, v176 row_newbcast:8 \n" " v_mul_f32 v176, %[smq_scale1], v176 row_newbcast:8 \n"
" v_mul_f32 v177, v19, v177 row_newbcast:9 \n" " v_mul_f32 v177, %[smq_scale1], v177 row_newbcast:9 \n"
" v_mul_f32 v178, v19, v178 row_newbcast:10 \n" " v_mul_f32 v178, %[smq_scale1], v178 row_newbcast:10 \n"
" v_mul_f32 v179, v19, v179 row_newbcast:11 \n" " v_mul_f32 v179, %[smq_scale1], v179 row_newbcast:11 \n"
" v_mul_f32 v180, v19, v180 row_newbcast:8 \n" " v_mul_f32 v180, %[smq_scale1], v180 row_newbcast:8 \n"
" v_mul_f32 v181, v19, v181 row_newbcast:9 \n" " v_mul_f32 v181, %[smq_scale1], v181 row_newbcast:9 \n"
" v_mul_f32 v182, v19, v182 row_newbcast:10 \n" " v_mul_f32 v182, %[smq_scale1], v182 row_newbcast:10 \n"
" v_mul_f32 v183, v19, v183 row_newbcast:11 \n" " v_mul_f32 v183, %[smq_scale1], v183 row_newbcast:11 \n"
" v_mul_f32 v184, v19, v184 row_newbcast:12 \n" " v_mul_f32 v184, %[smq_scale1], v184 row_newbcast:12 \n"
" v_mul_f32 v185, v19, v185 row_newbcast:13 \n" " v_mul_f32 v185, %[smq_scale1], v185 row_newbcast:13 \n"
" v_mul_f32 v186, v19, v186 row_newbcast:14 \n" " v_mul_f32 v186, %[smq_scale1], v186 row_newbcast:14 \n"
" v_mul_f32 v187, v19, v187 row_newbcast:15 \n" " v_mul_f32 v187, %[smq_scale1], v187 row_newbcast:15 \n"
" v_mul_f32 v188, v19, v188 row_newbcast:12 \n" " v_mul_f32 v188, %[smq_scale1], v188 row_newbcast:12 \n"
" v_mul_f32 v189, v19, v189 row_newbcast:13 \n" " v_mul_f32 v189, %[smq_scale1], v189 row_newbcast:13 \n"
" v_mul_f32 v190, v19, v190 row_newbcast:14 \n" " v_mul_f32 v190, %[smq_scale1], v190 row_newbcast:14 \n"
" v_mul_f32 v191, v19, v191 row_newbcast:15 \n" " v_mul_f32 v191, %[smq_scale1], v191 row_newbcast:15 \n"
" buffer_load_dword v12, v5, %[s_res_dq], 0 offen \n" " buffer_load_dword v12, %[v_os_dq], %[s_res_dq], 0 offen \n"
" v_mov_b32 v22, 0x358637bd \n" " v_mov_b32 v22, 0x358637bd \n"
" v_mov_b32 v23, 0x358637bd \n" " v_mov_b32 v23, 0x358637bd \n"
" v_max3_f32 v22, abs(v128), abs(v129), v22 \n" " v_max3_f32 v22, abs(v128), abs(v129), v22 \n"
...@@ -934,9 +937,42 @@ ...@@ -934,9 +937,42 @@
" v_lshlrev_b32 v54, 1, v54 \n" " v_lshlrev_b32 v54, 1, v54 \n"
" v_add_u32 v55, v54, v55 \n" " v_add_u32 v55, v54, v55 \n"
" v_lshlrev_b32 v54, 2, v55 \n" " v_lshlrev_b32 v54, 2, v55 \n"
" ds_read_b64 v[128:129], v54 offset:18688 \n"
" ds_read_b64 v[130:131], v54 offset:18816 \n"
" ds_read_b64 v[132:133], v54 offset:19712 \n"
" ds_read_b64 v[134:135], v54 offset:19840 \n"
" ds_read_b64 v[136:137], v54 offset:20736 \n"
" ds_read_b64 v[138:139], v54 offset:20864 \n"
" ds_read_b64 v[140:141], v54 offset:21760 \n"
" ds_read_b64 v[142:143], v54 offset:21888 \n"
" ds_read_b64 v[144:145], v54 offset:22784 \n"
" ds_read_b64 v[146:147], v54 offset:22912 \n"
" ds_read_b64 v[148:149], v54 offset:23808 \n"
" ds_read_b64 v[150:151], v54 offset:23936 \n"
" ds_read_b64 v[152:153], v54 offset:24832 \n"
" ds_read_b64 v[154:155], v54 offset:24960 \n"
" ds_read_b64 v[156:157], v54 offset:25856 \n"
" ds_read_b64 v[158:159], v54 offset:25984 \n"
" ds_read_b64 v[160:161], v54 offset:26880 \n"
" ds_read_b64 v[162:163], v54 offset:27008 \n"
" ds_read_b64 v[164:165], v54 offset:27904 \n"
" ds_read_b64 v[166:167], v54 offset:28032 \n"
" ds_read_b64 v[168:169], v54 offset:28928 \n"
" ds_read_b64 v[170:171], v54 offset:29056 \n"
" ds_read_b64 v[172:173], v54 offset:29952 \n"
" ds_read_b64 v[174:175], v54 offset:30080 \n"
" ds_read_b64 v[176:177], v54 offset:30976 \n"
" ds_read_b64 v[178:179], v54 offset:31104 \n"
" ds_read_b64 v[180:181], v54 offset:32000 \n"
" ds_read_b64 v[182:183], v54 offset:32128 \n"
" ds_read_b64 v[184:185], v54 offset:33024 \n"
" ds_read_b64 v[186:187], v54 offset:33152 \n"
" ds_read_b64 v[188:189], v54 offset:34048 \n"
" ds_read_b64 v[190:191], v54 offset:34176 \n"
#undef _UK_MFMA_ #undef _UK_MFMA_
#undef _UK_PK_CVT_ #undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_ #undef _UK_ATOMIC_ADD_
...@@ -5,36 +5,20 @@ ...@@ -5,36 +5,20 @@
#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_INT8 #if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_INT8
# define _UK_MFMA_ "v_mfma_i32_16x16x32_i8" # define _UK_MFMA_ "v_mfma_i32_16x16x32_i8"
#endif #endif
# define _DEQUAN_CVT_(a0,a1,a2,a3, b, c) \ # define _DEQUAN_CVT_(a0,a1,a2,a3, xq, gq,brd0,brd1,brd2,brd3) \
" v_cvt_f32_i32 a0, a0 \n" \ " v_cvt_f32_i32 " a0 ", " a0 " \n" \
" v_cvt_f32_i32 a1, a1 \n" \ " v_cvt_f32_i32 " a1 ", " a1 " \n" \
" v_cvt_f32_i32 a2, a2 \n" \ " v_cvt_f32_i32 " a2 ", " a2 " \n" \
" v_cvt_f32_i32 a3, a3 \n" \ " v_cvt_f32_i32 " a3 ", " a3" \n" \
" v_mul_f32 a0, v15, a0 \n" \ " v_mul_f32 " a0 ", " xq ", " a0 " \n" \
" v_mul_f32 a1, v15, a1 \n" \ " v_mul_f32 " a1 ", " xq ", " a1 " \n" \
" v_mul_f32 a2, v15, a2 \n" \ " v_mul_f32 " a2 ", " xq ", " a2 " \n" \
" v_mul_f32 a3, v15, a3 \n" \ " v_mul_f32 " a3 ", " xq ", " a3 " \n" \
" v_mul_f32 a0, v17, a0 row_newbcast:12 \n" \ " v_mul_f32 " a0 ", " gq ", " a0 " row_newbcast:" brd0 " \n" \
" v_mul_f32 a1, v17, a1 row_newbcast:13 \n" \ " v_mul_f32 " a1 ", " gq ", " a1 " row_newbcast:" brd1 " \n" \
" v_mul_f32 a2, v17, a2 row_newbcast:14 \n" \ " v_mul_f32 " a2 ", " gq ", " a2 " row_newbcast: " brd2 " \n" \
" v_mul_f32 a3, v17, a3 row_newbcast:15 \n" \ " v_mul_f32 " a3 ", " gq ", " a3 " row_newbcast:" brd3 " \n"
";---------------------------------------------- \n"
" v_lshrrev_b32 v54, 4, v0 \n"
" v_lshlrev_b32 v55, 2, v54 \n"
" v_and_b32 v54, 15, v0 \n"
" v_lshrrev_b32 v56, 2, v54 \n"
" v_lshlrev_b32 v56, 6, v56 \n"
" v_add_u32 v55, v56, v55 \n"
" v_and_b32 v54, 3, v0 \n"
" v_add_u32 v55, v54, v55 \n"
" v_lshlrev_b32 v10, 2, v55 \n"
" v_add_u32 v11, 0x00000400, v10 \n"
" s_mul_i32 s60, %[s_wave_id], 16 \n"
" s_mul_i32 s60, s60, 4 \n"
" v_add_u32 v10, s60, v10 \n"
" v_add_u32 v11, s60, v11 \n"
" v_mov_b32 v5, v10 \n"
";---------------------------------------------- \n" ";---------------------------------------------- \n"
" s_mov_b32 s57, 0x00000100 \n" " s_mov_b32 s57, 0x00000100 \n"
" s_mov_b32 s58, 0x00001000 \n" " s_mov_b32 s58, 0x00001000 \n"
...@@ -53,27 +37,22 @@ ...@@ -53,27 +37,22 @@
" v_mov_b32 v52, 0x7fff0000 \n" " v_mov_b32 v52, 0x7fff0000 \n"
" v_mov_b32 v53, 0x00007fff \n" " v_mov_b32 v53, 0x00007fff \n"
" s_waitcnt 0x0000 \n" " s_waitcnt 0x0000 \n"
";---------------------------------------------- \n"
" v_lshrrev_b32 v54, 24, %[v_token_id0] \n"
" v_mul_i32_i24 v54, s66, v54 \n"
" v_and_b32 v55, 0x00ffffff, %[v_token_id0] \n"
" v_add_u32 %[v_token_id0], v54, v55 \n"
" v_lshrrev_b32 v54, 24, %[v_token_id1] \n"
" v_mul_i32_i24 v54, s66, v54 \n"
" v_and_b32 v55, 0x00ffffff, %[v_token_id1] \n"
" v_add_u32 %[v_token_id1], v54, v55 \n"
" v_lshlrev_b32 %[v_token_id0], 2, %[v_token_id0] \n"
" v_lshlrev_b32 %[v_token_id1], 2, %[v_token_id1] \n"
" buffer_load_dword v14, %[v_token_id0], %[s_res_aq], 0 offen \n"
" buffer_load_dword v15, %[v_token_id1], %[s_res_aq], 0 offen \n"
" buffer_load_dword v16, v10, %[s_res_gq], 0 offen \n"
" buffer_load_dword v17, v11, %[s_res_gq], 0 offen \n"
" buffer_load_dword v18, v10, %[s_res_smq], 0 offen \n"
" buffer_load_dword v19, v11, %[s_res_smq], 0 offen \n"
" buffer_load_dword v20, v8, s[40:43], 0 offen \n"
" buffer_load_dword v21, v9, s[40:43], 0 offen \n"
" s_mov_b32 s80, 0 \n" " s_mov_b32 s80, 0 \n"
" v_lshrrev_b32 v54, 4, v0 \n"
" v_mul_i32_i24 v3, 34, v54 \n"
" v_and_b32 v54, 15, v0 \n"
" v_mul_i32_i24 v55, 2, v54 \n"
" v_add_u32 v3, v55, v3 \n"
" s_mul_i32 s60, s7, 0x00000088 \n"
" v_add_u32 v3, s60, v3 \n"
" v_lshlrev_b32 v3, 2, v3 \n"
" v_lshrrev_b32 v54, 1, v0 \n"
" v_mul_i32_i24 v4, 34, v54 \n"
" v_and_b32 v55, 1, v0 \n"
" v_add_u32 v4, v55, v4 \n"
" s_mul_i32 s60, s7, 2 \n"
" v_add_u32 v4, s60, v4 \n"
" v_lshlrev_b32 v4, 2, v4 \n"
";---------------------------------------------- \n" ";---------------------------------------------- \n"
"; -- prefetch A0\n" "; -- prefetch A0\n"
"s_add_u32 m0, 0, %[s_m0_init] \n" "s_add_u32 m0, 0, %[s_m0_init] \n"
...@@ -570,198 +549,23 @@ ...@@ -570,198 +549,23 @@
" s_branch label_start \n" " s_branch label_start \n"
" label_end : \n" " label_end : \n"
";---------------------------------------------- \n" ";---------------------------------------------- \n"
" v_cvt_f32_i32 v128, v128 \n" _DEQUAN_CVT_("%[c0]","%[c1]","%[c2]","%[c3]","%[a_scale0]"," %[gq_scale0]","0","1","2","3")
" v_cvt_f32_i32 v129, v129 \n" _DEQUAN_CVT_("%[c4]","%[c5]","%[c6]","%[c7]","%[a_scale1]"," %[gq_scale0]","0","1","2","3")
" v_cvt_f32_i32 v130, v130 \n" _DEQUAN_CVT_("%[c8]","%[c9]","%[c10]","%[c11]","%[a_scale0]"," %[gq_scale0]","4","5","6","7")
" v_cvt_f32_i32 v131, v131 \n" _DEQUAN_CVT_("%[c12]","%[c13]","%[c14]","%[c15]","%[a_scale1]"," %[gq_scale0]","4","5","6","7")
" v_mul_f32 v128, v14, v128 \n" _DEQUAN_CVT_("%[c16]","%[c17]","%[c18]","%[c19]","%[a_scale0]"," %[gq_scale0]","8","9","10","11")
" v_mul_f32 v129, v14, v129 \n" _DEQUAN_CVT_("%[c20]","%[c21]","%[c22]","%[c23]","%[a_scale1]"," %[gq_scale0]","8","9","10","11")
" v_mul_f32 v130, v14, v130 \n" _DEQUAN_CVT_("%[c24]","%[c25]","%[c26]","%[c27]","%[a_scale0]"," %[gq_scale0]","12","13","14","15")
" v_mul_f32 v131, v14, v131 \n" _DEQUAN_CVT_("%[c28]","%[c29]","%[c30]","%[c31]","%[a_scale1]"," %[gq_scale0]","12","13","14","15")
" v_mul_f32 v128, v16, v128 row_newbcast:0 \n" _DEQUAN_CVT_("%[c32]","%[c33]","%[c34]","%[c35]","%[a_scale0]"," %[gq_scale1]","0","1","2","3")
" v_mul_f32 v129, v16, v129 row_newbcast:1 \n" _DEQUAN_CVT_("%[c36]","%[c37]","%[c38]","%[c39]","%[a_scale1]"," %[gq_scale1]","0","1","2","3")
" v_mul_f32 v130, v16, v130 row_newbcast:2 \n" _DEQUAN_CVT_("%[c40]","%[c41]","%[c42]","%[c43]","%[a_scale0]"," %[gq_scale1]","4","5","6","7")
" v_mul_f32 v131, v16, v131 row_newbcast:3 \n" _DEQUAN_CVT_("%[c44]","%[c45]","%[c46]","%[c47]","%[a_scale1]"," %[gq_scale1]","4","5","6","7")
" v_cvt_f32_i32 v132, v132 \n" _DEQUAN_CVT_("%[c48]","%[c49]","%[c50]","%[c51]","%[a_scale0]"," %[gq_scale1]","8","9","10","11")
" v_cvt_f32_i32 v133, v133 \n" _DEQUAN_CVT_("%[c52]","%[c53]","%[c54]","%[c55]","%[a_scale1]"," %[gq_scale1]","8","9","10","11")
" v_cvt_f32_i32 v134, v134 \n" _DEQUAN_CVT_("%[c56]","%[c57]","%[c58]","%[c59]","%[a_scale0]"," %[gq_scale1]","12","13","14","15")
" v_cvt_f32_i32 v135, v135 \n" _DEQUAN_CVT_("%[c60]","%[c61]","%[c62]","%[c63]","%[a_scale1]"," %[gq_scale1]","12","13","14","15")
" v_mul_f32 v132, v15, v132 \n"
" v_mul_f32 v133, v15, v133 \n"
" v_mul_f32 v134, v15, v134 \n"
" v_mul_f32 v135, v15, v135 \n"
" v_mul_f32 v132, v16, v132 row_newbcast:0 \n"
" v_mul_f32 v133, v16, v133 row_newbcast:1 \n"
" v_mul_f32 v134, v16, v134 row_newbcast:2 \n"
" v_mul_f32 v135, v16, v135 row_newbcast:3 \n"
" v_cvt_f32_i32 v136, v136 \n"
" v_cvt_f32_i32 v137, v137 \n"
" v_cvt_f32_i32 v138, v138 \n"
" v_cvt_f32_i32 v139, v139 \n"
" v_mul_f32 v136, v14, v136 \n"
" v_mul_f32 v137, v14, v137 \n"
" v_mul_f32 v138, v14, v138 \n"
" v_mul_f32 v139, v14, v139 \n"
" v_mul_f32 v136, v16, v136 row_newbcast:4 \n"
" v_mul_f32 v137, v16, v137 row_newbcast:5 \n"
" v_mul_f32 v138, v16, v138 row_newbcast:6 \n"
" v_mul_f32 v139, v16, v139 row_newbcast:7 \n"
" v_cvt_f32_i32 v140, v140 \n"
" v_cvt_f32_i32 v141, v141 \n"
" v_cvt_f32_i32 v142, v142 \n"
" v_cvt_f32_i32 v143, v143 \n"
" v_mul_f32 v140, v15, v140 \n"
" v_mul_f32 v141, v15, v141 \n"
" v_mul_f32 v142, v15, v142 \n"
" v_mul_f32 v143, v15, v143 \n"
" v_mul_f32 v140, v16, v140 row_newbcast:4 \n"
" v_mul_f32 v141, v16, v141 row_newbcast:5 \n"
" v_mul_f32 v142, v16, v142 row_newbcast:6 \n"
" v_mul_f32 v143, v16, v143 row_newbcast:7 \n"
" v_cvt_f32_i32 v144, v144 \n"
" v_cvt_f32_i32 v145, v145 \n"
" v_cvt_f32_i32 v146, v146 \n"
" v_cvt_f32_i32 v147, v147 \n"
" v_mul_f32 v144, v14, v144 \n"
" v_mul_f32 v145, v14, v145 \n"
" v_mul_f32 v146, v14, v146 \n"
" v_mul_f32 v147, v14, v147 \n"
" v_mul_f32 v144, v16, v144 row_newbcast:8 \n"
" v_mul_f32 v145, v16, v145 row_newbcast:9 \n"
" v_mul_f32 v146, v16, v146 row_newbcast:10 \n"
" v_mul_f32 v147, v16, v147 row_newbcast:11 \n"
" v_cvt_f32_i32 v148, v148 \n"
" v_cvt_f32_i32 v149, v149 \n"
" v_cvt_f32_i32 v150, v150 \n"
" v_cvt_f32_i32 v151, v151 \n"
" v_mul_f32 v148, v15, v148 \n"
" v_mul_f32 v149, v15, v149 \n"
" v_mul_f32 v150, v15, v150 \n"
" v_mul_f32 v151, v15, v151 \n"
" v_mul_f32 v148, v16, v148 row_newbcast:8 \n"
" v_mul_f32 v149, v16, v149 row_newbcast:9 \n"
" v_mul_f32 v150, v16, v150 row_newbcast:10 \n"
" v_mul_f32 v151, v16, v151 row_newbcast:11 \n"
" v_cvt_f32_i32 v152, v152 \n"
" v_cvt_f32_i32 v153, v153 \n"
" v_cvt_f32_i32 v154, v154 \n"
" v_cvt_f32_i32 v155, v155 \n"
" v_mul_f32 v152, v14, v152 \n"
" v_mul_f32 v153, v14, v153 \n"
" v_mul_f32 v154, v14, v154 \n"
" v_mul_f32 v155, v14, v155 \n"
" v_mul_f32 v152, v16, v152 row_newbcast:12 \n"
" v_mul_f32 v153, v16, v153 row_newbcast:13 \n"
" v_mul_f32 v154, v16, v154 row_newbcast:14 \n"
" v_mul_f32 v155, v16, v155 row_newbcast:15 \n"
" v_cvt_f32_i32 v156, v156 \n"
" v_cvt_f32_i32 v157, v157 \n"
" v_cvt_f32_i32 v158, v158 \n"
" v_cvt_f32_i32 v159, v159 \n"
" v_mul_f32 v156, v15, v156 \n"
" v_mul_f32 v157, v15, v157 \n"
" v_mul_f32 v158, v15, v158 \n"
" v_mul_f32 v159, v15, v159 \n"
" v_mul_f32 v156, v16, v156 row_newbcast:12 \n"
" v_mul_f32 v157, v16, v157 row_newbcast:13 \n"
" v_mul_f32 v158, v16, v158 row_newbcast:14 \n"
" v_mul_f32 v159, v16, v159 row_newbcast:15 \n"
" v_cvt_f32_i32 v160, v160 \n"
" v_cvt_f32_i32 v161, v161 \n"
" v_cvt_f32_i32 v162, v162 \n"
" v_cvt_f32_i32 v163, v163 \n"
" v_mul_f32 v160, v14, v160 \n"
" v_mul_f32 v161, v14, v161 \n"
" v_mul_f32 v162, v14, v162 \n"
" v_mul_f32 v163, v14, v163 \n"
" v_mul_f32 v160, v17, v160 row_newbcast:0 \n"
" v_mul_f32 v161, v17, v161 row_newbcast:1 \n"
" v_mul_f32 v162, v17, v162 row_newbcast:2 \n"
" v_mul_f32 v163, v17, v163 row_newbcast:3 \n"
" v_cvt_f32_i32 v164, v164 \n"
" v_cvt_f32_i32 v165, v165 \n"
" v_cvt_f32_i32 v166, v166 \n"
" v_cvt_f32_i32 v167, v167 \n"
" v_mul_f32 v164, v15, v164 \n"
" v_mul_f32 v165, v15, v165 \n"
" v_mul_f32 v166, v15, v166 \n"
" v_mul_f32 v167, v15, v167 \n"
" v_mul_f32 v164, v17, v164 row_newbcast:0 \n"
" v_mul_f32 v165, v17, v165 row_newbcast:1 \n"
" v_mul_f32 v166, v17, v166 row_newbcast:2 \n"
" v_mul_f32 v167, v17, v167 row_newbcast:3 \n"
" v_cvt_f32_i32 v168, v168 \n"
" v_cvt_f32_i32 v169, v169 \n"
" v_cvt_f32_i32 v170, v170 \n"
" v_cvt_f32_i32 v171, v171 \n"
" v_mul_f32 v168, v14, v168 \n"
" v_mul_f32 v169, v14, v169 \n"
" v_mul_f32 v170, v14, v170 \n"
" v_mul_f32 v171, v14, v171 \n"
" v_mul_f32 v168, v17, v168 row_newbcast:4 \n"
" v_mul_f32 v169, v17, v169 row_newbcast:5 \n"
" v_mul_f32 v170, v17, v170 row_newbcast:6 \n"
" v_mul_f32 v171, v17, v171 row_newbcast:7 \n"
" v_cvt_f32_i32 v172, v172 \n"
" v_cvt_f32_i32 v173, v173 \n"
" v_cvt_f32_i32 v174, v174 \n"
" v_cvt_f32_i32 v175, v175 \n"
" v_mul_f32 v172, v15, v172 \n"
" v_mul_f32 v173, v15, v173 \n"
" v_mul_f32 v174, v15, v174 \n"
" v_mul_f32 v175, v15, v175 \n"
" v_mul_f32 v172, v17, v172 row_newbcast:4 \n"
" v_mul_f32 v173, v17, v173 row_newbcast:5 \n"
" v_mul_f32 v174, v17, v174 row_newbcast:6 \n"
" v_mul_f32 v175, v17, v175 row_newbcast:7 \n"
" v_cvt_f32_i32 v176, v176 \n"
" v_cvt_f32_i32 v177, v177 \n"
" v_cvt_f32_i32 v178, v178 \n"
" v_cvt_f32_i32 v179, v179 \n"
" v_mul_f32 v176, v14, v176 \n"
" v_mul_f32 v177, v14, v177 \n"
" v_mul_f32 v178, v14, v178 \n"
" v_mul_f32 v179, v14, v179 \n"
" v_mul_f32 v176, v17, v176 row_newbcast:8 \n"
" v_mul_f32 v177, v17, v177 row_newbcast:9 \n"
" v_mul_f32 v178, v17, v178 row_newbcast:10 \n"
" v_mul_f32 v179, v17, v179 row_newbcast:11 \n"
" v_cvt_f32_i32 v180, v180 \n"
" v_cvt_f32_i32 v181, v181 \n"
" v_cvt_f32_i32 v182, v182 \n"
" v_cvt_f32_i32 v183, v183 \n"
" v_mul_f32 v180, v15, v180 \n"
" v_mul_f32 v181, v15, v181 \n"
" v_mul_f32 v182, v15, v182 \n"
" v_mul_f32 v183, v15, v183 \n"
" v_mul_f32 v180, v17, v180 row_newbcast:8 \n"
" v_mul_f32 v181, v17, v181 row_newbcast:9 \n"
" v_mul_f32 v182, v17, v182 row_newbcast:10 \n"
" v_mul_f32 v183, v17, v183 row_newbcast:11 \n"
" v_cvt_f32_i32 v184, v184 \n"
" v_cvt_f32_i32 v185, v185 \n"
" v_cvt_f32_i32 v186, v186 \n"
" v_cvt_f32_i32 v187, v187 \n"
" v_mul_f32 v184, v14, v184 \n"
" v_mul_f32 v185, v14, v185 \n"
" v_mul_f32 v186, v14, v186 \n"
" v_mul_f32 v187, v14, v187 \n"
" v_mul_f32 v184, v17, v184 row_newbcast:12 \n"
" v_mul_f32 v185, v17, v185 row_newbcast:13 \n"
" v_mul_f32 v186, v17, v186 row_newbcast:14 \n"
" v_mul_f32 v187, v17, v187 row_newbcast:15 \n"
" v_cvt_f32_i32 v188, v188 \n"
" v_cvt_f32_i32 v189, v189 \n"
" v_cvt_f32_i32 v190, v190 \n"
" v_cvt_f32_i32 v191, v191 \n"
" v_mul_f32 v188, v15, v188 \n"
" v_mul_f32 v189, v15, v189 \n"
" v_mul_f32 v190, v15, v190 \n"
" v_mul_f32 v191, v15, v191 \n"
" v_mul_f32 v188, v17, v188 row_newbcast:12 \n"
" v_mul_f32 v189, v17, v189 row_newbcast:13 \n"
" v_mul_f32 v190, v17, v190 row_newbcast:14 \n"
" v_mul_f32 v191, v17, v191 row_newbcast:15 \n"
#undef _UK_MFMA_ #undef _UK_MFMA_
#undef _DEQUAN_CVT_ #undef _DEQUAN_CVT_
...@@ -186,6 +186,50 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -186,6 +186,50 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
return coords; return coords;
} }
// TODO: this row id is before shuffle atomic, need use acc distribution
//this calculation shared by G and SMQ
CK_TILE_DEVICE auto GetColCoords_GQSMQ(index_t base_offset)
{
constexpr index_t MLanes = BlockShape::Warp_M1;
constexpr index_t Repeat_N = 2;//different,this load is partitioned along N
// auto h_id = threadIdx.x / MLanes ;
// auto r_id = threadIdx.x & 0xffff;
// auto p_id = r_id/4;
// auto q_is = threadIdx.x & 0x3;
array<index_t, Repeat_N> coords;
static_for<0, Repeat_N, 1>{}([&](auto i) { coords.at(i) = base_coord + (threadIdx.x / MLanes) * 4 +
(threadIdx.x & 0xffff)/4 * 64 +
q_id +
i * 256 ; });
return coords;
}
//this calculation shared by G and SMQ
CK_TILE_DEVICE auto GetGQScale(const COL_IDS coords,
const GScaleDataType* g_scale_ptr)
{
constexpr index_t n_size = coords.size();
array<GScaleDataType, n_size> g_scale_value;
static_for<0, n_size, 1>{}([&](auto i) {
g_scale_value.at(i) = g_scale_ptr[coords[i]];
});
return g_scale_value;
}
CK_TILE_DEVICE auto GetSMQScale(const COL_IDS coords,
const YSmoothScaleDataType * y_scale_ptr)
{
constexpr index_t n_size = coords.size();
array<YSmoothScaleDataType, n_size> y_scale_value;
static_for<0, n_size, 1>{}([&](auto i) {
y_scale_value.at(i) = y_scale_ptr[coords[i]];
});
return y_scale_value;
}
template <typename Karg> template <typename Karg>
CK_TILE_DEVICE auto operator()(const Karg& kargs, CK_TILE_DEVICE auto operator()(const Karg& kargs,
...@@ -230,12 +274,6 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -230,12 +274,6 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
return (row_ids_a[i]) &0xffffff; return (row_ids_a[i]) &0xffffff;
}, },
number<row_ids_a.size()>{}); number<row_ids_a.size()>{});
// auto token_id_mma = generate_tuple(
// [&](auto i) {
// return (row_ids_a_mma[i]) &0xffffff;
// },
// number<row_ids_a_mma.size()>{});
//addr in fact
auto a_coords = generate_tuple( auto a_coords = generate_tuple(
[&](auto i) { [&](auto i) {
return ((row_ids_a[i])&0xffffff) * kargs.stride_token + return ((row_ids_a[i])&0xffffff) * kargs.stride_token +
...@@ -306,7 +344,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -306,7 +344,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto smq_win = [&]() { auto smq_win = [&]() {
const YSmoothScaleDataType* smq_ptr = reinterpret_cast<const YSmoothScaleDataType*>(kargs.y_smooth_scale_ptr) + const YSmoothScaleDataType* smq_ptr = reinterpret_cast<const YSmoothScaleDataType*>(kargs.y_smooth_scale_ptr) +
static_cast<long_index_t>(expert_id) * smq_scale_expert_stride_0 + static_cast<long_index_t>(expert_id) * smq_scale_expert_stride_0 +
intermediate_tile_id * BlockShape::Block_N0; intermediate_tile_id * BlockShape::Block_K1;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline // const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline
auto smq_view_ = make_naive_tensor_view<address_space_enum::global>( auto smq_view_ = make_naive_tensor_view<address_space_enum::global>(
smq_ptr, smq_ptr,
...@@ -346,15 +384,15 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -346,15 +384,15 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto d_res = d_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_; auto d_res = d_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
//////gq //////gq
auto dq_win = [&]() { auto dq_win = [&]() {
const DScaleDataType* g_ptr = reinterpret_cast<const DScaleDataType*>(kargs.d_scale_ptr) + const DScaleDataType* dq_ptr = reinterpret_cast<const DScaleDataType*>(kargs.d_scale_ptr) +
static_cast<long_index_t>(expert_id) * d_scale_expert_stride_1; static_cast<long_index_t>(expert_id) * d_scale_expert_stride_1;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.d_scale_ptr)//remember to add expert_id as expert_idx // const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.d_scale_ptr)//remember to add expert_id as expert_idx
auto g_view_ = make_naive_tensor_view_packed<address_space_enum::global>( auto dq_view_ = make_naive_tensor_view_packed<address_space_enum::global>(
g_ptr, dq_ptr,
make_tuple(kargs.hidden_size), make_tuple(kargs.hidden_size),
number<1>{}); number<1>{});
return g_view_; return dq_view_;
}(); }();
auto dq_res = dq_win.get_buffer_view().cached_buf_res_; auto dq_res = dq_win.get_buffer_view().cached_buf_res_;
...@@ -400,15 +438,6 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -400,15 +438,6 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
generate_tuple([&](auto i) { return cmp_lt_to_exec(token_id[i], kargs.num_tokens); }, generate_tuple([&](auto i) { return cmp_lt_to_exec(token_id[i], kargs.num_tokens); },
number<row_ids_a.size()>{}); number<row_ids_a.size()>{});
// auto bridge_sst_win = [&]() {
// constexpr auto desc_ = Policy::template MakeBridgeLdsStoreForUKDesc<Problem>();
// constexpr auto dist_ = Policy::template GetUK_0<Problem>().MakeCBlockDist();
// return make_tile_window_linear(make_tensor_view<address_space_enum::lds>(
// reinterpret_cast<YDataType*>(smem), desc_),
// desc_.get_lengths(),
// {0, 0},
// dist_);
// }();
auto o_res = auto o_res =
make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr), make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ODataType)); kargs.num_tokens * kargs.stride_token * sizeof(ODataType));
...@@ -417,16 +446,17 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -417,16 +446,17 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto w_scale = GetWeightScale( auto w_scale = GetWeightScale(
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)); row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
auto a_scale = GetAScale( auto a_scale = GetAScale(
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.a_scale_ptr)); row_ids_a_mma, reinterpret_cast<const AScaleDataType*>(kargs.a_scale_ptr));
auto gqsmq_coords = GetColCoords_GQSMQ(intermediated_tile_id * BlockShape::Block_K1);
auto dq_coords = gqsmq_coords[0];//only one for this tiling
auto gq_scale = GetGQScale(
gqsmq_coords, reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr + static_cast<long_index_t>(expert_id) * shared_intermediate_size_0));
auto smq_scale = GetSMQScale(
gqsmq_coords, reinterpret_cast<const YSmoothScaleDataType*>(kargs.y_smooth_scale_ptr + static_cast<long_index_t>(expert_id) * shared_intermediate_size_0));
auto uk_0 = Policy::template GetUK_0<Problem>(); auto uk_0 = Policy::template GetUK_0<Problem>();
// auto acc_0= uk_0( // auto acc_0= uk_0(
uk_0( uk_0( a_scale,
row_ids_a_mma,//fake token id, 2D index for X scale gq_scale,
a_scale,
dq_res,
gq_res,
smq_res,
a_res, a_res,
a_coords, a_coords,
g_res, g_res,
...@@ -457,6 +487,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -457,6 +487,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto uk_1 = Policy::template GetUK_1<Problem>(); auto uk_1 = Policy::template GetUK_1<Problem>();
uk_1(dq_res, uk_1(dq_res,
d_res, d_res,
dq_coords,
d_coords, d_coords,
o_res, o_res,
o_coords, o_coords,
...@@ -464,6 +495,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -464,6 +495,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
smem, smem,
kargs.hidden_size, // total n number kargs.hidden_size, // total n number
w_scale, w_scale,
smq_scale,
BlockShape::Block_N1, BlockShape::Block_N1,
shared_intermediate_size_1 * BlockShape::Block_N1 - kr_1 * BlockShape::Block_W1, // along N shared_intermediate_size_1 * BlockShape::Block_N1 - kr_1 * BlockShape::Block_W1, // along N
kr_1 * BlockShape::Block_W1, kr_1 * BlockShape::Block_W1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment