Commit 339a674b authored by shengnxu's avatar shengnxu
Browse files

current status: single WG, memory out of bound

parent 5d00b37e
...@@ -19,14 +19,14 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile: ...@@ -19,14 +19,14 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{ {
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; // using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
r = fused_moegemm_<t_>(s, a); // r = fused_moegemm_<t_>(s, a);
} }
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" && else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{ {
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; // using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
r = fused_moegemm_<t_>(s, a); // r = fused_moegemm_<t_>(s, a);
} }
else if(t.prec_i == "int8" && t.prec_w == "int8" && t.prec_o == "bf16" && t.prec_st == "fp32" && else if(t.prec_i == "int8" && t.prec_w == "int8" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
......
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
#include "fused_moegemm_api_internal.hpp" #include "fused_moegemm_api_internal.hpp"
// clang-format off // clang-format off
template float fused_moegemm_< // template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0> // fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a); // >(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on // clang-format on
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
#include "fused_moegemm_api_internal.hpp" #include "fused_moegemm_api_internal.hpp"
// clang-format off // clang-format off
template float fused_moegemm_< // template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0> // fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a); // >(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on // clang-format on
...@@ -87,11 +87,11 @@ void topid_unique_gen( ...@@ -87,11 +87,11 @@ void topid_unique_gen(
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("t", "128", "num input tokens") arg_parser.insert("t", "32", "num input tokens")
.insert("e", "32", "num of experts") .insert("e", "1", "num of experts")
.insert("k", "5", "topk") .insert("k", "1", "topk")
.insert("h", "8192", "hidden_size of this model") .insert("h", "256", "hidden_size of this model")
.insert("i", "8192", "intermediate_size between 2 gemms of FFN") .insert("i", "4096", "intermediate_size between 2 gemms of FFN")
.insert("stride", "-1", "stride per row, if -1 then equal to hidden_size") .insert("stride", "-1", "stride per row, if -1 then equal to hidden_size")
.insert("bm", "32", "blocking factor for sorted tokens") .insert("bm", "32", "blocking factor for sorted tokens")
.insert("tp", "8", "tensor parallel size") .insert("tp", "8", "tensor parallel size")
......
...@@ -242,6 +242,7 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16 ...@@ -242,6 +242,7 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
{ {
using ADataType = int8_t; using ADataType = int8_t;
using BDataType = int8_t; using BDataType = int8_t;
using AScaleDataType = float;
// TODO: need paired with tile_window_linear! // TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function! // TODO: need call init_raw() before call this function!
...@@ -258,7 +259,8 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16 ...@@ -258,7 +259,8 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
CK_TILE_LDS_ADDR void* smem, CK_TILE_LDS_ADDR void* smem,
index_t k, index_t k,
index_t tile_offset_a, // for each tile, the offset to move for each unroll index_t tile_offset_a, // for each tile, the offset to move for each unroll
index_t tile_offset_b) // for each tile, the offset to move for each unroll index_t tile_offset_b,
index_t a_bound_) // for each tile, the offset to move for each unroll
{ {
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 4 /*2x per dword*/); // 8 static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 4 /*2x per dword*/); // 8
static_assert(BCoords::size() == Repeat_N); static_assert(BCoords::size() == Repeat_N);
...@@ -449,7 +451,10 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16 ...@@ -449,7 +451,10 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
[c62]"+v"(v_z62), [c62]"+v"(v_z62),
[c63]"+v"(v_z63), [c63]"+v"(v_z63),
[s_mem_]"+r"(smem) [s_mem_]"+r"(smem)
: [a_scale0]"v"(a_scale_[0]), :
[a_bound]"s"(static_cast<int>(a_bound_ * sizeof(ADataType))),
// [a_scale_bound]"s"(a_scale_bound_ * sizeof(AScaleDataType)),
[a_scale0]"v"(a_scale_[0]),
[a_scale1]"v"(a_scale_[1]), [a_scale1]"v"(a_scale_[1]),
[gq_scale0]"v"(gq_scale_[0]), [gq_scale0]"v"(gq_scale_[0]),
[gq_scale1]"v"(gq_scale_[1]), [gq_scale1]"v"(gq_scale_[1]),
......
...@@ -81,6 +81,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -81,6 +81,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
template < template <
// typename DQRes, // typename DQRes,
// typename BRes, // typename BRes,
typename Tokenids,
typename DQCoords, typename DQCoords,
typename BCoords, typename BCoords,
typename ORes, typename ORes,
...@@ -92,6 +93,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -92,6 +93,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
operator()( operator()(
// const DQRes& res_dq, // const DQRes& res_dq,
// const BRes& res_b, // const BRes& res_b,
const Tokenids& token_id_,
const DQCoords& cached_coords_dq, const DQCoords& cached_coords_dq,
const BCoords& cached_coords_b, const BCoords& cached_coords_b,
const ORes& res_o, const ORes& res_o,
...@@ -108,7 +110,6 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -108,7 +110,6 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
{ {
static_assert(BCoords::size() == 4); // 8 static_assert(BCoords::size() == 4); // 8
static_assert(OCoords::size() == 8); static_assert(OCoords::size() == 8);
const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType); const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
const index_t tile_offset_half_b_bytes = tile_offset_half_b * sizeof(BDataType); const index_t tile_offset_half_b_bytes = tile_offset_half_b * sizeof(BDataType);
const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType); const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
...@@ -155,8 +156,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -155,8 +156,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8
#include "uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_1.inc" #include "uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_1.inc"
#undef CK_TILE_FLATMM_UK_MFMA #undef CK_TILE_FLATMM_UK_MFMA
:[smem_]"+r"(smem), :[smem_]"+r"(smem)
[s_loop_cnt]"+s"(loop_cnt) // [s_loop_cnt]"+s"(loop_cnt)
:[sld_a_base]"n"(0), :[sld_a_base]"n"(0),
// [shfl_base]"n"(0), // [shfl_base]"n"(0),
// [v_sld_y_os]"v"(sld_y_os), // [v_sld_y_os]"v"(sld_y_os),
...@@ -164,8 +165,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -164,8 +165,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
// [v_sfl_sst]"v"(sfl_sst), // [v_sfl_sst]"v"(sfl_sst),
[smq_scale0]"s"(smq_scale_[0]), [smq_scale0]"s"(smq_scale_[0]),
[smq_scale1]"s"(smq_scale_[1]), [smq_scale1]"s"(smq_scale_[1]),
[s_res_o0]"s"(res_o[0]), // [s_res_o0]"s"(res_o[0]),
[s_res_o1]"s"(res_o[1]), // [s_res_o1]"s"(res_o[1]),
//[s_res_o2]"s"(res_o[2]), //[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]), //[s_res_o3]"s"(res_o[3]),
[v_os_dq]"v"(static_cast<index_t>(cached_coords_dq * sizeof(DScaleDataType))), [v_os_dq]"v"(static_cast<index_t>(cached_coords_dq * sizeof(DScaleDataType))),
...@@ -184,19 +185,19 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -184,19 +185,19 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
[s_tile_os_o]"s"(tile_stride_o_bytes), [s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b_half]"s"(tile_offset_half_b_bytes), [s_tile_os_b_half]"s"(tile_offset_half_b_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes), [s_tile_os_b]"s"(tile_stride_b_bytes),
[s_tile_os_dq]"s"(tile_stride_dq_bytes), [s_tile_os_dq]"s"(tile_stride_dq_bytes)
// [scale_0]"v"(s0), // [scale_0]"v"(s0),
// [scale_1]"v"(s1), // [scale_1]"v"(s1),
// [v_nan_lo]"v"(nan_lo), // [v_nan_lo]"v"(nan_lo),
// [v_nan_hi]"v"(nan_hi), // [v_nan_hi]"v"(nan_hi),
[s_execflag_0]"s"(o_flags[number<0>{}]), // [s_execflag_0]"s"(o_flags[number<0>{}]),
[s_execflag_1]"s"(o_flags[number<1>{}]), // [s_execflag_1]"s"(o_flags[number<1>{}]),
[s_execflag_2]"s"(o_flags[number<2>{}]), // [s_execflag_2]"s"(o_flags[number<2>{}]),
[s_execflag_3]"s"(o_flags[number<3>{}]), // [s_execflag_3]"s"(o_flags[number<3>{}]),
[s_execflag_4]"s"(o_flags[number<4>{}]), // [s_execflag_4]"s"(o_flags[number<4>{}]),
[s_execflag_5]"s"(o_flags[number<5>{}]), // [s_execflag_5]"s"(o_flags[number<5>{}]),
[s_execflag_6]"s"(o_flags[number<6>{}]), // [s_execflag_6]"s"(o_flags[number<6>{}]),
[s_execflag_7]"s"(o_flags[number<7>{}]) // [s_execflag_7]"s"(o_flags[number<7>{}])
: :
"memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19", "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
...@@ -228,7 +229,9 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -228,7 +229,9 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243", "a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251", "a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255", "a252", "a253", "a254", "a255",
"s6", "s7", "s40", "s41", "s42", "s43", "s44", "s45", "s6", "s7","s20", "s21", "s22", "s23", "s24", "s25", "s26", "s27",
"s28", "s29", "s30", "s31", "s34", "s35", "s38", "s39",
"s40", "s41", "s42", "s43", "s44", "s45",
"s46", "s47", "s48", "s49", "s50", "s51", "s52", "s53", "s54", "s46", "s47", "s48", "s49", "s50", "s51", "s52", "s53", "s54",
"s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63", "s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63",
"s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72", "s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72",
...@@ -260,12 +263,14 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -260,12 +263,14 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
"v245", "v246", "v247", "v248", "v249", "v250", "v251", "v252", "v245", "v246", "v247", "v248", "v249", "v250", "v251", "v252",
"v253", "v254", "v255" "v253", "v254", "v255"
); );
if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 &&
hipThreadIdx_x == 5)
{
printf("\n sn0 done\n");
} // if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0)
// {
// // printf("\n xyz%x,%x,%x,thread idx:%xsn1 done\n",blockIdx.x, blockIdx.y, blockIdx.z ,threadIdx.x );
// printf("\n sn1 done\n");
// }
// return;
asm volatile( asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8
#include "uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc" #include "uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc"
...@@ -288,6 +293,123 @@ if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 && ...@@ -288,6 +293,123 @@ if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 &&
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))), [v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))), [v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))), [v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b_half]"s"(tile_offset_half_b_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes),
[s_tile_os_dq]"s"(tile_stride_dq_bytes),
[scale_0]"v"(s0),
[scale_1]"v"(s1)
// [v_nan_lo]"v"(nan_lo),
// [v_nan_hi]"v"(nan_hi),
// [s_execflag_0]"s"(o_flags[number<0>{}]),
// [s_execflag_1]"s"(o_flags[number<1>{}]),
// [s_execflag_2]"s"(o_flags[number<2>{}]),
// [s_execflag_3]"s"(o_flags[number<3>{}]),
// [s_execflag_4]"s"(o_flags[number<4>{}]),
// [s_execflag_5]"s"(o_flags[number<5>{}]),
// [s_execflag_6]"s"(o_flags[number<6>{}]),
// [s_execflag_7]"s"(o_flags[number<7>{}])
:
"memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s6", "s7","s20", "s21", "s22", "s23", "s24", "s25", "s26", "s27",
"s28", "s29", "s30", "s31", "s38", "s39", "s34", "s35", "s40", "s41", "s42", "s43", "s44", "s45",
"s46", "s47", "s48", "s49", "s50", "s51", "s52", "s53", "s54",
"s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63",
"s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72",
"s73", "s74", "s75", "s76", "s77", "s78", "s79", "s80", // s86 as tmp
"v1", "v2", "v3", "v4", "v5", "v12", "v13", "v21", "v22", "v23", "v24", "v25", "v50", "v51", "v52", "v53", "v54", "v55",
"v56", "v57", "v64",
"v65", "v66", "v67", "v68", "v69", "v70", "v71", "v72", "v73",
"v74", "v75", "v76", "v77", "v78", "v79", "v80", "v81", "v82",
"v83", "v84", "v85", "v86", "v87", "v88", "v89", "v90", "v91",
"v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99", "v100",
"v101", "v102", "v103", "v104", "v105", "v106", "v107", "v108",
"v109", "v110", "v111", "v112", "v113", "v114", "v115", "v116",
"v117", "v118", "v119", "v120", "v121", "v122", "v123", "v124",
"v125", "v126", "v127", "v128", "v129", "v130", "v131", "v132",
"v133", "v134", "v135", "v136", "v137", "v138", "v139", "v140",
"v141", "v142", "v143", "v144", "v145", "v146", "v147", "v148",
"v149", "v150", "v151", "v152", "v153", "v154", "v155", "v156",
"v157", "v158", "v159", "v160", "v161", "v162", "v163", "v164",
"v165", "v166", "v167", "v168", "v169", "v170", "v171", "v172",
"v173", "v174", "v175", "v176", "v177", "v178", "v179", "v180",
"v181", "v182", "v183", "v184", "v185", "v186", "v187", "v188",
"v189", "v190", "v191", "v192", "v193", "v194", "v195", "v196",
"v197", "v198", "v199", "v200", "v201", "v202", "v203", "v204",
"v205", "v206", "v207", "v208", "v209", "v210", "v211", "v212",
"v213", "v214", "v215", "v216", "v217", "v218", "v219", "v220",
"v221", "v222", "v223", "v224", "v225", "v226", "v227", "v228",
"v229", "v230", "v231", "v232", "v233", "v234", "v235", "v236",
"v237", "v238", "v239", "v240", "v241", "v242", "v243", "v244",
"v245", "v246", "v247", "v248", "v249", "v250", "v251", "v252",
"v253", "v254", "v255"
);
// if(hipBlockIdx_x == 0 && hipBlockIdx_y == 1 && hipBlockIdx_z == 0 &&
// hipThreadIdx_x == 0)
// {
// printf("\n sn2 done\n");
// }
return;
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8
#include "uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_3.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:[smem_]"+r"(smem),
[s_loop_cnt]"+s"(loop_cnt)
:[sld_a_base]"n"(0),
[s_res_o0]"s"(res_o[0]),
[s_res_o1]"s"(res_o[1]),
[v_os_dq]"v"(static_cast<index_t>(cached_coords_dq * sizeof(DScaleDataType))),
[v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
[v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
[v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
[v_os_o3]"v"(static_cast<index_t>(cached_coords_o[number<3>{}] * sizeof(ODataType))),
[v_os_o4]"v"(static_cast<index_t>(cached_coords_o[number<4>{}] * sizeof(ODataType))),
[v_os_o5]"v"(static_cast<index_t>(cached_coords_o[number<5>{}] * sizeof(ODataType))),
[v_os_o6]"v"(static_cast<index_t>(cached_coords_o[number<6>{}] * sizeof(ODataType))),
[v_os_o7]"v"(static_cast<index_t>(cached_coords_o[number<7>{}] * sizeof(ODataType))),
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[s_token_id0]"s"(token_id_[number<0>{}]),
[s_token_id1]"s"(token_id_[number<1>{}]),
[s_token_id2]"s"(token_id_[number<2>{}]),
[s_token_id3]"s"(token_id_[number<3>{}]),
[s_token_id4]"s"(token_id_[number<4>{}]),
[s_token_id5]"s"(token_id_[number<5>{}]),
[s_token_id6]"s"(token_id_[number<6>{}]),
[s_token_id7]"s"(token_id_[number<7>{}]),
[s_tile_os_o]"s"(tile_stride_o_bytes), [s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b_half]"s"(tile_offset_half_b_bytes), [s_tile_os_b_half]"s"(tile_offset_half_b_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes), [s_tile_os_b]"s"(tile_stride_b_bytes),
...@@ -335,7 +457,8 @@ if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 && ...@@ -335,7 +457,8 @@ if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 &&
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243", "a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251", "a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255", "a252", "a253", "a254", "a255",
"s6", "s7", "s40", "s41", "s42", "s43", "s44", "s45", "s6", "s7","s20", "s21", "s22", "s23", "s24", "s25", "s26", "s27",
"s28", "s29", "s30", "s31", "s38", "s39", "s34", "s35", "s40", "s41", "s42", "s43", "s44", "s45",
"s46", "s47", "s48", "s49", "s50", "s51", "s52", "s53", "s54", "s46", "s47", "s48", "s49", "s50", "s51", "s52", "s53", "s54",
"s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63", "s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63",
"s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72", "s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72",
...@@ -367,7 +490,7 @@ if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 && ...@@ -367,7 +490,7 @@ if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 &&
"v245", "v246", "v247", "v248", "v249", "v250", "v251", "v252", "v245", "v246", "v247", "v248", "v249", "v250", "v251", "v252",
"v253", "v254", "v255" "v253", "v254", "v255"
); );
#pragma clang diagnostic pop #pragma clang diagnostic pop
// clang-format on // clang-format on
} }
}; };
......
...@@ -31,8 +31,8 @@ ...@@ -31,8 +31,8 @@
" v_lshrrev_b32 v3, 6, v0 \n" " v_lshrrev_b32 v3, 6, v0 \n"
" v_readfirstlane_b32 s7, v3 \n" " v_readfirstlane_b32 s7, v3 \n"
" s_waitcnt vmcnt(24) \n" " s_waitcnt vmcnt(24) \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen\n" "buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen\n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024\n" "buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024\n"
" v_mul_f32 v54, v128, v128 \n" " v_mul_f32 v54, v128, v128 \n"
" v_mul_f32 v55, v129, v129 \n" " v_mul_f32 v55, v129, v129 \n"
" v_mul_f32 v56, v130, v130 \n" " v_mul_f32 v56, v130, v130 \n"
...@@ -65,7 +65,7 @@ ...@@ -65,7 +65,7 @@
" v_mul_f32 v129, v129, v55 \n" " v_mul_f32 v129, v129, v55 \n"
" v_mul_f32 v130, v130, v56 \n" " v_mul_f32 v130, v130, v56 \n"
" v_mul_f32 v131, v131, v57 \n" " v_mul_f32 v131, v131, v57 \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048\n" "buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v132, v132 \n" " v_mul_f32 v54, v132, v132 \n"
" v_mul_f32 v55, v133, v133 \n" " v_mul_f32 v55, v133, v133 \n"
" v_mul_f32 v56, v134, v134 \n" " v_mul_f32 v56, v134, v134 \n"
...@@ -86,7 +86,7 @@ ...@@ -86,7 +86,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072\n" "buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -99,7 +99,7 @@ ...@@ -99,7 +99,7 @@
" v_mul_f32 v133, v133, v55 \n" " v_mul_f32 v133, v133, v55 \n"
" v_mul_f32 v134, v134, v56 \n" " v_mul_f32 v134, v134, v56 \n"
" v_mul_f32 v135, v135, v57 \n" " v_mul_f32 v135, v135, v57 \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen\n" "buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen\n"
" v_mul_f32 v54, v136, v136 \n" " v_mul_f32 v54, v136, v136 \n"
" v_mul_f32 v55, v137, v137 \n" " v_mul_f32 v55, v137, v137 \n"
" v_mul_f32 v56, v138, v138 \n" " v_mul_f32 v56, v138, v138 \n"
...@@ -120,7 +120,7 @@ ...@@ -120,7 +120,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024\n" "buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -133,7 +133,7 @@ ...@@ -133,7 +133,7 @@
" v_mul_f32 v137, v137, v55 \n" " v_mul_f32 v137, v137, v55 \n"
" v_mul_f32 v138, v138, v56 \n" " v_mul_f32 v138, v138, v56 \n"
" v_mul_f32 v139, v139, v57 \n" " v_mul_f32 v139, v139, v57 \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048\n" "buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v140, v140 \n" " v_mul_f32 v54, v140, v140 \n"
" v_mul_f32 v55, v141, v141 \n" " v_mul_f32 v55, v141, v141 \n"
" v_mul_f32 v56, v142, v142 \n" " v_mul_f32 v56, v142, v142 \n"
...@@ -154,7 +154,7 @@ ...@@ -154,7 +154,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072\n" "buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -168,7 +168,7 @@ ...@@ -168,7 +168,7 @@
" v_mul_f32 v142, v142, v56 \n" " v_mul_f32 v142, v142, v56 \n"
" v_mul_f32 v143, v143, v57 \n" " v_mul_f32 v143, v143, v57 \n"
" s_waitcnt vmcnt(24) \n" " s_waitcnt vmcnt(24) \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen\n" "buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen\n"
" v_mul_f32 v54, v144, v144 \n" " v_mul_f32 v54, v144, v144 \n"
" v_mul_f32 v55, v145, v145 \n" " v_mul_f32 v55, v145, v145 \n"
" v_mul_f32 v56, v146, v146 \n" " v_mul_f32 v56, v146, v146 \n"
...@@ -189,7 +189,7 @@ ...@@ -189,7 +189,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024\n" "buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -202,7 +202,7 @@ ...@@ -202,7 +202,7 @@
" v_mul_f32 v145, v145, v55 \n" " v_mul_f32 v145, v145, v55 \n"
" v_mul_f32 v146, v146, v56 \n" " v_mul_f32 v146, v146, v56 \n"
" v_mul_f32 v147, v147, v57 \n" " v_mul_f32 v147, v147, v57 \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048\n" "buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v148, v148 \n" " v_mul_f32 v54, v148, v148 \n"
" v_mul_f32 v55, v149, v149 \n" " v_mul_f32 v55, v149, v149 \n"
" v_mul_f32 v56, v150, v150 \n" " v_mul_f32 v56, v150, v150 \n"
...@@ -223,7 +223,7 @@ ...@@ -223,7 +223,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072\n" "buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -236,7 +236,7 @@ ...@@ -236,7 +236,7 @@
" v_mul_f32 v149, v149, v55 \n" " v_mul_f32 v149, v149, v55 \n"
" v_mul_f32 v150, v150, v56 \n" " v_mul_f32 v150, v150, v56 \n"
" v_mul_f32 v151, v151, v57 \n" " v_mul_f32 v151, v151, v57 \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen\n" "buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen\n"
" v_mul_f32 v54, v152, v152 \n" " v_mul_f32 v54, v152, v152 \n"
" v_mul_f32 v55, v153, v153 \n" " v_mul_f32 v55, v153, v153 \n"
" v_mul_f32 v56, v154, v154 \n" " v_mul_f32 v56, v154, v154 \n"
...@@ -257,7 +257,7 @@ ...@@ -257,7 +257,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024\n" "buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -270,7 +270,7 @@ ...@@ -270,7 +270,7 @@
" v_mul_f32 v153, v153, v55 \n" " v_mul_f32 v153, v153, v55 \n"
" v_mul_f32 v154, v154, v56 \n" " v_mul_f32 v154, v154, v56 \n"
" v_mul_f32 v155, v155, v57 \n" " v_mul_f32 v155, v155, v57 \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048\n" "buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v156, v156 \n" " v_mul_f32 v54, v156, v156 \n"
" v_mul_f32 v55, v157, v157 \n" " v_mul_f32 v55, v157, v157 \n"
" v_mul_f32 v56, v158, v158 \n" " v_mul_f32 v56, v158, v158 \n"
...@@ -291,7 +291,7 @@ ...@@ -291,7 +291,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072\n" "buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072\n"
" s_add_u32 s12, %[s_tile_os_b_half], s12 \n" " s_add_u32 s12, %[s_tile_os_b_half], s12 \n"
" s_addc_u32 s13, 0, s13 \n" " s_addc_u32 s13, 0, s13 \n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
...@@ -307,7 +307,7 @@ ...@@ -307,7 +307,7 @@
" v_mul_f32 v158, v158, v56 \n" " v_mul_f32 v158, v158, v56 \n"
" v_mul_f32 v159, v159, v57 \n" " v_mul_f32 v159, v159, v57 \n"
" s_waitcnt vmcnt(24) \n" " s_waitcnt vmcnt(24) \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b0], s[12:15], 0 offen\n" "buffer_load_dwordx4 acc[64:67], %[v_os_b0], s[12:15], 0 offen\n"
" v_mul_f32 v54, v160, v160 \n" " v_mul_f32 v54, v160, v160 \n"
" v_mul_f32 v55, v161, v161 \n" " v_mul_f32 v55, v161, v161 \n"
" v_mul_f32 v56, v162, v162 \n" " v_mul_f32 v56, v162, v162 \n"
...@@ -328,7 +328,7 @@ ...@@ -328,7 +328,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b0], s[12:15], 0 offen offset:1024\n" "buffer_load_dwordx4 acc[68:71], %[v_os_b0], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -341,7 +341,7 @@ ...@@ -341,7 +341,7 @@
" v_mul_f32 v161, v161, v55 \n" " v_mul_f32 v161, v161, v55 \n"
" v_mul_f32 v162, v162, v56 \n" " v_mul_f32 v162, v162, v56 \n"
" v_mul_f32 v163, v163, v57 \n" " v_mul_f32 v163, v163, v57 \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b0], s[12:15], 0 offen offset:2048\n" "buffer_load_dwordx4 acc[72:75], %[v_os_b0], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v164, v164 \n" " v_mul_f32 v54, v164, v164 \n"
" v_mul_f32 v55, v165, v165 \n" " v_mul_f32 v55, v165, v165 \n"
" v_mul_f32 v56, v166, v166 \n" " v_mul_f32 v56, v166, v166 \n"
...@@ -362,7 +362,7 @@ ...@@ -362,7 +362,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b0], s[12:15], 0 offen offset:3072\n" "buffer_load_dwordx4 acc[76:79], %[v_os_b0], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -375,7 +375,7 @@ ...@@ -375,7 +375,7 @@
" v_mul_f32 v165, v165, v55 \n" " v_mul_f32 v165, v165, v55 \n"
" v_mul_f32 v166, v166, v56 \n" " v_mul_f32 v166, v166, v56 \n"
" v_mul_f32 v167, v167, v57 \n" " v_mul_f32 v167, v167, v57 \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b1], s[12:15], 0 offen\n" "buffer_load_dwordx4 acc[80:83], %[v_os_b1], s[12:15], 0 offen\n"
" v_mul_f32 v54, v168, v168 \n" " v_mul_f32 v54, v168, v168 \n"
" v_mul_f32 v55, v169, v169 \n" " v_mul_f32 v55, v169, v169 \n"
" v_mul_f32 v56, v170, v170 \n" " v_mul_f32 v56, v170, v170 \n"
...@@ -396,7 +396,7 @@ ...@@ -396,7 +396,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b1], s[12:15], 0 offen offset:1024\n" "buffer_load_dwordx4 acc[84:87], %[v_os_b1], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -409,7 +409,7 @@ ...@@ -409,7 +409,7 @@
" v_mul_f32 v169, v169, v55 \n" " v_mul_f32 v169, v169, v55 \n"
" v_mul_f32 v170, v170, v56 \n" " v_mul_f32 v170, v170, v56 \n"
" v_mul_f32 v171, v171, v57 \n" " v_mul_f32 v171, v171, v57 \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b1], s[12:15], 0 offen offset:2048\n" "buffer_load_dwordx4 acc[88:91], %[v_os_b1], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v172, v172 \n" " v_mul_f32 v54, v172, v172 \n"
" v_mul_f32 v55, v173, v173 \n" " v_mul_f32 v55, v173, v173 \n"
" v_mul_f32 v56, v174, v174 \n" " v_mul_f32 v56, v174, v174 \n"
...@@ -430,7 +430,7 @@ ...@@ -430,7 +430,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b1], s[12:15], 0 offen offset:3072\n" "buffer_load_dwordx4 acc[92:95], %[v_os_b1], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -444,7 +444,7 @@ ...@@ -444,7 +444,7 @@
" v_mul_f32 v174, v174, v56 \n" " v_mul_f32 v174, v174, v56 \n"
" v_mul_f32 v175, v175, v57 \n" " v_mul_f32 v175, v175, v57 \n"
" s_waitcnt vmcnt(24) \n" " s_waitcnt vmcnt(24) \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b2], s[12:15], 0 offen\n" "buffer_load_dwordx4 acc[96:99], %[v_os_b2], s[12:15], 0 offen\n"
" v_mul_f32 v54, v176, v176 \n" " v_mul_f32 v54, v176, v176 \n"
" v_mul_f32 v55, v177, v177 \n" " v_mul_f32 v55, v177, v177 \n"
" v_mul_f32 v56, v178, v178 \n" " v_mul_f32 v56, v178, v178 \n"
...@@ -465,7 +465,7 @@ ...@@ -465,7 +465,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b2], s[12:15], 0 offen offset:1024\n" "buffer_load_dwordx4 acc[100:103], %[v_os_b2], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -478,7 +478,7 @@ ...@@ -478,7 +478,7 @@
" v_mul_f32 v177, v177, v55 \n" " v_mul_f32 v177, v177, v55 \n"
" v_mul_f32 v178, v178, v56 \n" " v_mul_f32 v178, v178, v56 \n"
" v_mul_f32 v179, v179, v57 \n" " v_mul_f32 v179, v179, v57 \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b2], s[12:15], 0 offen offset:2048\n" "buffer_load_dwordx4 acc[104:107], %[v_os_b2], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v180, v180 \n" " v_mul_f32 v54, v180, v180 \n"
" v_mul_f32 v55, v181, v181 \n" " v_mul_f32 v55, v181, v181 \n"
" v_mul_f32 v56, v182, v182 \n" " v_mul_f32 v56, v182, v182 \n"
...@@ -499,7 +499,7 @@ ...@@ -499,7 +499,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b2], s[12:15], 0 offen offset:3072\n" "buffer_load_dwordx4 acc[108:111], %[v_os_b2], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -512,7 +512,7 @@ ...@@ -512,7 +512,7 @@
" v_mul_f32 v181, v181, v55 \n" " v_mul_f32 v181, v181, v55 \n"
" v_mul_f32 v182, v182, v56 \n" " v_mul_f32 v182, v182, v56 \n"
" v_mul_f32 v183, v183, v57 \n" " v_mul_f32 v183, v183, v57 \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b3], s[12:15], 0 offen\n" "buffer_load_dwordx4 acc[112:115], %[v_os_b3], s[12:15], 0 offen\n"
" v_mul_f32 v54, v184, v184 \n" " v_mul_f32 v54, v184, v184 \n"
" v_mul_f32 v55, v185, v185 \n" " v_mul_f32 v55, v185, v185 \n"
" v_mul_f32 v56, v186, v186 \n" " v_mul_f32 v56, v186, v186 \n"
...@@ -533,7 +533,7 @@ ...@@ -533,7 +533,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b3], s[12:15], 0 offen offset:1024\n" "buffer_load_dwordx4 acc[116:119], %[v_os_b3], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -546,7 +546,7 @@ ...@@ -546,7 +546,7 @@
" v_mul_f32 v185, v185, v55 \n" " v_mul_f32 v185, v185, v55 \n"
" v_mul_f32 v186, v186, v56 \n" " v_mul_f32 v186, v186, v56 \n"
" v_mul_f32 v187, v187, v57 \n" " v_mul_f32 v187, v187, v57 \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b3], s[12:15], 0 offen offset:2048\n" "buffer_load_dwordx4 acc[120:123], %[v_os_b3], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v188, v188 \n" " v_mul_f32 v54, v188, v188 \n"
" v_mul_f32 v55, v189, v189 \n" " v_mul_f32 v55, v189, v189 \n"
" v_mul_f32 v56, v190, v190 \n" " v_mul_f32 v56, v190, v190 \n"
...@@ -567,7 +567,7 @@ ...@@ -567,7 +567,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b3], s[12:15], 0 offen offset:3072\n" "buffer_load_dwordx4 acc[124:127], %[v_os_b3], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -644,7 +644,7 @@ ...@@ -644,7 +644,7 @@
" v_mul_f32 v189, %[smq_scale1], v189 row_newbcast:13 \n" " v_mul_f32 v189, %[smq_scale1], v189 row_newbcast:13 \n"
" v_mul_f32 v190, %[smq_scale1], v190 row_newbcast:14 \n" " v_mul_f32 v190, %[smq_scale1], v190 row_newbcast:14 \n"
" v_mul_f32 v191, %[smq_scale1], v191 row_newbcast:15 \n" " v_mul_f32 v191, %[smq_scale1], v191 row_newbcast:15 \n"
" buffer_load_dword v12, %[v_os_dq], s[16:19], 0 offen \n" ";--buffer_load_dword v12, %[v_os_dq], s[16:19], 0 offen \n"
" v_mov_b32 v22, 0x358637bd \n" " v_mov_b32 v22, 0x358637bd \n"
" v_mov_b32 v23, 0x358637bd \n" " v_mov_b32 v23, 0x358637bd \n"
" v_max3_f32 v22, abs(v128), abs(v129), v22 \n" " v_max3_f32 v22, abs(v128), abs(v129), v22 \n"
...@@ -974,3 +974,5 @@ ...@@ -974,3 +974,5 @@
#undef _UK_PK_CVT_ #undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_ #undef _UK_ATOMIC_ADD_
...@@ -27,8 +27,6 @@ ...@@ -27,8 +27,6 @@
# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_f16" # define _UK_ATOMIC_ADD_ "global_atomic_pk_add_f16"
#endif #endif
" s_mov_b32 s36, -1 \n"
" s_mov_b32 s37, -1 \n"
" s_add_u32 s12, %[s_tile_os_b], s12 \n" " s_add_u32 s12, %[s_tile_os_b], s12 \n"
" s_addc_u32 s13, 0, s13 \n" " s_addc_u32 s13, 0, s13 \n"
" s_add_u32 s16, %[s_tile_os_dq], s16 \n" " s_add_u32 s16, %[s_tile_os_dq], s16 \n"
...@@ -168,7 +166,7 @@ ...@@ -168,7 +166,7 @@
" buffer_load_dwordx4 acc[224:227], %[v_os_b2], s[12:15], 0 offen \n" " buffer_load_dwordx4 acc[224:227], %[v_os_b2], s[12:15], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[208:211], acc[100:101], v[148:149], v[208:211] \n" " v_mfma_i32_16x16x32_i8 v[208:211], acc[100:101], v[148:149], v[208:211] \n"
" v_mfma_i32_16x16x32_i8 v[208:211], acc[102:103], v[150:151], v[208:211] \n" " v_mfma_i32_16x16x32_i8 v[208:211], acc[102:103], v[150:151], v[208:211] \n"
" buffer_load_dword v13, %[v_os_dq], s[16:19], 0 offen \n" ";--- buffer_load_dword v13, %[v_os_dq], s[16:19], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[208:211], acc[104:105], v[152:153], v[208:211] \n" " v_mfma_i32_16x16x32_i8 v[208:211], acc[104:105], v[152:153], v[208:211] \n"
" v_mfma_i32_16x16x32_i8 v[208:211], acc[106:107], v[154:155], v[208:211] \n" " v_mfma_i32_16x16x32_i8 v[208:211], acc[106:107], v[154:155], v[208:211] \n"
" buffer_load_dwordx4 acc[228:231], %[v_os_b2], s[12:15], 0 offen offset:1024 \n" " buffer_load_dwordx4 acc[228:231], %[v_os_b2], s[12:15], 0 offen offset:1024 \n"
...@@ -480,560 +478,8 @@ ...@@ -480,560 +478,8 @@
" ds_read_b32 v78, v4 offset:43872 \n" " ds_read_b32 v78, v4 offset:43872 \n"
" ds_read_b32 v79, v4 offset:48224 \n" " ds_read_b32 v79, v4 offset:48224 \n"
" s_waitcnt lgkmcnt(0) \n" " s_waitcnt lgkmcnt(0) \n"
" s_mov_b64 exec, %[s_execflag_0] \n"
" global_atomic_pk_add_bf16 %[v_os_o0], v64, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_0] \n"
" global_atomic_pk_add_bf16 %[v_os_o0], v65, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_1] \n"
" global_atomic_pk_add_bf16 %[v_os_o1], v66, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_1] \n"
" global_atomic_pk_add_bf16 %[v_os_o1], v67, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_2] \n"
" global_atomic_pk_add_bf16 %[v_os_o2], v68, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_2] \n"
" global_atomic_pk_add_bf16 %[v_os_o2], v69, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_3] \n"
" global_atomic_pk_add_bf16 %[v_os_o3], v70, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_3] \n"
" global_atomic_pk_add_bf16 %[v_os_o3], v71, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_4] \n"
" global_atomic_pk_add_bf16 %[v_os_o4], v72, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_4] \n"
" global_atomic_pk_add_bf16 %[v_os_o4], v73, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_5] \n"
" global_atomic_pk_add_bf16 %[v_os_o5], v74, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_5] \n"
" global_atomic_pk_add_bf16 %[v_os_o5], v75, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_6] \n"
" global_atomic_pk_add_bf16 %[v_os_o6], v76, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_6] \n"
" global_atomic_pk_add_bf16 %[v_os_o6], v77, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_7] \n"
" global_atomic_pk_add_bf16 %[v_os_o7], v78, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_7] \n"
" global_atomic_pk_add_bf16 %[v_os_o7], v79, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_add_u32 %[s_res_o0], %[s_tile_os_o], %[s_res_o0] \n"
" s_addc_u32 %[s_res_o1], 0, %[s_res_o1] \n"
" s_addk_i32 s80, 0x0100 \n"
" s_cmp_lt_i32 s80, %[s_loop_cnt] \n"
" s_cbranch_scc0 label_end_gemm2 \n"
" s_waitcnt vmcnt(41) \n"
" s_barrier \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[128:129], v[128:129], 0 \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[130:131], v[130:131], v[224:227] \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[132:133], v[132:133], v[224:227] \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[134:135], v[134:135], v[224:227] \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[136:137], v[136:137], v[224:227] \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[138:139], v[138:139], v[224:227] \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024 \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[140:141], v[140:141], v[224:227] \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[142:143], v[142:143], v[224:227] \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[128:129], v[160:161], 0 \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[130:131], v[162:163], v[228:231] \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048 \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[132:133], v[164:165], v[228:231] \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[134:135], v[166:167], v[228:231] \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[136:137], v[168:169], v[228:231] \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[138:139], v[170:171], v[228:231] \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072 \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[140:141], v[172:173], v[228:231] \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[142:143], v[174:175], v[228:231] \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[144:145], v[128:129], 0 \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[146:147], v[130:131], v[232:235] \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[148:149], v[132:133], v[232:235] \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[150:151], v[134:135], v[232:235] \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[152:153], v[136:137], v[232:235] \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[154:155], v[138:139], v[232:235] \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024 \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[156:157], v[140:141], v[232:235] \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[158:159], v[142:143], v[232:235] \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[144:145], v[160:161], 0 \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[146:147], v[162:163], v[236:239] \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048 \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[148:149], v[164:165], v[236:239] \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[150:151], v[166:167], v[236:239] \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[152:153], v[168:169], v[236:239] \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[154:155], v[170:171], v[236:239] \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072 \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[156:157], v[172:173], v[236:239] \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[158:159], v[174:175], v[236:239] \n"
" s_waitcnt vmcnt(41) \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[160:161], v[128:129], 0 \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[162:163], v[130:131], v[240:243] \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[164:165], v[132:133], v[240:243] \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[166:167], v[134:135], v[240:243] \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[168:169], v[136:137], v[240:243] \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[170:171], v[138:139], v[240:243] \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024 \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[172:173], v[140:141], v[240:243] \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[174:175], v[142:143], v[240:243] \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[160:161], v[160:161], 0 \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[162:163], v[162:163], v[244:247] \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048 \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[164:165], v[164:165], v[244:247] \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[166:167], v[166:167], v[244:247] \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[168:169], v[168:169], v[244:247] \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[170:171], v[170:171], v[244:247] \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072 \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[172:173], v[172:173], v[244:247] \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[174:175], v[174:175], v[244:247] \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[176:177], v[128:129], 0 \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[178:179], v[130:131], v[248:251] \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[180:181], v[132:133], v[248:251] \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[182:183], v[134:135], v[248:251] \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[184:185], v[136:137], v[248:251] \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[186:187], v[138:139], v[248:251] \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024 \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[188:189], v[140:141], v[248:251] \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[190:191], v[142:143], v[248:251] \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[176:177], v[160:161], 0 \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[178:179], v[162:163], v[252:255] \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048 \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[180:181], v[164:165], v[252:255] \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[182:183], v[166:167], v[252:255] \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[184:185], v[168:169], v[252:255] \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[186:187], v[170:171], v[252:255] \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072 \n"
" s_add_u32 s12, %[s_tile_os_b_half], s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[188:189], v[172:173], v[252:255] \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[190:191], v[174:175], v[252:255] \n"
" s_waitcnt vmcnt(41) \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[192:193], v[144:145], v[224:227] \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[194:195], v[146:147], v[224:227] \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b0], s[12:15], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[196:197], v[148:149], v[224:227] \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[198:199], v[150:151], v[224:227] \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[200:201], v[152:153], v[224:227] \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[202:203], v[154:155], v[224:227] \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b0], s[12:15], 0 offen offset:1024 \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[204:205], v[156:157], v[224:227] \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[206:207], v[158:159], v[224:227] \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[192:193], v[176:177], v[228:231] \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[194:195], v[178:179], v[228:231] \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b0], s[12:15], 0 offen offset:2048 \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[196:197], v[180:181], v[228:231] \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[198:199], v[182:183], v[228:231] \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[200:201], v[184:185], v[228:231] \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[202:203], v[186:187], v[228:231] \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b0], s[12:15], 0 offen offset:3072 \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[204:205], v[188:189], v[228:231] \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[206:207], v[190:191], v[228:231] \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[208:209], v[144:145], v[232:235] \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[210:211], v[146:147], v[232:235] \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b1], s[12:15], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[212:213], v[148:149], v[232:235] \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[214:215], v[150:151], v[232:235] \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[216:217], v[152:153], v[232:235] \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[218:219], v[154:155], v[232:235] \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b1], s[12:15], 0 offen offset:1024 \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[220:221], v[156:157], v[232:235] \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[222:223], v[158:159], v[232:235] \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[208:209], v[176:177], v[236:239] \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[210:211], v[178:179], v[236:239] \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b1], s[12:15], 0 offen offset:2048 \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[212:213], v[180:181], v[236:239] \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[214:215], v[182:183], v[236:239] \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[216:217], v[184:185], v[236:239] \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[218:219], v[186:187], v[236:239] \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b1], s[12:15], 0 offen offset:3072 \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[220:221], v[188:189], v[236:239] \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[222:223], v[190:191], v[236:239] \n"
" s_waitcnt vmcnt(40) \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[224:225], v[144:145], v[240:243] \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[226:227], v[146:147], v[240:243] \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b2], s[12:15], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[228:229], v[148:149], v[240:243] \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[230:231], v[150:151], v[240:243] \n"
" buffer_load_dword v12, %[v_os_dq], s[16:19], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[232:233], v[152:153], v[240:243] \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[234:235], v[154:155], v[240:243] \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b2], s[12:15], 0 offen offset:1024 \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[236:237], v[156:157], v[240:243] \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[238:239], v[158:159], v[240:243] \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[224:225], v[176:177], v[244:247] \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[226:227], v[178:179], v[244:247] \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b2], s[12:15], 0 offen offset:2048 \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[228:229], v[180:181], v[244:247] \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[230:231], v[182:183], v[244:247] \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[232:233], v[184:185], v[244:247] \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[234:235], v[186:187], v[244:247] \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b2], s[12:15], 0 offen offset:3072 \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[236:237], v[188:189], v[244:247] \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[238:239], v[190:191], v[244:247] \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[240:241], v[144:145], v[248:251] \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[242:243], v[146:147], v[248:251] \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b3], s[12:15], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[244:245], v[148:149], v[248:251] \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[246:247], v[150:151], v[248:251] \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[248:249], v[152:153], v[248:251] \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[250:251], v[154:155], v[248:251] \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b3], s[12:15], 0 offen offset:1024 \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[252:253], v[156:157], v[248:251] \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[254:255], v[158:159], v[248:251] \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[240:241], v[176:177], v[252:255] \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[242:243], v[178:179], v[252:255] \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b3], s[12:15], 0 offen offset:2048 \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[244:245], v[180:181], v[252:255] \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[246:247], v[182:183], v[252:255] \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[248:249], v[184:185], v[252:255] \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[250:251], v[186:187], v[252:255] \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b3], s[12:15], 0 offen offset:3072 \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[252:253], v[188:189], v[252:255] \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[254:255], v[190:191], v[252:255] \n"
" s_add_u32 s60, 0x00000200, s80 \n"
" s_cmp_lt_u32 s60, %[s_loop_cnt] \n"
" s_cselect_b32 %[s_tile_os_b], %[s_tile_os_b], 0 \n"
" s_cselect_b32 %[s_tile_os_b_half], %[s_tile_os_b_half], 0 \n"
" s_cselect_b32 %[s_tile_os_dq], %[s_tile_os_dq], 0 \n"
" s_add_u32 s12, %[s_tile_os_b], s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" s_add_u32 s16, %[s_tile_os_dq], s16 \n"
" s_addc_u32 s17, 0, s17 \n"
" v_cvt_f32_i32 v224, v224 \n"
" v_cvt_f32_i32 v225, v225 \n"
" v_cvt_f32_i32 v226, v226 \n"
" v_cvt_f32_i32 v227, v227 \n"
" v_mul_f32 v224, v24, v224 \n"
" v_mul_f32 v225, v24, v225 \n"
" v_mul_f32 v226, v24, v226 \n"
" v_mul_f32 v227, v24, v227 \n"
" v_mul_f32 v224, v13, v224 row_newbcast:0 \n"
" v_mul_f32 v225, v13, v225 row_newbcast:1 \n"
" v_mul_f32 v226, v13, v226 row_newbcast:2 \n"
" v_mul_f32 v227, v13, v227 row_newbcast:3 \n"
" v_mul_f32 v224, %[scale_0], v224 \n"
" v_mul_f32 v225, %[scale_0], v225 \n"
" v_mul_f32 v226, %[scale_0], v226 \n"
" v_mul_f32 v227, %[scale_0], v227 \n"
" v_cvt_f32_i32 v228, v228 \n"
" v_cvt_f32_i32 v229, v229 \n"
" v_cvt_f32_i32 v230, v230 \n"
" v_cvt_f32_i32 v231, v231 \n"
" v_mul_f32 v228, v25, v228 \n"
" v_mul_f32 v229, v25, v229 \n"
" v_mul_f32 v230, v25, v230 \n"
" v_mul_f32 v231, v25, v231 \n"
" v_mul_f32 v228, v13, v228 row_newbcast:0 \n"
" v_mul_f32 v229, v13, v229 row_newbcast:1 \n"
" v_mul_f32 v230, v13, v230 row_newbcast:2 \n"
" v_mul_f32 v231, v13, v231 row_newbcast:3 \n"
" v_mul_f32 v228, %[scale_1], v228 \n"
" v_mul_f32 v229, %[scale_1], v229 \n"
" v_mul_f32 v230, %[scale_1], v230 \n"
" v_mul_f32 v231, %[scale_1], v231 \n"
" v_cvt_f32_i32 v232, v232 \n"
" v_cvt_f32_i32 v233, v233 \n"
" v_cvt_f32_i32 v234, v234 \n"
" v_cvt_f32_i32 v235, v235 \n"
" v_mul_f32 v232, v24, v232 \n"
" v_mul_f32 v233, v24, v233 \n"
" v_mul_f32 v234, v24, v234 \n"
" v_mul_f32 v235, v24, v235 \n"
" v_mul_f32 v232, v13, v232 row_newbcast:4 \n"
" v_mul_f32 v233, v13, v233 row_newbcast:5 \n"
" v_mul_f32 v234, v13, v234 row_newbcast:6 \n"
" v_mul_f32 v235, v13, v235 row_newbcast:7 \n"
" v_mul_f32 v232, %[scale_0], v232 \n"
" v_mul_f32 v233, %[scale_0], v233 \n"
" v_mul_f32 v234, %[scale_0], v234 \n"
" v_mul_f32 v235, %[scale_0], v235 \n"
" v_cvt_f32_i32 v236, v236 \n"
" v_cvt_f32_i32 v237, v237 \n"
" v_cvt_f32_i32 v238, v238 \n"
" v_cvt_f32_i32 v239, v239 \n"
" v_mul_f32 v236, v25, v236 \n"
" v_mul_f32 v237, v25, v237 \n"
" v_mul_f32 v238, v25, v238 \n"
" v_mul_f32 v239, v25, v239 \n"
" v_mul_f32 v236, v13, v236 row_newbcast:4 \n"
" v_mul_f32 v237, v13, v237 row_newbcast:5 \n"
" v_mul_f32 v238, v13, v238 row_newbcast:6 \n"
" v_mul_f32 v239, v13, v239 row_newbcast:7 \n"
" v_mul_f32 v236, %[scale_1], v236 \n"
" v_mul_f32 v237, %[scale_1], v237 \n"
" v_mul_f32 v238, %[scale_1], v238 \n"
" v_mul_f32 v239, %[scale_1], v239 \n"
" v_cvt_f32_i32 v240, v240 \n"
" v_cvt_f32_i32 v241, v241 \n"
" v_cvt_f32_i32 v242, v242 \n"
" v_cvt_f32_i32 v243, v243 \n"
" v_mul_f32 v240, v24, v240 \n"
" v_mul_f32 v241, v24, v241 \n"
" v_mul_f32 v242, v24, v242 \n"
" v_mul_f32 v243, v24, v243 \n"
" v_mul_f32 v240, v13, v240 row_newbcast:8 \n"
" v_mul_f32 v241, v13, v241 row_newbcast:9 \n"
" v_mul_f32 v242, v13, v242 row_newbcast:10 \n"
" v_mul_f32 v243, v13, v243 row_newbcast:11 \n"
" v_mul_f32 v240, %[scale_0], v240 \n"
" v_mul_f32 v241, %[scale_0], v241 \n"
" v_mul_f32 v242, %[scale_0], v242 \n"
" v_mul_f32 v243, %[scale_0], v243 \n"
" v_cvt_f32_i32 v244, v244 \n"
" v_cvt_f32_i32 v245, v245 \n"
" v_cvt_f32_i32 v246, v246 \n"
" v_cvt_f32_i32 v247, v247 \n"
" v_mul_f32 v244, v25, v244 \n"
" v_mul_f32 v245, v25, v245 \n"
" v_mul_f32 v246, v25, v246 \n"
" v_mul_f32 v247, v25, v247 \n"
" v_mul_f32 v244, v13, v244 row_newbcast:8 \n"
" v_mul_f32 v245, v13, v245 row_newbcast:9 \n"
" v_mul_f32 v246, v13, v246 row_newbcast:10 \n"
" v_mul_f32 v247, v13, v247 row_newbcast:11 \n"
" v_mul_f32 v244, %[scale_1], v244 \n"
" v_mul_f32 v245, %[scale_1], v245 \n"
" v_mul_f32 v246, %[scale_1], v246 \n"
" v_mul_f32 v247, %[scale_1], v247 \n"
" v_cvt_f32_i32 v248, v248 \n"
" v_cvt_f32_i32 v249, v249 \n"
" v_cvt_f32_i32 v250, v250 \n"
" v_cvt_f32_i32 v251, v251 \n"
" v_mul_f32 v248, v24, v248 \n"
" v_mul_f32 v249, v24, v249 \n"
" v_mul_f32 v250, v24, v250 \n"
" v_mul_f32 v251, v24, v251 \n"
" v_mul_f32 v248, v13, v248 row_newbcast:12 \n"
" v_mul_f32 v249, v13, v249 row_newbcast:13 \n"
" v_mul_f32 v250, v13, v250 row_newbcast:14 \n"
" v_mul_f32 v251, v13, v251 row_newbcast:15 \n"
" v_mul_f32 v248, %[scale_0], v248 \n"
" v_mul_f32 v249, %[scale_0], v249 \n"
" v_mul_f32 v250, %[scale_0], v250 \n"
" v_mul_f32 v251, %[scale_0], v251 \n"
" v_cvt_f32_i32 v252, v252 \n"
" v_cvt_f32_i32 v253, v253 \n"
" v_cvt_f32_i32 v254, v254 \n"
" v_cvt_f32_i32 v255, v255 \n"
" v_mul_f32 v252, v25, v252 \n"
" v_mul_f32 v253, v25, v253 \n"
" v_mul_f32 v254, v25, v254 \n"
" v_mul_f32 v255, v25, v255 \n"
" v_mul_f32 v252, v13, v252 row_newbcast:12 \n"
" v_mul_f32 v253, v13, v253 row_newbcast:13 \n"
" v_mul_f32 v254, v13, v254 row_newbcast:14 \n"
" v_mul_f32 v255, v13, v255 row_newbcast:15 \n"
" v_mul_f32 v252, %[scale_1], v252 \n"
" v_mul_f32 v253, %[scale_1], v253 \n"
" v_mul_f32 v254, %[scale_1], v254 \n"
" v_mul_f32 v255, %[scale_1], v255 \n"
" v_cmp_u_f32 s[48:49], v224, v224 \n"
" v_add3_u32 v50, v224, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v225, v225 \n"
" v_add3_u32 v50, v225, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v224, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v226, v226 \n"
" v_add3_u32 v50, v226, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v227, v227 \n"
" v_add3_u32 v50, v227, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v225, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v228, v228 \n"
" v_add3_u32 v50, v228, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v229, v229 \n"
" v_add3_u32 v50, v229, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v226, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v230, v230 \n"
" v_add3_u32 v50, v230, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v231, v231 \n"
" v_add3_u32 v50, v231, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v227, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v232, v232 \n"
" v_add3_u32 v50, v232, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v233, v233 \n"
" v_add3_u32 v50, v233, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v228, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v234, v234 \n"
" v_add3_u32 v50, v234, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v235, v235 \n"
" v_add3_u32 v50, v235, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v229, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v236, v236 \n"
" v_add3_u32 v50, v236, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v237, v237 \n"
" v_add3_u32 v50, v237, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v230, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v238, v238 \n"
" v_add3_u32 v50, v238, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v239, v239 \n"
" v_add3_u32 v50, v239, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v231, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v240, v240 \n"
" v_add3_u32 v50, v240, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v241, v241 \n"
" v_add3_u32 v50, v241, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v232, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v242, v242 \n"
" v_add3_u32 v50, v242, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v243, v243 \n"
" v_add3_u32 v50, v243, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v233, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v244, v244 \n"
" v_add3_u32 v50, v244, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v245, v245 \n"
" v_add3_u32 v50, v245, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v234, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v246, v246 \n"
" v_add3_u32 v50, v246, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v247, v247 \n"
" v_add3_u32 v50, v247, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v235, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v248, v248 \n"
" v_add3_u32 v50, v248, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v249, v249 \n"
" v_add3_u32 v50, v249, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v236, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v250, v250 \n"
" v_add3_u32 v50, v250, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v251, v251 \n"
" v_add3_u32 v50, v251, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v237, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v252, v252 \n"
" v_add3_u32 v50, v252, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v253, v253 \n"
" v_add3_u32 v50, v253, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v238, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v254, v254 \n"
" v_add3_u32 v50, v254, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v255, v255 \n"
" v_add3_u32 v50, v255, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v239, v55, v54, s52 \n"
" ds_write_b64 v3, v[224:225] offset:35072 \n"
" ds_write_b64 v3, v[226:227] offset:43776 \n"
" ds_write_b64 v3, v[228:229] offset:37248 \n"
" ds_write_b64 v3, v[230:231] offset:45952 \n"
" ds_write_b64 v3, v[232:233] offset:39424 \n"
" ds_write_b64 v3, v[234:235] offset:48128 \n"
" ds_write_b64 v3, v[236:237] offset:41600 \n"
" ds_write_b64 v3, v[238:239] offset:50304 \n"
" s_waitcnt lgkmcnt(0) \n"
" s_barrier \n"
" ds_read_b32 v64, v4 offset:35072 \n"
" ds_read_b32 v65, v4 offset:39424 \n"
" ds_read_b32 v66, v4 offset:35104 \n"
" ds_read_b32 v67, v4 offset:39456 \n"
" ds_read_b32 v68, v4 offset:35136 \n"
" ds_read_b32 v69, v4 offset:39488 \n"
" ds_read_b32 v70, v4 offset:35168 \n"
" ds_read_b32 v71, v4 offset:39520 \n"
" ds_read_b32 v72, v4 offset:43776 \n"
" ds_read_b32 v73, v4 offset:48128 \n"
" ds_read_b32 v74, v4 offset:43808 \n"
" ds_read_b32 v75, v4 offset:48160 \n"
" ds_read_b32 v76, v4 offset:43840 \n"
" ds_read_b32 v77, v4 offset:48192 \n"
" ds_read_b32 v78, v4 offset:43872 \n"
" ds_read_b32 v79, v4 offset:48224 \n"
" s_waitcnt lgkmcnt(0) \n"
" s_mov_b64 exec, %[s_execflag_0] \n"
" global_atomic_pk_add_bf16 %[v_os_o0], v64, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_0] \n"
" global_atomic_pk_add_bf16 %[v_os_o0], v65, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_1] \n"
" global_atomic_pk_add_bf16 %[v_os_o1], v66, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_1] \n"
" global_atomic_pk_add_bf16 %[v_os_o1], v67, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_2] \n"
" global_atomic_pk_add_bf16 %[v_os_o2], v68, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_2] \n"
" global_atomic_pk_add_bf16 %[v_os_o2], v69, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_3] \n"
" global_atomic_pk_add_bf16 %[v_os_o3], v70, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_3] \n"
" global_atomic_pk_add_bf16 %[v_os_o3], v71, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_4] \n"
" global_atomic_pk_add_bf16 %[v_os_o4], v72, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_4] \n"
" global_atomic_pk_add_bf16 %[v_os_o4], v73, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_5] \n"
" global_atomic_pk_add_bf16 %[v_os_o5], v74, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_5] \n"
" global_atomic_pk_add_bf16 %[v_os_o5], v75, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_6] \n"
" global_atomic_pk_add_bf16 %[v_os_o6], v76, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_6] \n"
" global_atomic_pk_add_bf16 %[v_os_o6], v77, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_7] \n"
" global_atomic_pk_add_bf16 %[v_os_o7], v78, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_7] \n"
" global_atomic_pk_add_bf16 %[v_os_o7], v79, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_add_u32 %[s_res_o0], %[s_tile_os_o], %[s_res_o0] \n"
" s_addc_u32 %[s_res_o1], 0, %[s_res_o1] \n"
" s_addk_i32 s80, 0x0100 \n"
" s_cmp_lt_i32 s80, %[s_loop_cnt] \n"
" s_cbranch_scc0 label_end_gemm2 \n"
" s_branch label_startgemm2 \n"
" label_end_gemm2: \n"
" s_waitcnt 0x0000 \n"
" s_endpgm \n"
#undef _UK_MFMA_ #undef _UK_MFMA_
#undef _UK_PK_CVT_ #undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_ #undef _UK_ATOMIC_ADD_
#ifndef CK_TILE_FLATMM_UK_MFMA
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8
#endif
#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_INT8
# define _UK_MFMA_ "v_mfma_i32_16x16x32_i8"
# define _UK_PK_CVT_(x0_, x1_, y_) \
" v_cmp_u_f32 s[36:37], " x0_ ", " x0_ " \n" \
" v_add3_u32 v50, " x0_ ", %[v_nan_lo], 1 \n" \
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[36:37] \n" \
" v_cmp_u_f32 s[36:37], " x1_ ", " x1_ " \n" \
" v_add3_u32 v50, " x1_ ", %[v_nan_lo], 1 \n" \
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[36:37] \n" \
" v_perm_b32 " y_ ", v55, v54, s52 \n"
# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_bf16"
#elif CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_FP16
#define _UK_MFMA_ "v_mfma_f32_16x16x16_f16"
# define _UK_PK_CVT_(x0_, x1_, y_) \
" v_cvt_f16_f32 v54, " x0_ " \n" \
" v_cvt_f16_f32 v55, " x1_ " \n" \
" v_pack_b32_f16 " y_ ", v54, v55 \n"
# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_f16"
#endif
" s_mov_b32 s36, -1 \n"
" s_mov_b32 s37, -1 \n"
" s_mov_b64 exec, %[s_execflag_0] \n"
" global_atomic_pk_add_bf16 %[v_os_o0], v64, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_0] \n"
" global_atomic_pk_add_bf16 %[v_os_o0], v65, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_1] \n"
" global_atomic_pk_add_bf16 %[v_os_o1], v66, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_1] \n"
" global_atomic_pk_add_bf16 %[v_os_o1], v67, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_2] \n"
" global_atomic_pk_add_bf16 %[v_os_o2], v68, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_2] \n"
" global_atomic_pk_add_bf16 %[v_os_o2], v69, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_3] \n"
" global_atomic_pk_add_bf16 %[v_os_o3], v70, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_3] \n"
" global_atomic_pk_add_bf16 %[v_os_o3], v71, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_4] \n"
" global_atomic_pk_add_bf16 %[v_os_o4], v72, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_4] \n"
" global_atomic_pk_add_bf16 %[v_os_o4], v73, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_5] \n"
" global_atomic_pk_add_bf16 %[v_os_o5], v74, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_5] \n"
" global_atomic_pk_add_bf16 %[v_os_o5], v75, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_6] \n"
" global_atomic_pk_add_bf16 %[v_os_o6], v76, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_6] \n"
" global_atomic_pk_add_bf16 %[v_os_o6], v77, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_7] \n"
" global_atomic_pk_add_bf16 %[v_os_o7], v78, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_7] \n"
" global_atomic_pk_add_bf16 %[v_os_o7], v79, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_add_u32 %[s_res_o0], %[s_tile_os_o], %[s_res_o0] \n"
" s_addc_u32 %[s_res_o1], 0, %[s_res_o1] \n"
" s_addk_i32 s80, 0x0100 \n"
" s_cmp_lt_i32 s80, %[s_loop_cnt] \n"
" s_cbranch_scc0 label_end_gemm2 \n"
" s_waitcnt vmcnt(41) \n"
" s_barrier \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[128:129], v[128:129], 0 \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[130:131], v[130:131], v[224:227] \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[132:133], v[132:133], v[224:227] \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[134:135], v[134:135], v[224:227] \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[136:137], v[136:137], v[224:227] \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[138:139], v[138:139], v[224:227] \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024 \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[140:141], v[140:141], v[224:227] \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[142:143], v[142:143], v[224:227] \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[128:129], v[160:161], 0 \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[130:131], v[162:163], v[228:231] \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048 \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[132:133], v[164:165], v[228:231] \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[134:135], v[166:167], v[228:231] \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[136:137], v[168:169], v[228:231] \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[138:139], v[170:171], v[228:231] \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072 \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[140:141], v[172:173], v[228:231] \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[142:143], v[174:175], v[228:231] \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[144:145], v[128:129], 0 \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[146:147], v[130:131], v[232:235] \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[148:149], v[132:133], v[232:235] \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[150:151], v[134:135], v[232:235] \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[152:153], v[136:137], v[232:235] \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[154:155], v[138:139], v[232:235] \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024 \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[156:157], v[140:141], v[232:235] \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[158:159], v[142:143], v[232:235] \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[144:145], v[160:161], 0 \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[146:147], v[162:163], v[236:239] \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048 \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[148:149], v[164:165], v[236:239] \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[150:151], v[166:167], v[236:239] \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[152:153], v[168:169], v[236:239] \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[154:155], v[170:171], v[236:239] \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072 \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[156:157], v[172:173], v[236:239] \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[158:159], v[174:175], v[236:239] \n"
" s_waitcnt vmcnt(41) \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[160:161], v[128:129], 0 \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[162:163], v[130:131], v[240:243] \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[164:165], v[132:133], v[240:243] \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[166:167], v[134:135], v[240:243] \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[168:169], v[136:137], v[240:243] \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[170:171], v[138:139], v[240:243] \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024 \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[172:173], v[140:141], v[240:243] \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[174:175], v[142:143], v[240:243] \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[160:161], v[160:161], 0 \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[162:163], v[162:163], v[244:247] \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048 \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[164:165], v[164:165], v[244:247] \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[166:167], v[166:167], v[244:247] \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[168:169], v[168:169], v[244:247] \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[170:171], v[170:171], v[244:247] \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072 \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[172:173], v[172:173], v[244:247] \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[174:175], v[174:175], v[244:247] \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[176:177], v[128:129], 0 \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[178:179], v[130:131], v[248:251] \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[180:181], v[132:133], v[248:251] \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[182:183], v[134:135], v[248:251] \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[184:185], v[136:137], v[248:251] \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[186:187], v[138:139], v[248:251] \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024 \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[188:189], v[140:141], v[248:251] \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[190:191], v[142:143], v[248:251] \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[176:177], v[160:161], 0 \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[178:179], v[162:163], v[252:255] \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048 \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[180:181], v[164:165], v[252:255] \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[182:183], v[166:167], v[252:255] \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[184:185], v[168:169], v[252:255] \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[186:187], v[170:171], v[252:255] \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072 \n"
" s_add_u32 s12, %[s_tile_os_b_half], s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[188:189], v[172:173], v[252:255] \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[190:191], v[174:175], v[252:255] \n"
" s_waitcnt vmcnt(41) \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[192:193], v[144:145], v[224:227] \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[194:195], v[146:147], v[224:227] \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b0], s[12:15], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[196:197], v[148:149], v[224:227] \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[198:199], v[150:151], v[224:227] \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[200:201], v[152:153], v[224:227] \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[202:203], v[154:155], v[224:227] \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b0], s[12:15], 0 offen offset:1024 \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[204:205], v[156:157], v[224:227] \n"
" v_mfma_i32_16x16x32_i8 v[224:227], acc[206:207], v[158:159], v[224:227] \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[192:193], v[176:177], v[228:231] \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[194:195], v[178:179], v[228:231] \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b0], s[12:15], 0 offen offset:2048 \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[196:197], v[180:181], v[228:231] \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[198:199], v[182:183], v[228:231] \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[200:201], v[184:185], v[228:231] \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[202:203], v[186:187], v[228:231] \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b0], s[12:15], 0 offen offset:3072 \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[204:205], v[188:189], v[228:231] \n"
" v_mfma_i32_16x16x32_i8 v[228:231], acc[206:207], v[190:191], v[228:231] \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[208:209], v[144:145], v[232:235] \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[210:211], v[146:147], v[232:235] \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b1], s[12:15], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[212:213], v[148:149], v[232:235] \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[214:215], v[150:151], v[232:235] \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[216:217], v[152:153], v[232:235] \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[218:219], v[154:155], v[232:235] \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b1], s[12:15], 0 offen offset:1024 \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[220:221], v[156:157], v[232:235] \n"
" v_mfma_i32_16x16x32_i8 v[232:235], acc[222:223], v[158:159], v[232:235] \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[208:209], v[176:177], v[236:239] \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[210:211], v[178:179], v[236:239] \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b1], s[12:15], 0 offen offset:2048 \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[212:213], v[180:181], v[236:239] \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[214:215], v[182:183], v[236:239] \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[216:217], v[184:185], v[236:239] \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[218:219], v[186:187], v[236:239] \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b1], s[12:15], 0 offen offset:3072 \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[220:221], v[188:189], v[236:239] \n"
" v_mfma_i32_16x16x32_i8 v[236:239], acc[222:223], v[190:191], v[236:239] \n"
" s_waitcnt vmcnt(40) \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[224:225], v[144:145], v[240:243] \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[226:227], v[146:147], v[240:243] \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b2], s[12:15], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[228:229], v[148:149], v[240:243] \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[230:231], v[150:151], v[240:243] \n"
";-- buffer_load_dword v12, %[v_os_dq], s[16:19], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[232:233], v[152:153], v[240:243] \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[234:235], v[154:155], v[240:243] \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b2], s[12:15], 0 offen offset:1024 \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[236:237], v[156:157], v[240:243] \n"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[238:239], v[158:159], v[240:243] \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[224:225], v[176:177], v[244:247] \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[226:227], v[178:179], v[244:247] \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b2], s[12:15], 0 offen offset:2048 \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[228:229], v[180:181], v[244:247] \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[230:231], v[182:183], v[244:247] \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[232:233], v[184:185], v[244:247] \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[234:235], v[186:187], v[244:247] \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b2], s[12:15], 0 offen offset:3072 \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[236:237], v[188:189], v[244:247] \n"
" v_mfma_i32_16x16x32_i8 v[244:247], acc[238:239], v[190:191], v[244:247] \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[240:241], v[144:145], v[248:251] \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[242:243], v[146:147], v[248:251] \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b3], s[12:15], 0 offen \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[244:245], v[148:149], v[248:251] \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[246:247], v[150:151], v[248:251] \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[248:249], v[152:153], v[248:251] \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[250:251], v[154:155], v[248:251] \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b3], s[12:15], 0 offen offset:1024 \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[252:253], v[156:157], v[248:251] \n"
" v_mfma_i32_16x16x32_i8 v[248:251], acc[254:255], v[158:159], v[248:251] \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[240:241], v[176:177], v[252:255] \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[242:243], v[178:179], v[252:255] \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b3], s[12:15], 0 offen offset:2048 \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[244:245], v[180:181], v[252:255] \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[246:247], v[182:183], v[252:255] \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[248:249], v[184:185], v[252:255] \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[250:251], v[186:187], v[252:255] \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b3], s[12:15], 0 offen offset:3072 \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[252:253], v[188:189], v[252:255] \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[254:255], v[190:191], v[252:255] \n"
" s_add_u32 s60, 0x00000200, s80 \n"
" s_cmp_lt_u32 s60, %[s_loop_cnt] \n"
" s_cselect_b32 %[s_tile_os_b], %[s_tile_os_b], 0 \n"
" s_cselect_b32 %[s_tile_os_b_half], %[s_tile_os_b_half], 0 \n"
" s_cselect_b32 %[s_tile_os_dq], %[s_tile_os_dq], 0 \n"
" s_add_u32 s12, %[s_tile_os_b], s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" s_add_u32 s16, %[s_tile_os_dq], s16 \n"
" s_addc_u32 s17, 0, s17 \n"
" v_cvt_f32_i32 v224, v224 \n"
" v_cvt_f32_i32 v225, v225 \n"
" v_cvt_f32_i32 v226, v226 \n"
" v_cvt_f32_i32 v227, v227 \n"
" v_mul_f32 v224, v24, v224 \n"
" v_mul_f32 v225, v24, v225 \n"
" v_mul_f32 v226, v24, v226 \n"
" v_mul_f32 v227, v24, v227 \n"
" v_mul_f32 v224, v13, v224 row_newbcast:0 \n"
" v_mul_f32 v225, v13, v225 row_newbcast:1 \n"
" v_mul_f32 v226, v13, v226 row_newbcast:2 \n"
" v_mul_f32 v227, v13, v227 row_newbcast:3 \n"
" v_mul_f32 v224, %[scale_0], v224 \n"
" v_mul_f32 v225, %[scale_0], v225 \n"
" v_mul_f32 v226, %[scale_0], v226 \n"
" v_mul_f32 v227, %[scale_0], v227 \n"
" v_cvt_f32_i32 v228, v228 \n"
" v_cvt_f32_i32 v229, v229 \n"
" v_cvt_f32_i32 v230, v230 \n"
" v_cvt_f32_i32 v231, v231 \n"
" v_mul_f32 v228, v25, v228 \n"
" v_mul_f32 v229, v25, v229 \n"
" v_mul_f32 v230, v25, v230 \n"
" v_mul_f32 v231, v25, v231 \n"
" v_mul_f32 v228, v13, v228 row_newbcast:0 \n"
" v_mul_f32 v229, v13, v229 row_newbcast:1 \n"
" v_mul_f32 v230, v13, v230 row_newbcast:2 \n"
" v_mul_f32 v231, v13, v231 row_newbcast:3 \n"
" v_mul_f32 v228, %[scale_1], v228 \n"
" v_mul_f32 v229, %[scale_1], v229 \n"
" v_mul_f32 v230, %[scale_1], v230 \n"
" v_mul_f32 v231, %[scale_1], v231 \n"
" v_cvt_f32_i32 v232, v232 \n"
" v_cvt_f32_i32 v233, v233 \n"
" v_cvt_f32_i32 v234, v234 \n"
" v_cvt_f32_i32 v235, v235 \n"
" v_mul_f32 v232, v24, v232 \n"
" v_mul_f32 v233, v24, v233 \n"
" v_mul_f32 v234, v24, v234 \n"
" v_mul_f32 v235, v24, v235 \n"
" v_mul_f32 v232, v13, v232 row_newbcast:4 \n"
" v_mul_f32 v233, v13, v233 row_newbcast:5 \n"
" v_mul_f32 v234, v13, v234 row_newbcast:6 \n"
" v_mul_f32 v235, v13, v235 row_newbcast:7 \n"
" v_mul_f32 v232, %[scale_0], v232 \n"
" v_mul_f32 v233, %[scale_0], v233 \n"
" v_mul_f32 v234, %[scale_0], v234 \n"
" v_mul_f32 v235, %[scale_0], v235 \n"
" v_cvt_f32_i32 v236, v236 \n"
" v_cvt_f32_i32 v237, v237 \n"
" v_cvt_f32_i32 v238, v238 \n"
" v_cvt_f32_i32 v239, v239 \n"
" v_mul_f32 v236, v25, v236 \n"
" v_mul_f32 v237, v25, v237 \n"
" v_mul_f32 v238, v25, v238 \n"
" v_mul_f32 v239, v25, v239 \n"
" v_mul_f32 v236, v13, v236 row_newbcast:4 \n"
" v_mul_f32 v237, v13, v237 row_newbcast:5 \n"
" v_mul_f32 v238, v13, v238 row_newbcast:6 \n"
" v_mul_f32 v239, v13, v239 row_newbcast:7 \n"
" v_mul_f32 v236, %[scale_1], v236 \n"
" v_mul_f32 v237, %[scale_1], v237 \n"
" v_mul_f32 v238, %[scale_1], v238 \n"
" v_mul_f32 v239, %[scale_1], v239 \n"
" v_cvt_f32_i32 v240, v240 \n"
" v_cvt_f32_i32 v241, v241 \n"
" v_cvt_f32_i32 v242, v242 \n"
" v_cvt_f32_i32 v243, v243 \n"
" v_mul_f32 v240, v24, v240 \n"
" v_mul_f32 v241, v24, v241 \n"
" v_mul_f32 v242, v24, v242 \n"
" v_mul_f32 v243, v24, v243 \n"
" v_mul_f32 v240, v13, v240 row_newbcast:8 \n"
" v_mul_f32 v241, v13, v241 row_newbcast:9 \n"
" v_mul_f32 v242, v13, v242 row_newbcast:10 \n"
" v_mul_f32 v243, v13, v243 row_newbcast:11 \n"
" v_mul_f32 v240, %[scale_0], v240 \n"
" v_mul_f32 v241, %[scale_0], v241 \n"
" v_mul_f32 v242, %[scale_0], v242 \n"
" v_mul_f32 v243, %[scale_0], v243 \n"
" v_cvt_f32_i32 v244, v244 \n"
" v_cvt_f32_i32 v245, v245 \n"
" v_cvt_f32_i32 v246, v246 \n"
" v_cvt_f32_i32 v247, v247 \n"
" v_mul_f32 v244, v25, v244 \n"
" v_mul_f32 v245, v25, v245 \n"
" v_mul_f32 v246, v25, v246 \n"
" v_mul_f32 v247, v25, v247 \n"
" v_mul_f32 v244, v13, v244 row_newbcast:8 \n"
" v_mul_f32 v245, v13, v245 row_newbcast:9 \n"
" v_mul_f32 v246, v13, v246 row_newbcast:10 \n"
" v_mul_f32 v247, v13, v247 row_newbcast:11 \n"
" v_mul_f32 v244, %[scale_1], v244 \n"
" v_mul_f32 v245, %[scale_1], v245 \n"
" v_mul_f32 v246, %[scale_1], v246 \n"
" v_mul_f32 v247, %[scale_1], v247 \n"
" v_cvt_f32_i32 v248, v248 \n"
" v_cvt_f32_i32 v249, v249 \n"
" v_cvt_f32_i32 v250, v250 \n"
" v_cvt_f32_i32 v251, v251 \n"
" v_mul_f32 v248, v24, v248 \n"
" v_mul_f32 v249, v24, v249 \n"
" v_mul_f32 v250, v24, v250 \n"
" v_mul_f32 v251, v24, v251 \n"
" v_mul_f32 v248, v13, v248 row_newbcast:12 \n"
" v_mul_f32 v249, v13, v249 row_newbcast:13 \n"
" v_mul_f32 v250, v13, v250 row_newbcast:14 \n"
" v_mul_f32 v251, v13, v251 row_newbcast:15 \n"
" v_mul_f32 v248, %[scale_0], v248 \n"
" v_mul_f32 v249, %[scale_0], v249 \n"
" v_mul_f32 v250, %[scale_0], v250 \n"
" v_mul_f32 v251, %[scale_0], v251 \n"
" v_cvt_f32_i32 v252, v252 \n"
" v_cvt_f32_i32 v253, v253 \n"
" v_cvt_f32_i32 v254, v254 \n"
" v_cvt_f32_i32 v255, v255 \n"
" v_mul_f32 v252, v25, v252 \n"
" v_mul_f32 v253, v25, v253 \n"
" v_mul_f32 v254, v25, v254 \n"
" v_mul_f32 v255, v25, v255 \n"
" v_mul_f32 v252, v13, v252 row_newbcast:12 \n"
" v_mul_f32 v253, v13, v253 row_newbcast:13 \n"
" v_mul_f32 v254, v13, v254 row_newbcast:14 \n"
" v_mul_f32 v255, v13, v255 row_newbcast:15 \n"
" v_mul_f32 v252, %[scale_1], v252 \n"
" v_mul_f32 v253, %[scale_1], v253 \n"
" v_mul_f32 v254, %[scale_1], v254 \n"
" v_mul_f32 v255, %[scale_1], v255 \n"
" v_cmp_u_f32 s[48:49], v224, v224 \n"
" v_add3_u32 v50, v224, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v225, v225 \n"
" v_add3_u32 v50, v225, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v224, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v226, v226 \n"
" v_add3_u32 v50, v226, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v227, v227 \n"
" v_add3_u32 v50, v227, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v225, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v228, v228 \n"
" v_add3_u32 v50, v228, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v229, v229 \n"
" v_add3_u32 v50, v229, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v226, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v230, v230 \n"
" v_add3_u32 v50, v230, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v231, v231 \n"
" v_add3_u32 v50, v231, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v227, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v232, v232 \n"
" v_add3_u32 v50, v232, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v233, v233 \n"
" v_add3_u32 v50, v233, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v228, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v234, v234 \n"
" v_add3_u32 v50, v234, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v235, v235 \n"
" v_add3_u32 v50, v235, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v229, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v236, v236 \n"
" v_add3_u32 v50, v236, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v237, v237 \n"
" v_add3_u32 v50, v237, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v230, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v238, v238 \n"
" v_add3_u32 v50, v238, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v239, v239 \n"
" v_add3_u32 v50, v239, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v231, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v240, v240 \n"
" v_add3_u32 v50, v240, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v241, v241 \n"
" v_add3_u32 v50, v241, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v232, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v242, v242 \n"
" v_add3_u32 v50, v242, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v243, v243 \n"
" v_add3_u32 v50, v243, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v233, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v244, v244 \n"
" v_add3_u32 v50, v244, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v245, v245 \n"
" v_add3_u32 v50, v245, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v234, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v246, v246 \n"
" v_add3_u32 v50, v246, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v247, v247 \n"
" v_add3_u32 v50, v247, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v235, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v248, v248 \n"
" v_add3_u32 v50, v248, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v249, v249 \n"
" v_add3_u32 v50, v249, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v236, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v250, v250 \n"
" v_add3_u32 v50, v250, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v251, v251 \n"
" v_add3_u32 v50, v251, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v237, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v252, v252 \n"
" v_add3_u32 v50, v252, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v253, v253 \n"
" v_add3_u32 v50, v253, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v238, v55, v54, s52 \n"
" v_cmp_u_f32 s[48:49], v254, v254 \n"
" v_add3_u32 v50, v254, v53, 1 \n"
" v_cndmask_b32 v54, v50, v52, s[48:49] \n"
" v_cmp_u_f32 s[48:49], v255, v255 \n"
" v_add3_u32 v50, v255, v53, 1 \n"
" v_cndmask_b32 v55, v50, v52, s[48:49] \n"
" v_perm_b32 v239, v55, v54, s52 \n"
" ds_write_b64 v3, v[224:225] offset:35072 \n"
" ds_write_b64 v3, v[226:227] offset:43776 \n"
" ds_write_b64 v3, v[228:229] offset:37248 \n"
" ds_write_b64 v3, v[230:231] offset:45952 \n"
" ds_write_b64 v3, v[232:233] offset:39424 \n"
" ds_write_b64 v3, v[234:235] offset:48128 \n"
" ds_write_b64 v3, v[236:237] offset:41600 \n"
" ds_write_b64 v3, v[238:239] offset:50304 \n"
" s_waitcnt lgkmcnt(0) \n"
" s_barrier \n"
" ds_read_b32 v64, v4 offset:35072 \n"
" ds_read_b32 v65, v4 offset:39424 \n"
" ds_read_b32 v66, v4 offset:35104 \n"
" ds_read_b32 v67, v4 offset:39456 \n"
" ds_read_b32 v68, v4 offset:35136 \n"
" ds_read_b32 v69, v4 offset:39488 \n"
" ds_read_b32 v70, v4 offset:35168 \n"
" ds_read_b32 v71, v4 offset:39520 \n"
" ds_read_b32 v72, v4 offset:43776 \n"
" ds_read_b32 v73, v4 offset:48128 \n"
" ds_read_b32 v74, v4 offset:43808 \n"
" ds_read_b32 v75, v4 offset:48160 \n"
" ds_read_b32 v76, v4 offset:43840 \n"
" ds_read_b32 v77, v4 offset:48192 \n"
" ds_read_b32 v78, v4 offset:43872 \n"
" ds_read_b32 v79, v4 offset:48224 \n"
" s_waitcnt lgkmcnt(0) \n"
" s_mov_b64 exec, %[s_execflag_0] \n"
" global_atomic_pk_add_bf16 %[v_os_o0], v64, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_0] \n"
" global_atomic_pk_add_bf16 %[v_os_o0], v65, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_1] \n"
" global_atomic_pk_add_bf16 %[v_os_o1], v66, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_1] \n"
" global_atomic_pk_add_bf16 %[v_os_o1], v67, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_2] \n"
" global_atomic_pk_add_bf16 %[v_os_o2], v68, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_2] \n"
" global_atomic_pk_add_bf16 %[v_os_o2], v69, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_3] \n"
" global_atomic_pk_add_bf16 %[v_os_o3], v70, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_3] \n"
" global_atomic_pk_add_bf16 %[v_os_o3], v71, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_4] \n"
" global_atomic_pk_add_bf16 %[v_os_o4], v72, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_4] \n"
" global_atomic_pk_add_bf16 %[v_os_o4], v73, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_5] \n"
" global_atomic_pk_add_bf16 %[v_os_o5], v74, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_5] \n"
" global_atomic_pk_add_bf16 %[v_os_o5], v75, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_6] \n"
" global_atomic_pk_add_bf16 %[v_os_o6], v76, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_6] \n"
" global_atomic_pk_add_bf16 %[v_os_o6], v77, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_7] \n"
" global_atomic_pk_add_bf16 %[v_os_o7], v78, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_7] \n"
" global_atomic_pk_add_bf16 %[v_os_o7], v79, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_add_u32 %[s_res_o0], %[s_tile_os_o], %[s_res_o0] \n"
" s_addc_u32 %[s_res_o1], 0, %[s_res_o1] \n"
" s_addk_i32 s80, 0x0100 \n"
" s_cmp_lt_i32 s80, %[s_loop_cnt] \n"
" s_cbranch_scc0 label_end_gemm2 \n"
" s_branch label_startgemm2 \n"
" label_end_gemm2: \n"
" s_waitcnt 0x0000 \n"
" s_endpgm \n"
#undef _UK_MFMA_
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
...@@ -19,9 +19,9 @@ ...@@ -19,9 +19,9 @@
" v_mul_f32 " a2 ", " gq ", " a2 " row_newbcast: " brd2 " \n" \ " v_mul_f32 " a2 ", " gq ", " a2 " row_newbcast: " brd2 " \n" \
" v_mul_f32 " a3 ", " gq ", " a3 " row_newbcast:" brd3 " \n" " v_mul_f32 " a3 ", " gq ", " a3 " row_newbcast:" brd3 " \n"
" s_mov_b32 s22, %[a_bound] \n"
"s_mov_b32 s20, %[s_res_a0] \n" "s_mov_b32 s20, %[s_res_a0] \n"
"s_mov_b32 s21, %[s_res_a1] \n" "s_mov_b32 s21, %[s_res_a1] \n"
"s_mov_b32 s22, %[s_res_a2] \n"
"s_mov_b32 s23, %[s_res_a3] \n" "s_mov_b32 s23, %[s_res_a3] \n"
"s_mov_b32 s24, %[s_res_b0] \n" "s_mov_b32 s24, %[s_res_b0] \n"
"s_mov_b32 s25, %[s_res_b1] \n" "s_mov_b32 s25, %[s_res_b1] \n"
...@@ -110,38 +110,38 @@ ...@@ -110,38 +110,38 @@
" s_add_u32 s20, s57, s20 \n" " s_add_u32 s20, s57, s20 \n"
" s_addc_u32 s21, 0, s21 \n" " s_addc_u32 s21, 0, s21 \n"
"; -- prefetch B0\n" "; -- prefetch B0\n"
"buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[24:27], 0 offen \n" " buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[24:27], 0 offen offset:1024 \n" " buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[24:27], 0 offen offset:2048 \n" " buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[24:27], 0 offen offset:3072 \n" " buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[24:27], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[24:27], 0 offen \n" " buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[24:27], 0 offen offset:1024 \n" " buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[24:27], 0 offen offset:2048 \n" " buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[24:27], 0 offen offset:3072 \n" " buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[24:27], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[24:27], 0 offen \n" " buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[24:27], 0 offen offset:1024 \n" " buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[24:27], 0 offen offset:2048 \n" " buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[24:27], 0 offen offset:3072 \n" " buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[24:27], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[24:27], 0 offen \n" " buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[24:27], 0 offen offset:1024 \n" " buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[24:27], 0 offen offset:2048 \n" " buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[24:27], 0 offen offset:3072 \n" " buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[24:27], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[24:27], 0 offen \n" " buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[24:27], 0 offen offset:1024 \n" " buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[24:27], 0 offen offset:2048 \n" " buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[24:27], 0 offen offset:3072 \n" " buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[24:27], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[24:27], 0 offen \n" " buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[24:27], 0 offen offset:1024 \n" " buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[24:27], 0 offen offset:2048 \n" " buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[24:27], 0 offen offset:3072 \n" " buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[24:27], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[24:27], 0 offen \n" " buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[24:27], 0 offen offset:1024 \n" " buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[24:27], 0 offen offset:2048 \n" " buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[24:27], 0 offen offset:3072 \n" " buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[24:27], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[24:27], 0 offen \n" " buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[24:27], 0 offen offset:1024 \n" " buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[24:27], 0 offen offset:2048 \n" " buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[24:27], 0 offen offset:3072 \n" " buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[24:27], 0 offen offset:3072 \n"
"s_add_u32 s24, s58, s24 \n" "s_add_u32 s24, s58, s24 \n"
"s_addc_u32 s25, 0, s25 \n" "s_addc_u32 s25, 0, s25 \n"
......
...@@ -237,12 +237,23 @@ struct FusedMoeGemmKernel ...@@ -237,12 +237,23 @@ struct FusedMoeGemmKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
{ {
if constexpr(UseUK) if constexpr(UseUK)
{ {
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
__shared__ CK_TILE_LDS_ADDR ADataType smem[65536];
// index_t s_size = GetSmemSize();
// ADataType{}.aaa();
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane( IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr)); *reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
// __builtin_amdgcn_sched_barrier(0);
// if(threadIdx.x == 0){
// printf("num_sorted_tiles %d\n", num_sorted_tiles);
// printf("data type :%s\n", t2s<ADataType>::name);
// printf("\nblockIdx.x :%x, blockIdx.y :%x,\n", blockIdx.x, blockIdx.y);
// __builtin_amdgcn_sched_barrier(0);
// }
// __builtin_amdgcn_sched_barrier(0);
num_sorted_tiles = num_sorted_tiles / BlockShape::Block_M0; num_sorted_tiles = num_sorted_tiles / BlockShape::Block_M0;
const auto [sorted_tile_id, intermediate_tile_id] = const auto [sorted_tile_id, intermediate_tile_id] =
......
...@@ -74,7 +74,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -74,7 +74,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize(); constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
constexpr index_t smem_bridge = constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType); BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
return max(smem_0, max(smem_1, smem_bridge)); return max(smem_0, max(smem_1, smem_bridge));
} }
// this is the thread-offset along row/col // this is the thread-offset along row/col
...@@ -159,15 +159,22 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -159,15 +159,22 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
template <typename ROW_IDS> template <typename ROW_IDS>
CK_TILE_DEVICE auto GetAScale(const ROW_IDS row_ids_mma, CK_TILE_DEVICE auto GetAScale(const ROW_IDS row_ids_mma,
const AScaleDataType* a_scale_ptr) // const AScaleDataType* a_scale_ptr, index_t num_tokens_)
index_t num_tokens_)
{ {
constexpr index_t n_size = row_ids_mma.size(); constexpr index_t n_size = row_ids_mma.size();
array<TopkWeightDataType, n_size> w; array<TopkWeightDataType, n_size> w;
static_for<0, n_size, 1>{}([&](auto i) { static_for<0, n_size, 1>{}([&](auto i) {
auto row_id = row_ids_mma[i] & 0xffffff; auto row_id = row_ids_mma[i] & 0xffffff;
auto itp_k = row_ids_mma[i] >> 24; if (row_id >= num_tokens_)
w.at(i) = a_scale_ptr[row_id * 5+itp_k]; {
w.at(i) = 0.f;
} else {
w.at(i) = 1.f;
// auto itp_k = row_ids_mma[i] >> 24;
// w.at(i) = a_scale_ptr[row_id * 5+itp_k];
}
}); });
return w; return w;
...@@ -247,7 +254,12 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -247,7 +254,12 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
index_t kr_0 = kargs.hidden_size / BlockShape::Warp_K0; // divide K in W index_t kr_0 = kargs.hidden_size / BlockShape::Warp_K0; // divide K in W
index_t nr_1 = kargs.hidden_size / BlockShape::Warp_N1; index_t nr_1 = kargs.hidden_size / BlockShape::Warp_N1;
index_t kr_1 = shared_intermediate_size_1 / BlockShape::Warp_K1; index_t kr_1 = shared_intermediate_size_1 / BlockShape::Warp_K1;
// if(threadIdx.x == 31 && blockIdx.x == 0 && blockIdx.y == 0)
// {
// printf("\nWarpPerBlock_N0 :%x, WarpPerBlock_M0:%x,\n", BlockShape::WarpPerBlock_N0
// , BlockShape::WarpPerBlock_M0);
// }
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane( const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]); reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size; index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size;
...@@ -271,7 +283,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -271,7 +283,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
row_coords_a_mma, reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr)); row_coords_a_mma, reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr));
auto token_id = generate_tuple( auto token_id = generate_tuple(
[&](auto i) { [&](auto i) {
return (row_ids_a[i]) &0xffffff; return (row_ids_a[i] &0xffffff);
}, },
number<row_ids_a.size()>{}); number<row_ids_a.size()>{});
auto a_coords = generate_tuple( auto a_coords = generate_tuple(
...@@ -385,10 +397,16 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -385,10 +397,16 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
threadIdx.x % (BlockShape::Block_N1/2 / kAlignmentO) * kAlignmentO; threadIdx.x % (BlockShape::Block_N1/2 / kAlignmentO) * kAlignmentO;
}, },
number<row_ids_a.size()>{}); number<row_ids_a.size()>{});
auto o_flags = auto o_flags =
generate_tuple([&](auto i) { return cmp_lt_to_exec(token_id[i], kargs.num_tokens); }, generate_tuple([&](auto i) { return cmp_lt_to_exec(token_id[i], kargs.num_tokens); },
number<row_ids_a.size()>{}); // generate_tuple([&](auto i) {
// if (__builtin_amdgcn_readfirstlane(token_id[i]) < kargs.num_tokens)
// {return 0xffffffffffffffff;}
// else
// {return uint32x2_t 0;}
// },
number<token_id.size()>{});
auto o_res = auto o_res =
make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr), make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr),
...@@ -398,13 +416,59 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -398,13 +416,59 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto w_scale = GetWeightScale( auto w_scale = GetWeightScale(
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)); row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
auto a_scale = GetAScale( auto a_scale = GetAScale(
row_ids_a_mma, reinterpret_cast<const AScaleDataType*>(kargs.a_scale_ptr)); // row_ids_a_mma, reinterpret_cast<const AScaleDataType*>(kargs.a_scale_ptr), kargs.num_tokens );
row_ids_a_mma, kargs.num_tokens );
auto gqsmq_coords = GetColCoords_GQSMQ(intermediate_tile_id * BlockShape::Block_K1); auto gqsmq_coords = GetColCoords_GQSMQ(intermediate_tile_id * BlockShape::Block_K1);
auto dq_coords = gqsmq_coords[0];//only one for this tiling auto dq_coords = gqsmq_coords[0];//only one for this tiling
auto gq_scale = GetGQScale( auto gq_scale = GetGQScale(
gqsmq_coords, (reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr) + static_cast<long_index_t>(expert_id) * shared_intermediate_size_0)); gqsmq_coords, (reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr) + static_cast<long_index_t>(expert_id) * shared_intermediate_size_0));
auto smq_scale = GetSMQScale( auto smq_scale = GetSMQScale(
gqsmq_coords, (reinterpret_cast<const YSmoothScaleDataType*>(kargs.y_smooth_scale_ptr) + static_cast<long_index_t>(expert_id) * shared_intermediate_size_0)); gqsmq_coords, (reinterpret_cast<const YSmoothScaleDataType*>(kargs.y_smooth_scale_ptr) + static_cast<long_index_t>(expert_id) * shared_intermediate_size_0));
if(threadIdx.x == 95 && blockIdx.x == 0 && blockIdx.y == 0)
{
printf("\nblockIdx.x :%x, blockIdx.y :%x, d ptr: %p, wg d ptr :%x%x,gemm0 done\n", blockIdx.x, blockIdx.y, kargs.d_ptr,d_res[1],d_res[0]);
// // printf("\n wg 1 1, wave 1, row_coords_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_coords_a[number<0>{}],row_coords_a[number<1>{}],row_coords_a[number<2>{}],row_coords_a[number<3>{}], row_coords_a[number<4>{}],row_coords_a[number<5>{}],row_coords_a[number<6>{}],row_coords_a[number<7>{}]);
// // printf("\n -------------- -row_ids_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_ids_a[number<0>{}],row_ids_a[number<1>{}],row_ids_a[number<2>{}],row_ids_a[number<3>{}], row_ids_a[number<4>{}],row_ids_a[number<5>{}],row_ids_a[number<6>{}],row_ids_a[number<7>{}]);
// // printf("\n -----------thread id %x--- - token_id 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", hipThreadIdx_x , token_id[number<0>{}],token_id[number<1>{}],token_id[number<2>{}],token_id[number<3>{}], token_id[number<4>{}],token_id[number<5>{}],token_id[number<6>{}],token_id[number<7>{}]);
// // printf("\n -----------thread id %x--- - token_id , 7:%x,, \n", hipThreadIdx_x , token_id[number<7>{}]);
// // printf("\n -------------- - exec 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", o_flags[number<0>{}][0],o_flags[number<1>{}][0],o_flags[number<2>{}][0],o_flags[number<3>{}][0], o_flags[number<4>{}][0],o_flags[number<5>{}][0],o_flags[number<6>{}][0],o_flags[number<7>{}][0]);
printf("\ntoken id :%x,%x,%x,%x, %x,%x,%x,%x \n d_coords: %x,%x,%x,%x, \n row_idx: %x,%x,%x,%x, %x,%x,%x,%x \n o_flags:%x,%x,%x,%x, %x,%x,%x,%x \n",
token_id[number<0>{}],
token_id[number<1>{}],
token_id[number<2>{}],
token_id[number<3>{}],
token_id[number<4>{}],
token_id[number<5>{}],
token_id[number<6>{}],
token_id[number<7>{}],
d_coords[number<0>{}],
d_coords[number<1>{}],
d_coords[number<2>{}],
d_coords[number<3>{}],
// d_coords[number<4>{}],
// d_coords[number<5>{}],
// d_coords[number<6>{}],
// d_coords[number<7>{}],
row_ids_a[number<0>{}],
row_ids_a[number<1>{}],
row_ids_a[number<2>{}],
row_ids_a[number<3>{}],
row_ids_a[number<4>{}],
row_ids_a[number<5>{}],
row_ids_a[number<6>{}],
row_ids_a[number<7>{}],
o_flags[number<0>{}][0],
o_flags[number<1>{}][0],
o_flags[number<2>{}][0],
o_flags[number<3>{}][0],
o_flags[number<4>{}][0],
o_flags[number<5>{}][0],
o_flags[number<6>{}][0],
o_flags[number<7>{}][0]);
// return;
}
__builtin_amdgcn_sched_barrier(0);
auto uk_0 = Policy::template GetUK_0<Problem>(); auto uk_0 = Policy::template GetUK_0<Problem>();
// auto acc_0= uk_0( // auto acc_0= uk_0(
uk_0(a_scale, uk_0(a_scale,
...@@ -418,17 +482,12 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -418,17 +482,12 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
smem, smem,
kargs.hidden_size, kargs.hidden_size,
BlockShape::Block_K0, // tile offset for B matrix each unroll BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 * 16*256,
BlockShape::Block_W0); // tile offset for B matrix each unroll kargs.num_tokens * kargs.stride_token); // tile offset for B matrix each unroll
if(hipBlockIdx_x == 1 && hipBlockIdx_y == 1 && hipBlockIdx_z == 0 && // return;
hipThreadIdx_x == 64) __builtin_amdgcn_sched_barrier(0);
{
printf("\ngemm0 done\n"); // // sweep_tile(
// printf("\n wg 1 1, wave 1, row_coords_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_coords_a[number<0>{}],row_coords_a[number<1>{}],row_coords_a[number<2>{}],row_coords_a[number<3>{}], row_coords_a[number<4>{}],row_coords_a[number<5>{}],row_coords_a[number<6>{}],row_coords_a[number<7>{}]);
// printf("\n -------------- -row_ids_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_ids_a[number<0>{}],row_ids_a[number<1>{}],row_ids_a[number<2>{}],row_ids_a[number<3>{}], row_ids_a[number<4>{}],row_ids_a[number<5>{}],row_ids_a[number<6>{}],row_ids_a[number<7>{}]);
printf("\n -------------- - token_id 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", token_id[number<0>{}],token_id[number<1>{}],token_id[number<2>{}],token_id[number<3>{}], token_id[number<4>{}],token_id[number<5>{}],token_id[number<6>{}],token_id[number<7>{}]);
}
// sweep_tile(
// acc_0, // acc_0,
// [&](auto idx0, auto idx1) { // [&](auto idx0, auto idx1) {
// fp32x2_t v_{acc_0(idx0), acc_0(idx1)}; // fp32x2_t v_{acc_0(idx0), acc_0(idx1)};
...@@ -449,6 +508,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -449,6 +508,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
uk_1( uk_1(
// dq_res, // dq_res,
// d_res, // d_res,
token_id,
dq_coords, dq_coords,
d_coords, d_coords,
o_res, o_res,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment