Commit 9a46c0e7 authored by shengnxu's avatar shengnxu
Browse files

move a scale out inline

parent 26d84960
......@@ -5,7 +5,7 @@
#include "fused_moegemm_api_traits.hpp"
#include "ck_tile/ops/fused_moe.hpp"
#include "fused_moegemm_api.cpp"
// #include "fused_moegemm_api.cpp"
#include <iostream>
template <ck_tile::index_t... Is>
......
......@@ -264,6 +264,7 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 4 /*2x per dword*/); // 8
static_assert(BCoords::size() == Repeat_N);
static_assert(AToken_id::size() == Repeat_M);
static_assert(Ascale::size() == Repeat_M);
auto a_sst = make_tile_window(
make_tensor_view<address_space_enum::lds>(
......@@ -451,30 +452,18 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
[v_token_id0]"+v"(temp0),
[v_token_id1]"+v"(temp1),
[s_mem_]"+r"(smem)
: [s_res_aq0]"s"(res_aq[0]),
[s_res_aq1]"s"(res_aq[1]),
[s_res_aq2]"s"(res_aq[2]),
[s_res_aq3]"s"(res_aq[3]),
[s_res_dq0]"s"(res_dq[0]),
[s_res_dq1]"s"(res_dq[1]),
[s_res_dq2]"s"(res_dq[2]),
[s_res_dq3]"s"(res_dq[3]),
[s_res_gq0]"s"(res_gq[0]),
[s_res_gq1]"s"(res_gq[1]),
[s_res_gq2]"s"(res_gq[2]),
[s_res_gq3]"s"(res_gq[3]),
[s_res_smq0]"s"(res_smq[0]),
[s_res_smq1]"s"(res_smq[1]),
[s_res_smq2]"s"(res_smq[2]),
[s_res_smq3]"s"(res_smq[3]),
[s_res_a0]"s"(res_a[0]),
[s_res_a1]"s"(res_a[1]),
[s_res_a2]"s"(res_a[2]),
[s_res_a3]"s"(res_a[3]),
[s_res_b0]"s"(res_b[0]),
[s_res_b1]"s"(res_b[1]),
[s_res_b2]"s"(res_b[2]),
[s_res_b3]"s"(res_b[3]),
: [s_res_aq]"s"(res_aq),
[s_res_dq]"s"(res_dq),
[s_res_gq]"s"(res_gq),
[s_res_smq]"s"(res_smq),
[s_res_a]"s"(res_a),
// [s_res_a1]"s"(res_a[1]),
// [s_res_a2]"s"(res_a[2]),
// [s_res_a3]"s"(res_a[3]),
[s_res_b]"s"(res_b),
// [s_res_b1]"s"(res_b[1]),
// [s_res_b2]"s"(res_b[2]),
// [s_res_b3]"s"(res_b[3]),
[v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))),
[v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))),
[v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))),
......@@ -539,21 +528,15 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s6", "s7", "s8", "s9", "s10", "s11", "s12", "s13", "s14", "s15",
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23", "s24", "s25",
"s26", "s27", "s28", "s29", "s30", "s31", "s32", "s33", "s34", "s35",
"s36", "s37", "s38", "s39", "s40", "s41", "s42", "s43", "s44", "s45",
"s6", "s7", "s40", "s41", "s42", "s43", "s44", "s45",
"s46", "s47", "s48", "s49", "s50", "s51", "s52", "s53", "s54",
"s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63",
"s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72",
"s73", "s74", "s75", "s76", "s77", "s78", "s79", "s80", // s86 as tmp
"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "v32", "v33", "v34", "v35", "v36", "v37",
"v38", "v39", "v40", "v41", "v42", "v43", "v44", "v45", "v46",
"v47", "v48", "v49", "v50", "v51", "v52", "v53", "v54", "v55",
"v56", "v57", "v58", "v59", "v60", "v61", "v62", "v63", "v64",
"v20", "v21", "v22", "v23", "v24", "v25", "v50", "v51", "v52", "v53", "v54", "v55",
"v56", "v57", "v64",
"v65", "v66", "v67", "v68", "v69", "v70", "v71", "v72", "v73",
"v74", "v75", "v76", "v77", "v78", "v79", "v80", "v81", "v82",
"v83", "v84", "v85", "v86", "v87", "v88", "v89", "v90", "v91",
......
......@@ -78,21 +78,23 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
template <typename BRes,
template <typename DQRes,
typename BRes,
typename BCoords,
typename ORes,
typename OCoords,
typename OFlags,
typename ScaleTensor>
typename OFlags>
// typename ScaleTensor>
CK_TILE_DEVICE auto
operator()(const BRes& res_b,
operator()(const DQRes& res_dq,
const BRes& res_b,
const BCoords& cached_coords_b,
const ORes& res_o,
const OCoords& cached_coords_o,
const OFlags& o_flags, // this should be in sgpr
CK_TILE_LDS_ADDR void* smem,
index_t n, // loop along n dim
const ScaleTensor& scale_,
// const ScaleTensor& scale_,
index_t tile_offset_dq,
index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust
index_t tile_offset_half_b, //splited load alone K in to 2 part
......@@ -106,11 +108,11 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
const index_t tile_stride_dq_bytes = tile_offset_dq * sizeof(DScaleDataType);
static_assert(ScaleTensor::size() == 2);
float s0 = scale_[number<0>{}];
float s1 = scale_[number<1>{}];
// static_assert(ScaleTensor::size() == 2);
// float s0 = scale_[number<0>{}];
// float s1 = scale_[number<1>{}];
index_t loop_cnt = n / Block_N;
index_t loop_cnt = n ;
// register float v_c0 asm("v64");
// register float v_c1 asm("v65");
......@@ -144,15 +146,15 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
// register float v_c29 asm("v93");
// register float v_c30 asm("v94");
// register float v_c31 asm("v95");
int32_t nan_hi = 0x7fff0000;
int32_t nan_lo = 0x00007fff;
// int32_t nan_hi = 0x7fff0000;
// int32_t nan_lo = 0x00007fff;
// in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4)
// every threads need 8xK in contiguous register
// ... and every wave need the same data
int lane_id = threadIdx.x % 64;
int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
sld_y_os *= 2;
// int lane_id = threadIdx.x % 64;
// int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
// sld_y_os *= 2;
// y y p p p y
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
......@@ -161,15 +163,15 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
// sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4
int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4);
sfl_sst *= 2;
// int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4);
// sfl_sst *= 2;
// from LDS we need load as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4)
// ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2)
// sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4
int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
sfl_sld *= 2;
// int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
// sfl_sld *= 2;
// B nr->kr
// clang-format off
......@@ -214,18 +216,19 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
// [c30]"+v"(v_c30),
// [c31]"+v"(v_c31)
:[sld_a_base]"n"(0),
[shfl_base]"n"(0),
[v_sld_y_os]"v"(sld_y_os),
[v_sfl_sld]"v"(sfl_sld),
[v_sfl_sst]"v"(sfl_sst),
// [shfl_base]"n"(0),
// [v_sld_y_os]"v"(sld_y_os),
// [v_sfl_sld]"v"(sfl_sld),
// [v_sfl_sst]"v"(sfl_sst),
[s_res_dq]"s"(res_dq),
[s_res_o0]"s"(res_o[0]),
[s_res_o1]"s"(res_o[1]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]),
[s_res_b0]"s"(res_b[0]),
[s_res_b1]"s"(res_b[1]),
[s_res_b2]"s"(res_b[2]),
[s_res_b3]"s"(res_b[3]),
[s_res_d]"s"(res_b),
// [s_res_b1]"s"(res_b[1]),
// [s_res_b2]"s"(res_b[2]),
// [s_res_b3]"s"(res_b[3]),
[v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
[v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
[v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
......@@ -242,10 +245,10 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
[s_tile_os_b_half]"s"(tile_offset_half_b_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes),
[s_tile_os_dq]"s"(tile_stride_dq_bytes),
[scale_0]"v"(s0),
[scale_1]"v"(s1),
[v_nan_lo]"v"(nan_lo),
[v_nan_hi]"v"(nan_hi),
// [scale_0]"v"(s0),
// [scale_1]"v"(s1),
// [v_nan_lo]"v"(nan_lo),
// [v_nan_hi]"v"(nan_hi),
[s_execflag_0]"s"(o_flags[number<0>{}]),
[s_execflag_1]"s"(o_flags[number<1>{}]),
[s_execflag_2]"s"(o_flags[number<2>{}]),
......@@ -285,21 +288,15 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s6", "s7", "s8", "s9", "s10", "s11", "s12", "s13", "s14", "s15",
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23", "s24", "s25",
"s26", "s27", "s28", "s29", "s30", "s31", "s32", "s33", "s34", "s35",
"s36", "s37", "s38", "s39", "s40", "s41", "s42", "s43", "s44", "s45",
"s6", "s7", "s40", "s41", "s42", "s43", "s44", "s45",
"s46", "s47", "s48", "s49", "s50", "s51", "s52", "s53", "s54",
"s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63",
"s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72",
"s73", "s74", "s75", "s76", "s77", "s78", "s79", "s80", // s86 as tmp
"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "v32", "v33", "v34", "v35", "v36", "v37",
"v38", "v39", "v40", "v41", "v42", "v43", "v44", "v45", "v46",
"v47", "v48", "v49", "v50", "v51", "v52", "v53", "v54", "v55",
"v56", "v57", "v58", "v59", "v60", "v61", "v62", "v63", "v64",
"v20", "v21", "v22", "v23", "v24", "v25", "v50", "v51", "v52", "v53", "v54", "v55",
"v56", "v57", "v64",
"v65", "v66", "v67", "v68", "v69", "v70", "v71", "v72", "v73",
"v74", "v75", "v76", "v77", "v78", "v79", "v80", "v81", "v82",
"v83", "v84", "v85", "v86", "v87", "v88", "v89", "v90", "v91",
......@@ -364,18 +361,19 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
// [c30]"+v"(v_c30),
// [c31]"+v"(v_c31)
:[sld_a_base]"n"(0),
[shfl_base]"n"(0),
[v_sld_y_os]"v"(sld_y_os),
[v_sfl_sld]"v"(sfl_sld),
[v_sfl_sst]"v"(sfl_sst),
// [shfl_base]"n"(0),
// [v_sld_y_os]"v"(sld_y_os),
// [v_sfl_sld]"v"(sfl_sld),
// [v_sfl_sst]"v"(sfl_sst),
[s_res_dq]"s"(res_dq),
[s_res_o0]"s"(res_o[0]),
[s_res_o1]"s"(res_o[1]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]),
[s_res_b0]"s"(res_b[0]),
[s_res_b1]"s"(res_b[1]),
[s_res_b2]"s"(res_b[2]),
[s_res_b3]"s"(res_b[3]),
[s_res_d]"s"(res_b),
// [s_res_b1]"s"(res_b[1]),
// [s_res_b2]"s"(res_b[2]),
// [s_res_b3]"s"(res_b[3]),
[v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
[v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
[v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
......@@ -392,10 +390,10 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
[s_tile_os_b_half]"s"(tile_offset_half_b_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes),
[s_tile_os_dq]"s"(tile_stride_dq_bytes),
[scale_0]"v"(s0),
[scale_1]"v"(s1),
[v_nan_lo]"v"(nan_lo),
[v_nan_hi]"v"(nan_hi),
// [scale_0]"v"(s0),
// [scale_1]"v"(s1),
// [v_nan_lo]"v"(nan_lo),
// [v_nan_hi]"v"(nan_hi),
[s_execflag_0]"s"(o_flags[number<0>{}]),
[s_execflag_1]"s"(o_flags[number<1>{}]),
[s_execflag_2]"s"(o_flags[number<2>{}]),
......@@ -435,21 +433,15 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s6", "s7", "s8", "s9", "s10", "s11", "s12", "s13", "s14", "s15",
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23", "s24", "s25",
"s26", "s27", "s28", "s29", "s30", "s31", "s32", "s33", "s34", "s35",
"s36", "s37", "s38", "s39", "s40", "s41", "s42", "s43", "s44", "s45",
"s6", "s7", "s40", "s41", "s42", "s43", "s44", "s45",
"s46", "s47", "s48", "s49", "s50", "s51", "s52", "s53", "s54",
"s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63",
"s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72",
"s73", "s74", "s75", "s76", "s77", "s78", "s79", "s80", // s86 as tmp
"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "v32", "v33", "v34", "v35", "v36", "v37",
"v38", "v39", "v40", "v41", "v42", "v43", "v44", "v45", "v46",
"v47", "v48", "v49", "v50", "v51", "v52", "v53", "v54", "v55",
"v56", "v57", "v58", "v59", "v60", "v61", "v62", "v63", "v64",
"v20", "v21", "v22", "v23", "v24", "v25", "v50", "v51", "v52", "v53", "v54", "v55",
"v56", "v57", "v64",
"v65", "v66", "v67", "v68", "v69", "v70", "v71", "v72", "v73",
"v74", "v75", "v76", "v77", "v78", "v79", "v80", "v81", "v82",
"v83", "v84", "v85", "v86", "v87", "v88", "v89", "v90", "v91",
......
......@@ -160,7 +160,7 @@
" ds_write_b64 %[v_sfl_sst], [%[c6],%[c7]] offset:23168 \n"
" s_mov_b32 s80, 0 \n"
" s_waitcnt vmcnt(24) \n"
"label_0AA6: \n"
"L_start%=: \n"
" s_waitcnt vmcnt(30) & lgkmcnt(0) \n"
" s_barrier \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[0:1], v[128:129], 0 \n"
......@@ -398,7 +398,7 @@ _UK_PK_CVT_("%[c12]", "%[c13]", "%[c6]")
_UK_PK_CVT_("%[c14]", "%[c15]", "%[c7]")
" s_addk_i32 s80, 0x0080 \n"
" s_cmp_lt_i32 s80, %[s_loop_cnt] \n"
" s_cbranch_scc0 label_0EC1 \n"
" s_cbranch_scc0 L_end%= \n"
" s_waitcnt vmcnt(30) & lgkmcnt(0) \n"
" s_barrier \n"
_UK_MFMA_ " [%[c16], %[c17], %[c18], %[c19]], acc[128:129], v[128:129], 0 \n"
......@@ -636,9 +636,9 @@ _UK_PK_CVT_("%[c28]", "%[c29]", "%[c22]")
_UK_PK_CVT_("%[c30]", "%[c31]", "%[c23]")
" s_addk_i32 s80, 0x0080 \n"
" s_cmp_lt_i32 s80, %[s_loop_cnt] \n"
" s_cbranch_scc0 label_0EC1 \n"
" s_branch label_0AA6 \n"
" label_0EC1: \n"
" s_cbranch_scc0 L_end%= \n"
" s_branch L_start%= \n"
" L_end%=: \n"
" s_waitcnt lgkmcnt(0) \n"
" s_barrier \n"
" ds_read_b32 v10, %[v_sfl_sld] offset:16640 \n"
......
......@@ -27,14 +27,8 @@
# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_f16"
#endif
" s_mov_b32 s8, %[s_res_o0] \n"
" s_mov_b32 s9, %[s_res_o1] \n"
" s_mov_b32 s12, %[s_res_b0] \n"
" s_mov_b32 s13, %[s_res_b1] \n"
" s_mov_b32 s14, %[s_res_b2] \n"
" s_mov_b32 s15, %[s_res_b3] \n"
" s_waitcnt vmcnt(24) \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen\n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], %[s_res_d], 0 offen\n"
" v_mul_f32 v54, v128, v128 \n"
" v_mul_f32 v55, v129, v129 \n"
" v_mul_f32 v56, v130, v130 \n"
......@@ -55,7 +49,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], %[s_res_d], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -68,7 +62,7 @@
" v_mul_f32 v129, v129, v55 \n"
" v_mul_f32 v130, v130, v56 \n"
" v_mul_f32 v131, v131, v57 \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], %[s_res_d], 0 offen offset:2048\n"
" v_mul_f32 v54, v132, v132 \n"
" v_mul_f32 v55, v133, v133 \n"
" v_mul_f32 v56, v134, v134 \n"
......@@ -89,7 +83,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], %[s_res_d], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -102,7 +96,7 @@
" v_mul_f32 v133, v133, v55 \n"
" v_mul_f32 v134, v134, v56 \n"
" v_mul_f32 v135, v135, v57 \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen\n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], %[s_res_d], 0 offen\n"
" v_mul_f32 v54, v136, v136 \n"
" v_mul_f32 v55, v137, v137 \n"
" v_mul_f32 v56, v138, v138 \n"
......@@ -123,7 +117,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], %[s_res_d], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -136,7 +130,7 @@
" v_mul_f32 v137, v137, v55 \n"
" v_mul_f32 v138, v138, v56 \n"
" v_mul_f32 v139, v139, v57 \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], %[s_res_d], 0 offen offset:2048\n"
" v_mul_f32 v54, v140, v140 \n"
" v_mul_f32 v55, v141, v141 \n"
" v_mul_f32 v56, v142, v142 \n"
......@@ -157,7 +151,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], %[s_res_d], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -171,7 +165,7 @@
" v_mul_f32 v142, v142, v56 \n"
" v_mul_f32 v143, v143, v57 \n"
" s_waitcnt vmcnt(24) \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen\n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], %[s_res_d], 0 offen\n"
" v_mul_f32 v54, v144, v144 \n"
" v_mul_f32 v55, v145, v145 \n"
" v_mul_f32 v56, v146, v146 \n"
......@@ -192,7 +186,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], %[s_res_d], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -205,7 +199,7 @@
" v_mul_f32 v145, v145, v55 \n"
" v_mul_f32 v146, v146, v56 \n"
" v_mul_f32 v147, v147, v57 \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], %[s_res_d], 0 offen offset:2048\n"
" v_mul_f32 v54, v148, v148 \n"
" v_mul_f32 v55, v149, v149 \n"
" v_mul_f32 v56, v150, v150 \n"
......@@ -226,7 +220,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], %[s_res_d], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -239,7 +233,7 @@
" v_mul_f32 v149, v149, v55 \n"
" v_mul_f32 v150, v150, v56 \n"
" v_mul_f32 v151, v151, v57 \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen\n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], %[s_res_d], 0 offen\n"
" v_mul_f32 v54, v152, v152 \n"
" v_mul_f32 v55, v153, v153 \n"
" v_mul_f32 v56, v154, v154 \n"
......@@ -260,7 +254,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], %[s_res_d], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -273,7 +267,7 @@
" v_mul_f32 v153, v153, v55 \n"
" v_mul_f32 v154, v154, v56 \n"
" v_mul_f32 v155, v155, v57 \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], %[s_res_d], 0 offen offset:2048\n"
" v_mul_f32 v54, v156, v156 \n"
" v_mul_f32 v55, v157, v157 \n"
" v_mul_f32 v56, v158, v158 \n"
......@@ -294,7 +288,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], %[s_res_d], 0 offen offset:3072\n"
" s_add_u32 s12, %[s_tile_os_b_half], s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" v_add_f32 v54, v54, 1.0 \n"
......@@ -310,7 +304,7 @@
" v_mul_f32 v158, v158, v56 \n"
" v_mul_f32 v159, v159, v57 \n"
" s_waitcnt vmcnt(24) \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b0], s[12:15], 0 offen\n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b0], %[s_res_d], 0 offen\n"
" v_mul_f32 v54, v160, v160 \n"
" v_mul_f32 v55, v161, v161 \n"
" v_mul_f32 v56, v162, v162 \n"
......@@ -331,7 +325,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b0], s[12:15], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b0], %[s_res_d], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -344,7 +338,7 @@
" v_mul_f32 v161, v161, v55 \n"
" v_mul_f32 v162, v162, v56 \n"
" v_mul_f32 v163, v163, v57 \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b0], s[12:15], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b0], %[s_res_d], 0 offen offset:2048\n"
" v_mul_f32 v54, v164, v164 \n"
" v_mul_f32 v55, v165, v165 \n"
" v_mul_f32 v56, v166, v166 \n"
......@@ -365,7 +359,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b0], s[12:15], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b0], %[s_res_d], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -378,7 +372,7 @@
" v_mul_f32 v165, v165, v55 \n"
" v_mul_f32 v166, v166, v56 \n"
" v_mul_f32 v167, v167, v57 \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b1], s[12:15], 0 offen\n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b1], %[s_res_d], 0 offen\n"
" v_mul_f32 v54, v168, v168 \n"
" v_mul_f32 v55, v169, v169 \n"
" v_mul_f32 v56, v170, v170 \n"
......@@ -399,7 +393,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b1], s[12:15], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b1], %[s_res_d], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -412,7 +406,7 @@
" v_mul_f32 v169, v169, v55 \n"
" v_mul_f32 v170, v170, v56 \n"
" v_mul_f32 v171, v171, v57 \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b1], s[12:15], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b1], %[s_res_d], 0 offen offset:2048\n"
" v_mul_f32 v54, v172, v172 \n"
" v_mul_f32 v55, v173, v173 \n"
" v_mul_f32 v56, v174, v174 \n"
......@@ -433,7 +427,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b1], s[12:15], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b1], %[s_res_d], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -447,7 +441,7 @@
" v_mul_f32 v174, v174, v56 \n"
" v_mul_f32 v175, v175, v57 \n"
" s_waitcnt vmcnt(24) \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b2], s[12:15], 0 offen\n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b2], %[s_res_d], 0 offen\n"
" v_mul_f32 v54, v176, v176 \n"
" v_mul_f32 v55, v177, v177 \n"
" v_mul_f32 v56, v178, v178 \n"
......@@ -468,7 +462,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b2], s[12:15], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b2], %[s_res_d], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -481,7 +475,7 @@
" v_mul_f32 v177, v177, v55 \n"
" v_mul_f32 v178, v178, v56 \n"
" v_mul_f32 v179, v179, v57 \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b2], s[12:15], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b2], %[s_res_d], 0 offen offset:2048\n"
" v_mul_f32 v54, v180, v180 \n"
" v_mul_f32 v55, v181, v181 \n"
" v_mul_f32 v56, v182, v182 \n"
......@@ -502,7 +496,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b2], s[12:15], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b2], %[s_res_d], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -515,7 +509,7 @@
" v_mul_f32 v181, v181, v55 \n"
" v_mul_f32 v182, v182, v56 \n"
" v_mul_f32 v183, v183, v57 \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b3], s[12:15], 0 offen\n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b3], %[s_res_d], 0 offen\n"
" v_mul_f32 v54, v184, v184 \n"
" v_mul_f32 v55, v185, v185 \n"
" v_mul_f32 v56, v186, v186 \n"
......@@ -536,7 +530,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b3], s[12:15], 0 offen offset:1024\n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b3], %[s_res_d], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -549,7 +543,7 @@
" v_mul_f32 v185, v185, v55 \n"
" v_mul_f32 v186, v186, v56 \n"
" v_mul_f32 v187, v187, v57 \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b3], s[12:15], 0 offen offset:2048\n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b3], %[s_res_d], 0 offen offset:2048\n"
" v_mul_f32 v54, v188, v188 \n"
" v_mul_f32 v55, v189, v189 \n"
" v_mul_f32 v56, v190, v190 \n"
......@@ -570,7 +564,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b3], s[12:15], 0 offen offset:3072\n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b3], %[s_res_d], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -647,7 +641,7 @@
" v_mul_f32 v189, v19, v189 row_newbcast:13 \n"
" v_mul_f32 v190, v19, v190 row_newbcast:14 \n"
" v_mul_f32 v191, v19, v191 row_newbcast:15 \n"
" buffer_load_dword v12, v5, s[16:19], 0 offen \n"
" buffer_load_dword v12, v5, %[s_res_dq], 0 offen \n"
" v_mov_b32 v22, 0x358637bd \n"
" v_mov_b32 v23, 0x358637bd \n"
" v_max3_f32 v22, abs(v128), abs(v129), v22 \n"
......@@ -945,3 +939,4 @@
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
......@@ -102,7 +102,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
return MRepeat;
}
// TODO: properlly support scatter/gather
// TODO: properlly support scatter/gather for load only
CK_TILE_DEVICE auto GetRowCoords_A(index_t base_offset)
{
constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
......@@ -116,6 +116,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
return coords;
}
//for mma and A scale
CK_TILE_DEVICE auto GetRowCoords_A_mma(index_t base_offset)
{
// constexpr index_t KLans = 2;
......@@ -156,6 +157,22 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
return w;
}
template <typename ROW_COORDS>
CK_TILE_DEVICE auto GetAScale(const ROW_IDS row_ids_mma,
const AScaleDataType* a_scale_ptr)
{
constexpr index_t n_size = row_ids_mma.size();
array<TopkWeightDataType, n_size> w;
static_for<0, n_size, 1>{}([&](auto i) {
auto row_id = row_idx_mma[i] & 0xffffff;
auto itp_k = row_idx_mma[i] >> 24;
w.at(i) = sorted_weight_ptr[row_id *kargs.topk+itp_k];
});
return w;
}
// TODO: this row id is before shuffle atomic, need use acc distribution
CK_TILE_DEVICE auto GetRowCoords_O(index_t base_offset)
{
......@@ -203,7 +220,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
BlockShape::Block_Kr1); // intermediate_tile_id * Block_N / (N in W)
auto row_coords_a = GetRowCoords_A(sorted_tile_id * BlockShape::Block_M0);
auto row_coords_a_mma = GetRowCoords_A_mma(sorted_tile_id * BlockShape::Block_M0);
auto row_coords_a_mma = GetRowCoords_O(sorted_tile_id * BlockShape::Block_M0);
auto row_ids_a = GetRowID(
row_coords_a, reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr));
auto row_ids_a_mma = GetRowID(
......@@ -221,7 +238,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
//addr in fact
auto a_coords = generate_tuple(
[&](auto i) {
return (row_ids_a[i]) * kargs.stride_token +
return ((row_ids_a[i])&0xffffff) * kargs.stride_token +
threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA;
},
number<row_ids_a.size()>{});
......@@ -231,9 +248,11 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
//////aq
auto aq_win = [&]() {
const AScaleDataType* aq_ptr = reinterpret_cast<const AScaleDataType*>(kargs.a_scale_ptr);
auto aq_view_ = make_naive_tensor_view_packed<address_space_enum::global>(
auto aq_view_ = make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.num_tokens * kargs.topk),
make_tuple(1),
number<1>{},
number<1>{});
return aq_view_;
......@@ -272,9 +291,11 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
static_cast<long_index_t>(expert_id) * g_scale_expert_stride_0 +
intermediate_tile_id * BlockShape::Block_N0;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline
auto gq_view_ = make_naive_tensor_view_packed<address_space_enum::global>(
auto gq_view_ = make_naive_tensor_view<address_space_enum::global>(
gq_ptr,
make_tuple(shared_intermediate_size_1),
make_tuple(1),
number<1>{},
number<1>{});
return gq_view_;
......@@ -287,9 +308,11 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
static_cast<long_index_t>(expert_id) * smq_scale_expert_stride_0 +
intermediate_tile_id * BlockShape::Block_N0;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline
auto smq_view_ = make_naive_tensor_view_packed<address_space_enum::global>(
auto smq_view_ = make_naive_tensor_view<address_space_enum::global>(
smq_ptr,
make_tuple(shared_intermediate_size_1),
make_tuple(1),
number<1>{},
number<1>{});
return smq_view_;
......@@ -393,12 +416,14 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto row_coords_o = GetRowCoords_O(sorted_tile_id * BlockShape::Block_M0);
auto w_scale = GetWeightScale(
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
auto a_scale = GetAScale(
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.a_scale_ptr));
auto uk_0 = Policy::template GetUK_0<Problem>();
// auto acc_0= uk_0(
uk_0(
row_ids_a_mma,//fake token id, 2D index for X scale
aq_res,
a_scale,
dq_res,
gq_res,
smq_res,
......@@ -430,7 +455,8 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
// block_sync_lds();
auto uk_1 = Policy::template GetUK_1<Problem>();
uk_1(d_res,
uk_1(dq_res,
d_res,
d_coords,
o_res,
o_coords,
......
......@@ -17,7 +17,7 @@ fi
cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker -save-temps=$PWD" \
-D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \
-D GPU_TARGETS=$GPU_TARGETS \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment