Commit 339a674b authored by shengnxu's avatar shengnxu
Browse files

current status: single WG, memory out of bound

parent 5d00b37e
......@@ -19,14 +19,14 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
r = fused_moegemm_<t_>(s, a);
// using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
// r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
r = fused_moegemm_<t_>(s, a);
// using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
// r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "int8" && t.prec_w == "int8" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
......
......@@ -7,8 +7,8 @@
#include "fused_moegemm_api_internal.hpp"
// clang-format off
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// template float fused_moegemm_<
// fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
// >(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
......@@ -7,8 +7,8 @@
#include "fused_moegemm_api_internal.hpp"
// clang-format off
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// template float fused_moegemm_<
// fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
// >(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
......@@ -87,11 +87,11 @@ void topid_unique_gen(
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("t", "128", "num input tokens")
.insert("e", "32", "num of experts")
.insert("k", "5", "topk")
.insert("h", "8192", "hidden_size of this model")
.insert("i", "8192", "intermediate_size between 2 gemms of FFN")
arg_parser.insert("t", "32", "num input tokens")
.insert("e", "1", "num of experts")
.insert("k", "1", "topk")
.insert("h", "256", "hidden_size of this model")
.insert("i", "4096", "intermediate_size between 2 gemms of FFN")
.insert("stride", "-1", "stride per row, if -1 then equal to hidden_size")
.insert("bm", "32", "blocking factor for sorted tokens")
.insert("tp", "8", "tensor parallel size")
......
......@@ -242,6 +242,7 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
{
using ADataType = int8_t;
using BDataType = int8_t;
using AScaleDataType = float;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
......@@ -258,7 +259,8 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
CK_TILE_LDS_ADDR void* smem,
index_t k,
index_t tile_offset_a, // for each tile, the offset to move for each unroll
index_t tile_offset_b) // for each tile, the offset to move for each unroll
index_t tile_offset_b,
index_t a_bound_) // for each tile, the offset to move for each unroll
{
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 4 /*2x per dword*/); // 8
static_assert(BCoords::size() == Repeat_N);
......@@ -449,7 +451,10 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
[c62]"+v"(v_z62),
[c63]"+v"(v_z63),
[s_mem_]"+r"(smem)
: [a_scale0]"v"(a_scale_[0]),
:
[a_bound]"s"(static_cast<int>(a_bound_ * sizeof(ADataType))),
// [a_scale_bound]"s"(a_scale_bound_ * sizeof(AScaleDataType)),
[a_scale0]"v"(a_scale_[0]),
[a_scale1]"v"(a_scale_[1]),
[gq_scale0]"v"(gq_scale_[0]),
[gq_scale1]"v"(gq_scale_[1]),
......
......@@ -81,6 +81,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
template <
// typename DQRes,
// typename BRes,
typename Tokenids,
typename DQCoords,
typename BCoords,
typename ORes,
......@@ -92,6 +93,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
operator()(
// const DQRes& res_dq,
// const BRes& res_b,
const Tokenids& token_id_,
const DQCoords& cached_coords_dq,
const BCoords& cached_coords_b,
const ORes& res_o,
......@@ -108,7 +110,6 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
{
static_assert(BCoords::size() == 4); // 8
static_assert(OCoords::size() == 8);
const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
const index_t tile_offset_half_b_bytes = tile_offset_half_b * sizeof(BDataType);
const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
......@@ -155,8 +156,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8
#include "uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_1.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:[smem_]"+r"(smem),
[s_loop_cnt]"+s"(loop_cnt)
:[smem_]"+r"(smem)
// [s_loop_cnt]"+s"(loop_cnt)
:[sld_a_base]"n"(0),
// [shfl_base]"n"(0),
// [v_sld_y_os]"v"(sld_y_os),
......@@ -164,8 +165,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
// [v_sfl_sst]"v"(sfl_sst),
[smq_scale0]"s"(smq_scale_[0]),
[smq_scale1]"s"(smq_scale_[1]),
[s_res_o0]"s"(res_o[0]),
[s_res_o1]"s"(res_o[1]),
// [s_res_o0]"s"(res_o[0]),
// [s_res_o1]"s"(res_o[1]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]),
[v_os_dq]"v"(static_cast<index_t>(cached_coords_dq * sizeof(DScaleDataType))),
......@@ -184,19 +185,19 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
[s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b_half]"s"(tile_offset_half_b_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes),
[s_tile_os_dq]"s"(tile_stride_dq_bytes),
[s_tile_os_dq]"s"(tile_stride_dq_bytes)
// [scale_0]"v"(s0),
// [scale_1]"v"(s1),
// [v_nan_lo]"v"(nan_lo),
// [v_nan_hi]"v"(nan_hi),
[s_execflag_0]"s"(o_flags[number<0>{}]),
[s_execflag_1]"s"(o_flags[number<1>{}]),
[s_execflag_2]"s"(o_flags[number<2>{}]),
[s_execflag_3]"s"(o_flags[number<3>{}]),
[s_execflag_4]"s"(o_flags[number<4>{}]),
[s_execflag_5]"s"(o_flags[number<5>{}]),
[s_execflag_6]"s"(o_flags[number<6>{}]),
[s_execflag_7]"s"(o_flags[number<7>{}])
// [s_execflag_0]"s"(o_flags[number<0>{}]),
// [s_execflag_1]"s"(o_flags[number<1>{}]),
// [s_execflag_2]"s"(o_flags[number<2>{}]),
// [s_execflag_3]"s"(o_flags[number<3>{}]),
// [s_execflag_4]"s"(o_flags[number<4>{}]),
// [s_execflag_5]"s"(o_flags[number<5>{}]),
// [s_execflag_6]"s"(o_flags[number<6>{}]),
// [s_execflag_7]"s"(o_flags[number<7>{}])
:
"memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
......@@ -228,7 +229,9 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s6", "s7", "s40", "s41", "s42", "s43", "s44", "s45",
"s6", "s7","s20", "s21", "s22", "s23", "s24", "s25", "s26", "s27",
"s28", "s29", "s30", "s31", "s34", "s35", "s38", "s39",
"s40", "s41", "s42", "s43", "s44", "s45",
"s46", "s47", "s48", "s49", "s50", "s51", "s52", "s53", "s54",
"s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63",
"s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72",
......@@ -260,12 +263,14 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
"v245", "v246", "v247", "v248", "v249", "v250", "v251", "v252",
"v253", "v254", "v255"
);
if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 &&
hipThreadIdx_x == 5)
{
printf("\n sn0 done\n");
}
// if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0)
// {
// // printf("\n xyz%x,%x,%x,thread idx:%xsn1 done\n",blockIdx.x, blockIdx.y, blockIdx.z ,threadIdx.x );
// printf("\n sn1 done\n");
// }
// return;
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8
#include "uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc"
......@@ -288,6 +293,123 @@ if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 &&
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b_half]"s"(tile_offset_half_b_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes),
[s_tile_os_dq]"s"(tile_stride_dq_bytes),
[scale_0]"v"(s0),
[scale_1]"v"(s1)
// [v_nan_lo]"v"(nan_lo),
// [v_nan_hi]"v"(nan_hi),
// [s_execflag_0]"s"(o_flags[number<0>{}]),
// [s_execflag_1]"s"(o_flags[number<1>{}]),
// [s_execflag_2]"s"(o_flags[number<2>{}]),
// [s_execflag_3]"s"(o_flags[number<3>{}]),
// [s_execflag_4]"s"(o_flags[number<4>{}]),
// [s_execflag_5]"s"(o_flags[number<5>{}]),
// [s_execflag_6]"s"(o_flags[number<6>{}]),
// [s_execflag_7]"s"(o_flags[number<7>{}])
:
"memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s6", "s7","s20", "s21", "s22", "s23", "s24", "s25", "s26", "s27",
"s28", "s29", "s30", "s31", "s38", "s39", "s34", "s35", "s40", "s41", "s42", "s43", "s44", "s45",
"s46", "s47", "s48", "s49", "s50", "s51", "s52", "s53", "s54",
"s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63",
"s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72",
"s73", "s74", "s75", "s76", "s77", "s78", "s79", "s80", // s86 as tmp
"v1", "v2", "v3", "v4", "v5", "v12", "v13", "v21", "v22", "v23", "v24", "v25", "v50", "v51", "v52", "v53", "v54", "v55",
"v56", "v57", "v64",
"v65", "v66", "v67", "v68", "v69", "v70", "v71", "v72", "v73",
"v74", "v75", "v76", "v77", "v78", "v79", "v80", "v81", "v82",
"v83", "v84", "v85", "v86", "v87", "v88", "v89", "v90", "v91",
"v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99", "v100",
"v101", "v102", "v103", "v104", "v105", "v106", "v107", "v108",
"v109", "v110", "v111", "v112", "v113", "v114", "v115", "v116",
"v117", "v118", "v119", "v120", "v121", "v122", "v123", "v124",
"v125", "v126", "v127", "v128", "v129", "v130", "v131", "v132",
"v133", "v134", "v135", "v136", "v137", "v138", "v139", "v140",
"v141", "v142", "v143", "v144", "v145", "v146", "v147", "v148",
"v149", "v150", "v151", "v152", "v153", "v154", "v155", "v156",
"v157", "v158", "v159", "v160", "v161", "v162", "v163", "v164",
"v165", "v166", "v167", "v168", "v169", "v170", "v171", "v172",
"v173", "v174", "v175", "v176", "v177", "v178", "v179", "v180",
"v181", "v182", "v183", "v184", "v185", "v186", "v187", "v188",
"v189", "v190", "v191", "v192", "v193", "v194", "v195", "v196",
"v197", "v198", "v199", "v200", "v201", "v202", "v203", "v204",
"v205", "v206", "v207", "v208", "v209", "v210", "v211", "v212",
"v213", "v214", "v215", "v216", "v217", "v218", "v219", "v220",
"v221", "v222", "v223", "v224", "v225", "v226", "v227", "v228",
"v229", "v230", "v231", "v232", "v233", "v234", "v235", "v236",
"v237", "v238", "v239", "v240", "v241", "v242", "v243", "v244",
"v245", "v246", "v247", "v248", "v249", "v250", "v251", "v252",
"v253", "v254", "v255"
);
// if(hipBlockIdx_x == 0 && hipBlockIdx_y == 1 && hipBlockIdx_z == 0 &&
// hipThreadIdx_x == 0)
// {
// printf("\n sn2 done\n");
// }
return;
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8
#include "uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_3.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:[smem_]"+r"(smem),
[s_loop_cnt]"+s"(loop_cnt)
:[sld_a_base]"n"(0),
[s_res_o0]"s"(res_o[0]),
[s_res_o1]"s"(res_o[1]),
[v_os_dq]"v"(static_cast<index_t>(cached_coords_dq * sizeof(DScaleDataType))),
[v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
[v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
[v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
[v_os_o3]"v"(static_cast<index_t>(cached_coords_o[number<3>{}] * sizeof(ODataType))),
[v_os_o4]"v"(static_cast<index_t>(cached_coords_o[number<4>{}] * sizeof(ODataType))),
[v_os_o5]"v"(static_cast<index_t>(cached_coords_o[number<5>{}] * sizeof(ODataType))),
[v_os_o6]"v"(static_cast<index_t>(cached_coords_o[number<6>{}] * sizeof(ODataType))),
[v_os_o7]"v"(static_cast<index_t>(cached_coords_o[number<7>{}] * sizeof(ODataType))),
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[s_token_id0]"s"(token_id_[number<0>{}]),
[s_token_id1]"s"(token_id_[number<1>{}]),
[s_token_id2]"s"(token_id_[number<2>{}]),
[s_token_id3]"s"(token_id_[number<3>{}]),
[s_token_id4]"s"(token_id_[number<4>{}]),
[s_token_id5]"s"(token_id_[number<5>{}]),
[s_token_id6]"s"(token_id_[number<6>{}]),
[s_token_id7]"s"(token_id_[number<7>{}]),
[s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b_half]"s"(tile_offset_half_b_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes),
......@@ -335,7 +457,8 @@ if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 &&
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s6", "s7", "s40", "s41", "s42", "s43", "s44", "s45",
"s6", "s7","s20", "s21", "s22", "s23", "s24", "s25", "s26", "s27",
"s28", "s29", "s30", "s31", "s38", "s39", "s34", "s35", "s40", "s41", "s42", "s43", "s44", "s45",
"s46", "s47", "s48", "s49", "s50", "s51", "s52", "s53", "s54",
"s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63",
"s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72",
......@@ -367,7 +490,7 @@ if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 &&
"v245", "v246", "v247", "v248", "v249", "v250", "v251", "v252",
"v253", "v254", "v255"
);
#pragma clang diagnostic pop
#pragma clang diagnostic pop
// clang-format on
}
};
......
......@@ -31,8 +31,8 @@
" v_lshrrev_b32 v3, 6, v0 \n"
" v_readfirstlane_b32 s7, v3 \n"
" s_waitcnt vmcnt(24) \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen\n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024\n"
"buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen\n"
"buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024\n"
" v_mul_f32 v54, v128, v128 \n"
" v_mul_f32 v55, v129, v129 \n"
" v_mul_f32 v56, v130, v130 \n"
......@@ -65,7 +65,7 @@
" v_mul_f32 v129, v129, v55 \n"
" v_mul_f32 v130, v130, v56 \n"
" v_mul_f32 v131, v131, v57 \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048\n"
"buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v132, v132 \n"
" v_mul_f32 v55, v133, v133 \n"
" v_mul_f32 v56, v134, v134 \n"
......@@ -86,7 +86,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072\n"
"buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -99,7 +99,7 @@
" v_mul_f32 v133, v133, v55 \n"
" v_mul_f32 v134, v134, v56 \n"
" v_mul_f32 v135, v135, v57 \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen\n"
"buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen\n"
" v_mul_f32 v54, v136, v136 \n"
" v_mul_f32 v55, v137, v137 \n"
" v_mul_f32 v56, v138, v138 \n"
......@@ -120,7 +120,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024\n"
"buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -133,7 +133,7 @@
" v_mul_f32 v137, v137, v55 \n"
" v_mul_f32 v138, v138, v56 \n"
" v_mul_f32 v139, v139, v57 \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048\n"
"buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v140, v140 \n"
" v_mul_f32 v55, v141, v141 \n"
" v_mul_f32 v56, v142, v142 \n"
......@@ -154,7 +154,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072\n"
"buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -168,7 +168,7 @@
" v_mul_f32 v142, v142, v56 \n"
" v_mul_f32 v143, v143, v57 \n"
" s_waitcnt vmcnt(24) \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen\n"
"buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen\n"
" v_mul_f32 v54, v144, v144 \n"
" v_mul_f32 v55, v145, v145 \n"
" v_mul_f32 v56, v146, v146 \n"
......@@ -189,7 +189,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024\n"
"buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -202,7 +202,7 @@
" v_mul_f32 v145, v145, v55 \n"
" v_mul_f32 v146, v146, v56 \n"
" v_mul_f32 v147, v147, v57 \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048\n"
"buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v148, v148 \n"
" v_mul_f32 v55, v149, v149 \n"
" v_mul_f32 v56, v150, v150 \n"
......@@ -223,7 +223,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072\n"
"buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -236,7 +236,7 @@
" v_mul_f32 v149, v149, v55 \n"
" v_mul_f32 v150, v150, v56 \n"
" v_mul_f32 v151, v151, v57 \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen\n"
"buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen\n"
" v_mul_f32 v54, v152, v152 \n"
" v_mul_f32 v55, v153, v153 \n"
" v_mul_f32 v56, v154, v154 \n"
......@@ -257,7 +257,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024\n"
"buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -270,7 +270,7 @@
" v_mul_f32 v153, v153, v55 \n"
" v_mul_f32 v154, v154, v56 \n"
" v_mul_f32 v155, v155, v57 \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048\n"
"buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v156, v156 \n"
" v_mul_f32 v55, v157, v157 \n"
" v_mul_f32 v56, v158, v158 \n"
......@@ -291,7 +291,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072\n"
"buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072\n"
" s_add_u32 s12, %[s_tile_os_b_half], s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" v_add_f32 v54, v54, 1.0 \n"
......@@ -307,7 +307,7 @@
" v_mul_f32 v158, v158, v56 \n"
" v_mul_f32 v159, v159, v57 \n"
" s_waitcnt vmcnt(24) \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b0], s[12:15], 0 offen\n"
"buffer_load_dwordx4 acc[64:67], %[v_os_b0], s[12:15], 0 offen\n"
" v_mul_f32 v54, v160, v160 \n"
" v_mul_f32 v55, v161, v161 \n"
" v_mul_f32 v56, v162, v162 \n"
......@@ -328,7 +328,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b0], s[12:15], 0 offen offset:1024\n"
"buffer_load_dwordx4 acc[68:71], %[v_os_b0], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -341,7 +341,7 @@
" v_mul_f32 v161, v161, v55 \n"
" v_mul_f32 v162, v162, v56 \n"
" v_mul_f32 v163, v163, v57 \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b0], s[12:15], 0 offen offset:2048\n"
"buffer_load_dwordx4 acc[72:75], %[v_os_b0], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v164, v164 \n"
" v_mul_f32 v55, v165, v165 \n"
" v_mul_f32 v56, v166, v166 \n"
......@@ -362,7 +362,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b0], s[12:15], 0 offen offset:3072\n"
"buffer_load_dwordx4 acc[76:79], %[v_os_b0], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -375,7 +375,7 @@
" v_mul_f32 v165, v165, v55 \n"
" v_mul_f32 v166, v166, v56 \n"
" v_mul_f32 v167, v167, v57 \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b1], s[12:15], 0 offen\n"
"buffer_load_dwordx4 acc[80:83], %[v_os_b1], s[12:15], 0 offen\n"
" v_mul_f32 v54, v168, v168 \n"
" v_mul_f32 v55, v169, v169 \n"
" v_mul_f32 v56, v170, v170 \n"
......@@ -396,7 +396,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b1], s[12:15], 0 offen offset:1024\n"
"buffer_load_dwordx4 acc[84:87], %[v_os_b1], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -409,7 +409,7 @@
" v_mul_f32 v169, v169, v55 \n"
" v_mul_f32 v170, v170, v56 \n"
" v_mul_f32 v171, v171, v57 \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b1], s[12:15], 0 offen offset:2048\n"
"buffer_load_dwordx4 acc[88:91], %[v_os_b1], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v172, v172 \n"
" v_mul_f32 v55, v173, v173 \n"
" v_mul_f32 v56, v174, v174 \n"
......@@ -430,7 +430,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b1], s[12:15], 0 offen offset:3072\n"
"buffer_load_dwordx4 acc[92:95], %[v_os_b1], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -444,7 +444,7 @@
" v_mul_f32 v174, v174, v56 \n"
" v_mul_f32 v175, v175, v57 \n"
" s_waitcnt vmcnt(24) \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b2], s[12:15], 0 offen\n"
"buffer_load_dwordx4 acc[96:99], %[v_os_b2], s[12:15], 0 offen\n"
" v_mul_f32 v54, v176, v176 \n"
" v_mul_f32 v55, v177, v177 \n"
" v_mul_f32 v56, v178, v178 \n"
......@@ -465,7 +465,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b2], s[12:15], 0 offen offset:1024\n"
"buffer_load_dwordx4 acc[100:103], %[v_os_b2], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -478,7 +478,7 @@
" v_mul_f32 v177, v177, v55 \n"
" v_mul_f32 v178, v178, v56 \n"
" v_mul_f32 v179, v179, v57 \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b2], s[12:15], 0 offen offset:2048\n"
"buffer_load_dwordx4 acc[104:107], %[v_os_b2], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v180, v180 \n"
" v_mul_f32 v55, v181, v181 \n"
" v_mul_f32 v56, v182, v182 \n"
......@@ -499,7 +499,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b2], s[12:15], 0 offen offset:3072\n"
"buffer_load_dwordx4 acc[108:111], %[v_os_b2], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -512,7 +512,7 @@
" v_mul_f32 v181, v181, v55 \n"
" v_mul_f32 v182, v182, v56 \n"
" v_mul_f32 v183, v183, v57 \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b3], s[12:15], 0 offen\n"
"buffer_load_dwordx4 acc[112:115], %[v_os_b3], s[12:15], 0 offen\n"
" v_mul_f32 v54, v184, v184 \n"
" v_mul_f32 v55, v185, v185 \n"
" v_mul_f32 v56, v186, v186 \n"
......@@ -533,7 +533,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b3], s[12:15], 0 offen offset:1024\n"
"buffer_load_dwordx4 acc[116:119], %[v_os_b3], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -546,7 +546,7 @@
" v_mul_f32 v185, v185, v55 \n"
" v_mul_f32 v186, v186, v56 \n"
" v_mul_f32 v187, v187, v57 \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b3], s[12:15], 0 offen offset:2048\n"
"buffer_load_dwordx4 acc[120:123], %[v_os_b3], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v188, v188 \n"
" v_mul_f32 v55, v189, v189 \n"
" v_mul_f32 v56, v190, v190 \n"
......@@ -567,7 +567,7 @@
" v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b3], s[12:15], 0 offen offset:3072\n"
"buffer_load_dwordx4 acc[124:127], %[v_os_b3], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n"
......@@ -644,7 +644,7 @@
" v_mul_f32 v189, %[smq_scale1], v189 row_newbcast:13 \n"
" v_mul_f32 v190, %[smq_scale1], v190 row_newbcast:14 \n"
" v_mul_f32 v191, %[smq_scale1], v191 row_newbcast:15 \n"
" buffer_load_dword v12, %[v_os_dq], s[16:19], 0 offen \n"
";--buffer_load_dword v12, %[v_os_dq], s[16:19], 0 offen \n"
" v_mov_b32 v22, 0x358637bd \n"
" v_mov_b32 v23, 0x358637bd \n"
" v_max3_f32 v22, abs(v128), abs(v129), v22 \n"
......@@ -974,3 +974,5 @@
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
......@@ -19,9 +19,9 @@
" v_mul_f32 " a2 ", " gq ", " a2 " row_newbcast: " brd2 " \n" \
" v_mul_f32 " a3 ", " gq ", " a3 " row_newbcast:" brd3 " \n"
" s_mov_b32 s22, %[a_bound] \n"
"s_mov_b32 s20, %[s_res_a0] \n"
"s_mov_b32 s21, %[s_res_a1] \n"
"s_mov_b32 s22, %[s_res_a2] \n"
"s_mov_b32 s23, %[s_res_a3] \n"
"s_mov_b32 s24, %[s_res_b0] \n"
"s_mov_b32 s25, %[s_res_b1] \n"
......@@ -110,38 +110,38 @@
" s_add_u32 s20, s57, s20 \n"
" s_addc_u32 s21, 0, s21 \n"
"; -- prefetch B0\n"
"buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[24:27], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[24:27], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[24:27], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[24:27], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[24:27], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[24:27], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[24:27], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[24:27], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[24:27], 0 offen \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[24:27], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[24:27], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[24:27], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[24:27], 0 offen \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[24:27], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[24:27], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[24:27], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[24:27], 0 offen \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[24:27], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[24:27], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[24:27], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[24:27], 0 offen \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[24:27], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[24:27], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[24:27], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[24:27], 0 offen \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[24:27], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[24:27], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[24:27], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[24:27], 0 offen \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[24:27], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[24:27], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[24:27], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[24:27], 0 offen \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[24:27], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[24:27], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[24:27], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[24:27], 0 offen \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[24:27], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[24:27], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[24:27], 0 offen offset:3072 \n"
"s_add_u32 s24, s58, s24 \n"
"s_addc_u32 s25, 0, s25 \n"
......
......@@ -237,12 +237,23 @@ struct FusedMoeGemmKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
if constexpr(UseUK)
{
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
__shared__ CK_TILE_LDS_ADDR ADataType smem[65536];
// index_t s_size = GetSmemSize();
// ADataType{}.aaa();
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
// __builtin_amdgcn_sched_barrier(0);
// if(threadIdx.x == 0){
// printf("num_sorted_tiles %d\n", num_sorted_tiles);
// printf("data type :%s\n", t2s<ADataType>::name);
// printf("\nblockIdx.x :%x, blockIdx.y :%x,\n", blockIdx.x, blockIdx.y);
// __builtin_amdgcn_sched_barrier(0);
// }
// __builtin_amdgcn_sched_barrier(0);
num_sorted_tiles = num_sorted_tiles / BlockShape::Block_M0;
const auto [sorted_tile_id, intermediate_tile_id] =
......
......@@ -74,7 +74,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
return max(smem_0, max(smem_1, smem_bridge));
return max(smem_0, max(smem_1, smem_bridge));
}
// this is the thread-offset along row/col
......@@ -159,15 +159,22 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
template <typename ROW_IDS>
CK_TILE_DEVICE auto GetAScale(const ROW_IDS row_ids_mma,
const AScaleDataType* a_scale_ptr)
// const AScaleDataType* a_scale_ptr, index_t num_tokens_)
index_t num_tokens_)
{
constexpr index_t n_size = row_ids_mma.size();
array<TopkWeightDataType, n_size> w;
static_for<0, n_size, 1>{}([&](auto i) {
auto row_id = row_ids_mma[i] & 0xffffff;
auto itp_k = row_ids_mma[i] >> 24;
w.at(i) = a_scale_ptr[row_id * 5+itp_k];
if (row_id >= num_tokens_)
{
w.at(i) = 0.f;
} else {
w.at(i) = 1.f;
// auto itp_k = row_ids_mma[i] >> 24;
// w.at(i) = a_scale_ptr[row_id * 5+itp_k];
}
});
return w;
......@@ -247,7 +254,12 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
index_t kr_0 = kargs.hidden_size / BlockShape::Warp_K0; // divide K in W
index_t nr_1 = kargs.hidden_size / BlockShape::Warp_N1;
index_t kr_1 = shared_intermediate_size_1 / BlockShape::Warp_K1;
// if(threadIdx.x == 31 && blockIdx.x == 0 && blockIdx.y == 0)
// {
// printf("\nWarpPerBlock_N0 :%x, WarpPerBlock_M0:%x,\n", BlockShape::WarpPerBlock_N0
// , BlockShape::WarpPerBlock_M0);
// }
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size;
......@@ -271,7 +283,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
row_coords_a_mma, reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr));
auto token_id = generate_tuple(
[&](auto i) {
return (row_ids_a[i]) &0xffffff;
return (row_ids_a[i] &0xffffff);
},
number<row_ids_a.size()>{});
auto a_coords = generate_tuple(
......@@ -385,10 +397,16 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
threadIdx.x % (BlockShape::Block_N1/2 / kAlignmentO) * kAlignmentO;
},
number<row_ids_a.size()>{});
auto o_flags =
generate_tuple([&](auto i) { return cmp_lt_to_exec(token_id[i], kargs.num_tokens); },
number<row_ids_a.size()>{});
// generate_tuple([&](auto i) {
// if (__builtin_amdgcn_readfirstlane(token_id[i]) < kargs.num_tokens)
// {return 0xffffffffffffffff;}
// else
// {return uint32x2_t 0;}
// },
number<token_id.size()>{});
auto o_res =
make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr),
......@@ -398,13 +416,59 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto w_scale = GetWeightScale(
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
auto a_scale = GetAScale(
row_ids_a_mma, reinterpret_cast<const AScaleDataType*>(kargs.a_scale_ptr));
// row_ids_a_mma, reinterpret_cast<const AScaleDataType*>(kargs.a_scale_ptr), kargs.num_tokens );
row_ids_a_mma, kargs.num_tokens );
auto gqsmq_coords = GetColCoords_GQSMQ(intermediate_tile_id * BlockShape::Block_K1);
auto dq_coords = gqsmq_coords[0];//only one for this tiling
auto gq_scale = GetGQScale(
gqsmq_coords, (reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr) + static_cast<long_index_t>(expert_id) * shared_intermediate_size_0));
auto smq_scale = GetSMQScale(
gqsmq_coords, (reinterpret_cast<const YSmoothScaleDataType*>(kargs.y_smooth_scale_ptr) + static_cast<long_index_t>(expert_id) * shared_intermediate_size_0));
if(threadIdx.x == 95 && blockIdx.x == 0 && blockIdx.y == 0)
{
printf("\nblockIdx.x :%x, blockIdx.y :%x, d ptr: %p, wg d ptr :%x%x,gemm0 done\n", blockIdx.x, blockIdx.y, kargs.d_ptr,d_res[1],d_res[0]);
// // printf("\n wg 1 1, wave 1, row_coords_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_coords_a[number<0>{}],row_coords_a[number<1>{}],row_coords_a[number<2>{}],row_coords_a[number<3>{}], row_coords_a[number<4>{}],row_coords_a[number<5>{}],row_coords_a[number<6>{}],row_coords_a[number<7>{}]);
// // printf("\n -------------- -row_ids_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_ids_a[number<0>{}],row_ids_a[number<1>{}],row_ids_a[number<2>{}],row_ids_a[number<3>{}], row_ids_a[number<4>{}],row_ids_a[number<5>{}],row_ids_a[number<6>{}],row_ids_a[number<7>{}]);
// // printf("\n -----------thread id %x--- - token_id 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", hipThreadIdx_x , token_id[number<0>{}],token_id[number<1>{}],token_id[number<2>{}],token_id[number<3>{}], token_id[number<4>{}],token_id[number<5>{}],token_id[number<6>{}],token_id[number<7>{}]);
// // printf("\n -----------thread id %x--- - token_id , 7:%x,, \n", hipThreadIdx_x , token_id[number<7>{}]);
// // printf("\n -------------- - exec 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", o_flags[number<0>{}][0],o_flags[number<1>{}][0],o_flags[number<2>{}][0],o_flags[number<3>{}][0], o_flags[number<4>{}][0],o_flags[number<5>{}][0],o_flags[number<6>{}][0],o_flags[number<7>{}][0]);
printf("\ntoken id :%x,%x,%x,%x, %x,%x,%x,%x \n d_coords: %x,%x,%x,%x, \n row_idx: %x,%x,%x,%x, %x,%x,%x,%x \n o_flags:%x,%x,%x,%x, %x,%x,%x,%x \n",
token_id[number<0>{}],
token_id[number<1>{}],
token_id[number<2>{}],
token_id[number<3>{}],
token_id[number<4>{}],
token_id[number<5>{}],
token_id[number<6>{}],
token_id[number<7>{}],
d_coords[number<0>{}],
d_coords[number<1>{}],
d_coords[number<2>{}],
d_coords[number<3>{}],
// d_coords[number<4>{}],
// d_coords[number<5>{}],
// d_coords[number<6>{}],
// d_coords[number<7>{}],
row_ids_a[number<0>{}],
row_ids_a[number<1>{}],
row_ids_a[number<2>{}],
row_ids_a[number<3>{}],
row_ids_a[number<4>{}],
row_ids_a[number<5>{}],
row_ids_a[number<6>{}],
row_ids_a[number<7>{}],
o_flags[number<0>{}][0],
o_flags[number<1>{}][0],
o_flags[number<2>{}][0],
o_flags[number<3>{}][0],
o_flags[number<4>{}][0],
o_flags[number<5>{}][0],
o_flags[number<6>{}][0],
o_flags[number<7>{}][0]);
// return;
}
__builtin_amdgcn_sched_barrier(0);
auto uk_0 = Policy::template GetUK_0<Problem>();
// auto acc_0= uk_0(
uk_0(a_scale,
......@@ -418,17 +482,12 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
smem,
kargs.hidden_size,
BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 *
BlockShape::Block_W0); // tile offset for B matrix each unroll
if(hipBlockIdx_x == 1 && hipBlockIdx_y == 1 && hipBlockIdx_z == 0 &&
hipThreadIdx_x == 64)
{
printf("\ngemm0 done\n");
// printf("\n wg 1 1, wave 1, row_coords_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_coords_a[number<0>{}],row_coords_a[number<1>{}],row_coords_a[number<2>{}],row_coords_a[number<3>{}], row_coords_a[number<4>{}],row_coords_a[number<5>{}],row_coords_a[number<6>{}],row_coords_a[number<7>{}]);
// printf("\n -------------- -row_ids_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_ids_a[number<0>{}],row_ids_a[number<1>{}],row_ids_a[number<2>{}],row_ids_a[number<3>{}], row_ids_a[number<4>{}],row_ids_a[number<5>{}],row_ids_a[number<6>{}],row_ids_a[number<7>{}]);
printf("\n -------------- - token_id 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", token_id[number<0>{}],token_id[number<1>{}],token_id[number<2>{}],token_id[number<3>{}], token_id[number<4>{}],token_id[number<5>{}],token_id[number<6>{}],token_id[number<7>{}]);
}
// sweep_tile(
16*256,
kargs.num_tokens * kargs.stride_token); // tile offset for B matrix each unroll
// return;
__builtin_amdgcn_sched_barrier(0);
// // sweep_tile(
// acc_0,
// [&](auto idx0, auto idx1) {
// fp32x2_t v_{acc_0(idx0), acc_0(idx1)};
......@@ -449,6 +508,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
uk_1(
// dq_res,
// d_res,
token_id,
dq_coords,
d_coords,
o_res,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment