Commit 339a674b authored by shengnxu's avatar shengnxu
Browse files

current status: single WG, memory out of bound

parent 5d00b37e
...@@ -19,14 +19,14 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile: ...@@ -19,14 +19,14 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{ {
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; // using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
r = fused_moegemm_<t_>(s, a); // r = fused_moegemm_<t_>(s, a);
} }
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" && else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{ {
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; // using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
r = fused_moegemm_<t_>(s, a); // r = fused_moegemm_<t_>(s, a);
} }
else if(t.prec_i == "int8" && t.prec_w == "int8" && t.prec_o == "bf16" && t.prec_st == "fp32" && else if(t.prec_i == "int8" && t.prec_w == "int8" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
......
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
#include "fused_moegemm_api_internal.hpp" #include "fused_moegemm_api_internal.hpp"
// clang-format off // clang-format off
template float fused_moegemm_< // template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0> // fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a); // >(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on // clang-format on
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
#include "fused_moegemm_api_internal.hpp" #include "fused_moegemm_api_internal.hpp"
// clang-format off // clang-format off
template float fused_moegemm_< // template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0> // fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a); // >(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on // clang-format on
...@@ -87,11 +87,11 @@ void topid_unique_gen( ...@@ -87,11 +87,11 @@ void topid_unique_gen(
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("t", "128", "num input tokens") arg_parser.insert("t", "32", "num input tokens")
.insert("e", "32", "num of experts") .insert("e", "1", "num of experts")
.insert("k", "5", "topk") .insert("k", "1", "topk")
.insert("h", "8192", "hidden_size of this model") .insert("h", "256", "hidden_size of this model")
.insert("i", "8192", "intermediate_size between 2 gemms of FFN") .insert("i", "4096", "intermediate_size between 2 gemms of FFN")
.insert("stride", "-1", "stride per row, if -1 then equal to hidden_size") .insert("stride", "-1", "stride per row, if -1 then equal to hidden_size")
.insert("bm", "32", "blocking factor for sorted tokens") .insert("bm", "32", "blocking factor for sorted tokens")
.insert("tp", "8", "tensor parallel size") .insert("tp", "8", "tensor parallel size")
......
...@@ -242,6 +242,7 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16 ...@@ -242,6 +242,7 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
{ {
using ADataType = int8_t; using ADataType = int8_t;
using BDataType = int8_t; using BDataType = int8_t;
using AScaleDataType = float;
// TODO: need paired with tile_window_linear! // TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function! // TODO: need call init_raw() before call this function!
...@@ -258,7 +259,8 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16 ...@@ -258,7 +259,8 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
CK_TILE_LDS_ADDR void* smem, CK_TILE_LDS_ADDR void* smem,
index_t k, index_t k,
index_t tile_offset_a, // for each tile, the offset to move for each unroll index_t tile_offset_a, // for each tile, the offset to move for each unroll
index_t tile_offset_b) // for each tile, the offset to move for each unroll index_t tile_offset_b,
index_t a_bound_) // for each tile, the offset to move for each unroll
{ {
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 4 /*2x per dword*/); // 8 static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 4 /*2x per dword*/); // 8
static_assert(BCoords::size() == Repeat_N); static_assert(BCoords::size() == Repeat_N);
...@@ -449,7 +451,10 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16 ...@@ -449,7 +451,10 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
[c62]"+v"(v_z62), [c62]"+v"(v_z62),
[c63]"+v"(v_z63), [c63]"+v"(v_z63),
[s_mem_]"+r"(smem) [s_mem_]"+r"(smem)
: [a_scale0]"v"(a_scale_[0]), :
[a_bound]"s"(static_cast<int>(a_bound_ * sizeof(ADataType))),
// [a_scale_bound]"s"(a_scale_bound_ * sizeof(AScaleDataType)),
[a_scale0]"v"(a_scale_[0]),
[a_scale1]"v"(a_scale_[1]), [a_scale1]"v"(a_scale_[1]),
[gq_scale0]"v"(gq_scale_[0]), [gq_scale0]"v"(gq_scale_[0]),
[gq_scale1]"v"(gq_scale_[1]), [gq_scale1]"v"(gq_scale_[1]),
......
...@@ -81,6 +81,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -81,6 +81,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
template < template <
// typename DQRes, // typename DQRes,
// typename BRes, // typename BRes,
typename Tokenids,
typename DQCoords, typename DQCoords,
typename BCoords, typename BCoords,
typename ORes, typename ORes,
...@@ -92,6 +93,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -92,6 +93,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
operator()( operator()(
// const DQRes& res_dq, // const DQRes& res_dq,
// const BRes& res_b, // const BRes& res_b,
const Tokenids& token_id_,
const DQCoords& cached_coords_dq, const DQCoords& cached_coords_dq,
const BCoords& cached_coords_b, const BCoords& cached_coords_b,
const ORes& res_o, const ORes& res_o,
...@@ -108,7 +110,6 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -108,7 +110,6 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
{ {
static_assert(BCoords::size() == 4); // 8 static_assert(BCoords::size() == 4); // 8
static_assert(OCoords::size() == 8); static_assert(OCoords::size() == 8);
const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType); const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
const index_t tile_offset_half_b_bytes = tile_offset_half_b * sizeof(BDataType); const index_t tile_offset_half_b_bytes = tile_offset_half_b * sizeof(BDataType);
const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType); const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
...@@ -155,8 +156,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -155,8 +156,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8
#include "uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_1.inc" #include "uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_1.inc"
#undef CK_TILE_FLATMM_UK_MFMA #undef CK_TILE_FLATMM_UK_MFMA
:[smem_]"+r"(smem), :[smem_]"+r"(smem)
[s_loop_cnt]"+s"(loop_cnt) // [s_loop_cnt]"+s"(loop_cnt)
:[sld_a_base]"n"(0), :[sld_a_base]"n"(0),
// [shfl_base]"n"(0), // [shfl_base]"n"(0),
// [v_sld_y_os]"v"(sld_y_os), // [v_sld_y_os]"v"(sld_y_os),
...@@ -164,8 +165,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -164,8 +165,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
// [v_sfl_sst]"v"(sfl_sst), // [v_sfl_sst]"v"(sfl_sst),
[smq_scale0]"s"(smq_scale_[0]), [smq_scale0]"s"(smq_scale_[0]),
[smq_scale1]"s"(smq_scale_[1]), [smq_scale1]"s"(smq_scale_[1]),
[s_res_o0]"s"(res_o[0]), // [s_res_o0]"s"(res_o[0]),
[s_res_o1]"s"(res_o[1]), // [s_res_o1]"s"(res_o[1]),
//[s_res_o2]"s"(res_o[2]), //[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]), //[s_res_o3]"s"(res_o[3]),
[v_os_dq]"v"(static_cast<index_t>(cached_coords_dq * sizeof(DScaleDataType))), [v_os_dq]"v"(static_cast<index_t>(cached_coords_dq * sizeof(DScaleDataType))),
...@@ -184,19 +185,19 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -184,19 +185,19 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
[s_tile_os_o]"s"(tile_stride_o_bytes), [s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b_half]"s"(tile_offset_half_b_bytes), [s_tile_os_b_half]"s"(tile_offset_half_b_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes), [s_tile_os_b]"s"(tile_stride_b_bytes),
[s_tile_os_dq]"s"(tile_stride_dq_bytes), [s_tile_os_dq]"s"(tile_stride_dq_bytes)
// [scale_0]"v"(s0), // [scale_0]"v"(s0),
// [scale_1]"v"(s1), // [scale_1]"v"(s1),
// [v_nan_lo]"v"(nan_lo), // [v_nan_lo]"v"(nan_lo),
// [v_nan_hi]"v"(nan_hi), // [v_nan_hi]"v"(nan_hi),
[s_execflag_0]"s"(o_flags[number<0>{}]), // [s_execflag_0]"s"(o_flags[number<0>{}]),
[s_execflag_1]"s"(o_flags[number<1>{}]), // [s_execflag_1]"s"(o_flags[number<1>{}]),
[s_execflag_2]"s"(o_flags[number<2>{}]), // [s_execflag_2]"s"(o_flags[number<2>{}]),
[s_execflag_3]"s"(o_flags[number<3>{}]), // [s_execflag_3]"s"(o_flags[number<3>{}]),
[s_execflag_4]"s"(o_flags[number<4>{}]), // [s_execflag_4]"s"(o_flags[number<4>{}]),
[s_execflag_5]"s"(o_flags[number<5>{}]), // [s_execflag_5]"s"(o_flags[number<5>{}]),
[s_execflag_6]"s"(o_flags[number<6>{}]), // [s_execflag_6]"s"(o_flags[number<6>{}]),
[s_execflag_7]"s"(o_flags[number<7>{}]) // [s_execflag_7]"s"(o_flags[number<7>{}])
: :
"memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19", "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
...@@ -228,7 +229,9 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -228,7 +229,9 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243", "a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251", "a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255", "a252", "a253", "a254", "a255",
"s6", "s7", "s40", "s41", "s42", "s43", "s44", "s45", "s6", "s7","s20", "s21", "s22", "s23", "s24", "s25", "s26", "s27",
"s28", "s29", "s30", "s31", "s34", "s35", "s38", "s39",
"s40", "s41", "s42", "s43", "s44", "s45",
"s46", "s47", "s48", "s49", "s50", "s51", "s52", "s53", "s54", "s46", "s47", "s48", "s49", "s50", "s51", "s52", "s53", "s54",
"s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63", "s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63",
"s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72", "s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72",
...@@ -260,12 +263,14 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x ...@@ -260,12 +263,14 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
"v245", "v246", "v247", "v248", "v249", "v250", "v251", "v252", "v245", "v246", "v247", "v248", "v249", "v250", "v251", "v252",
"v253", "v254", "v255" "v253", "v254", "v255"
); );
if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 &&
hipThreadIdx_x == 5)
{
printf("\n sn0 done\n");
} // if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0)
// {
// // printf("\n xyz%x,%x,%x,thread idx:%xsn1 done\n",blockIdx.x, blockIdx.y, blockIdx.z ,threadIdx.x );
// printf("\n sn1 done\n");
// }
// return;
asm volatile( asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8 #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8
#include "uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc" #include "uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc"
...@@ -288,6 +293,123 @@ if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 && ...@@ -288,6 +293,123 @@ if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 &&
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))), [v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))), [v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))), [v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b_half]"s"(tile_offset_half_b_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes),
[s_tile_os_dq]"s"(tile_stride_dq_bytes),
[scale_0]"v"(s0),
[scale_1]"v"(s1)
// [v_nan_lo]"v"(nan_lo),
// [v_nan_hi]"v"(nan_hi),
// [s_execflag_0]"s"(o_flags[number<0>{}]),
// [s_execflag_1]"s"(o_flags[number<1>{}]),
// [s_execflag_2]"s"(o_flags[number<2>{}]),
// [s_execflag_3]"s"(o_flags[number<3>{}]),
// [s_execflag_4]"s"(o_flags[number<4>{}]),
// [s_execflag_5]"s"(o_flags[number<5>{}]),
// [s_execflag_6]"s"(o_flags[number<6>{}]),
// [s_execflag_7]"s"(o_flags[number<7>{}])
:
"memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s6", "s7","s20", "s21", "s22", "s23", "s24", "s25", "s26", "s27",
"s28", "s29", "s30", "s31", "s38", "s39", "s34", "s35", "s40", "s41", "s42", "s43", "s44", "s45",
"s46", "s47", "s48", "s49", "s50", "s51", "s52", "s53", "s54",
"s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63",
"s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72",
"s73", "s74", "s75", "s76", "s77", "s78", "s79", "s80", // s86 as tmp
"v1", "v2", "v3", "v4", "v5", "v12", "v13", "v21", "v22", "v23", "v24", "v25", "v50", "v51", "v52", "v53", "v54", "v55",
"v56", "v57", "v64",
"v65", "v66", "v67", "v68", "v69", "v70", "v71", "v72", "v73",
"v74", "v75", "v76", "v77", "v78", "v79", "v80", "v81", "v82",
"v83", "v84", "v85", "v86", "v87", "v88", "v89", "v90", "v91",
"v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99", "v100",
"v101", "v102", "v103", "v104", "v105", "v106", "v107", "v108",
"v109", "v110", "v111", "v112", "v113", "v114", "v115", "v116",
"v117", "v118", "v119", "v120", "v121", "v122", "v123", "v124",
"v125", "v126", "v127", "v128", "v129", "v130", "v131", "v132",
"v133", "v134", "v135", "v136", "v137", "v138", "v139", "v140",
"v141", "v142", "v143", "v144", "v145", "v146", "v147", "v148",
"v149", "v150", "v151", "v152", "v153", "v154", "v155", "v156",
"v157", "v158", "v159", "v160", "v161", "v162", "v163", "v164",
"v165", "v166", "v167", "v168", "v169", "v170", "v171", "v172",
"v173", "v174", "v175", "v176", "v177", "v178", "v179", "v180",
"v181", "v182", "v183", "v184", "v185", "v186", "v187", "v188",
"v189", "v190", "v191", "v192", "v193", "v194", "v195", "v196",
"v197", "v198", "v199", "v200", "v201", "v202", "v203", "v204",
"v205", "v206", "v207", "v208", "v209", "v210", "v211", "v212",
"v213", "v214", "v215", "v216", "v217", "v218", "v219", "v220",
"v221", "v222", "v223", "v224", "v225", "v226", "v227", "v228",
"v229", "v230", "v231", "v232", "v233", "v234", "v235", "v236",
"v237", "v238", "v239", "v240", "v241", "v242", "v243", "v244",
"v245", "v246", "v247", "v248", "v249", "v250", "v251", "v252",
"v253", "v254", "v255"
);
// if(hipBlockIdx_x == 0 && hipBlockIdx_y == 1 && hipBlockIdx_z == 0 &&
// hipThreadIdx_x == 0)
// {
// printf("\n sn2 done\n");
// }
return;
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8
#include "uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_3.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:[smem_]"+r"(smem),
[s_loop_cnt]"+s"(loop_cnt)
:[sld_a_base]"n"(0),
[s_res_o0]"s"(res_o[0]),
[s_res_o1]"s"(res_o[1]),
[v_os_dq]"v"(static_cast<index_t>(cached_coords_dq * sizeof(DScaleDataType))),
[v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
[v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
[v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
[v_os_o3]"v"(static_cast<index_t>(cached_coords_o[number<3>{}] * sizeof(ODataType))),
[v_os_o4]"v"(static_cast<index_t>(cached_coords_o[number<4>{}] * sizeof(ODataType))),
[v_os_o5]"v"(static_cast<index_t>(cached_coords_o[number<5>{}] * sizeof(ODataType))),
[v_os_o6]"v"(static_cast<index_t>(cached_coords_o[number<6>{}] * sizeof(ODataType))),
[v_os_o7]"v"(static_cast<index_t>(cached_coords_o[number<7>{}] * sizeof(ODataType))),
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[s_token_id0]"s"(token_id_[number<0>{}]),
[s_token_id1]"s"(token_id_[number<1>{}]),
[s_token_id2]"s"(token_id_[number<2>{}]),
[s_token_id3]"s"(token_id_[number<3>{}]),
[s_token_id4]"s"(token_id_[number<4>{}]),
[s_token_id5]"s"(token_id_[number<5>{}]),
[s_token_id6]"s"(token_id_[number<6>{}]),
[s_token_id7]"s"(token_id_[number<7>{}]),
[s_tile_os_o]"s"(tile_stride_o_bytes), [s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b_half]"s"(tile_offset_half_b_bytes), [s_tile_os_b_half]"s"(tile_offset_half_b_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes), [s_tile_os_b]"s"(tile_stride_b_bytes),
...@@ -335,7 +457,8 @@ if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 && ...@@ -335,7 +457,8 @@ if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 &&
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243", "a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251", "a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255", "a252", "a253", "a254", "a255",
"s6", "s7", "s40", "s41", "s42", "s43", "s44", "s45", "s6", "s7","s20", "s21", "s22", "s23", "s24", "s25", "s26", "s27",
"s28", "s29", "s30", "s31", "s38", "s39", "s34", "s35", "s40", "s41", "s42", "s43", "s44", "s45",
"s46", "s47", "s48", "s49", "s50", "s51", "s52", "s53", "s54", "s46", "s47", "s48", "s49", "s50", "s51", "s52", "s53", "s54",
"s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63", "s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63",
"s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72", "s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72",
...@@ -367,7 +490,7 @@ if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 && ...@@ -367,7 +490,7 @@ if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 &&
"v245", "v246", "v247", "v248", "v249", "v250", "v251", "v252", "v245", "v246", "v247", "v248", "v249", "v250", "v251", "v252",
"v253", "v254", "v255" "v253", "v254", "v255"
); );
#pragma clang diagnostic pop #pragma clang diagnostic pop
// clang-format on // clang-format on
} }
}; };
......
...@@ -31,8 +31,8 @@ ...@@ -31,8 +31,8 @@
" v_lshrrev_b32 v3, 6, v0 \n" " v_lshrrev_b32 v3, 6, v0 \n"
" v_readfirstlane_b32 s7, v3 \n" " v_readfirstlane_b32 s7, v3 \n"
" s_waitcnt vmcnt(24) \n" " s_waitcnt vmcnt(24) \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen\n" "buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen\n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024\n" "buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024\n"
" v_mul_f32 v54, v128, v128 \n" " v_mul_f32 v54, v128, v128 \n"
" v_mul_f32 v55, v129, v129 \n" " v_mul_f32 v55, v129, v129 \n"
" v_mul_f32 v56, v130, v130 \n" " v_mul_f32 v56, v130, v130 \n"
...@@ -65,7 +65,7 @@ ...@@ -65,7 +65,7 @@
" v_mul_f32 v129, v129, v55 \n" " v_mul_f32 v129, v129, v55 \n"
" v_mul_f32 v130, v130, v56 \n" " v_mul_f32 v130, v130, v56 \n"
" v_mul_f32 v131, v131, v57 \n" " v_mul_f32 v131, v131, v57 \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048\n" "buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v132, v132 \n" " v_mul_f32 v54, v132, v132 \n"
" v_mul_f32 v55, v133, v133 \n" " v_mul_f32 v55, v133, v133 \n"
" v_mul_f32 v56, v134, v134 \n" " v_mul_f32 v56, v134, v134 \n"
...@@ -86,7 +86,7 @@ ...@@ -86,7 +86,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072\n" "buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -99,7 +99,7 @@ ...@@ -99,7 +99,7 @@
" v_mul_f32 v133, v133, v55 \n" " v_mul_f32 v133, v133, v55 \n"
" v_mul_f32 v134, v134, v56 \n" " v_mul_f32 v134, v134, v56 \n"
" v_mul_f32 v135, v135, v57 \n" " v_mul_f32 v135, v135, v57 \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen\n" "buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen\n"
" v_mul_f32 v54, v136, v136 \n" " v_mul_f32 v54, v136, v136 \n"
" v_mul_f32 v55, v137, v137 \n" " v_mul_f32 v55, v137, v137 \n"
" v_mul_f32 v56, v138, v138 \n" " v_mul_f32 v56, v138, v138 \n"
...@@ -120,7 +120,7 @@ ...@@ -120,7 +120,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024\n" "buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -133,7 +133,7 @@ ...@@ -133,7 +133,7 @@
" v_mul_f32 v137, v137, v55 \n" " v_mul_f32 v137, v137, v55 \n"
" v_mul_f32 v138, v138, v56 \n" " v_mul_f32 v138, v138, v56 \n"
" v_mul_f32 v139, v139, v57 \n" " v_mul_f32 v139, v139, v57 \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048\n" "buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v140, v140 \n" " v_mul_f32 v54, v140, v140 \n"
" v_mul_f32 v55, v141, v141 \n" " v_mul_f32 v55, v141, v141 \n"
" v_mul_f32 v56, v142, v142 \n" " v_mul_f32 v56, v142, v142 \n"
...@@ -154,7 +154,7 @@ ...@@ -154,7 +154,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072\n" "buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -168,7 +168,7 @@ ...@@ -168,7 +168,7 @@
" v_mul_f32 v142, v142, v56 \n" " v_mul_f32 v142, v142, v56 \n"
" v_mul_f32 v143, v143, v57 \n" " v_mul_f32 v143, v143, v57 \n"
" s_waitcnt vmcnt(24) \n" " s_waitcnt vmcnt(24) \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen\n" "buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen\n"
" v_mul_f32 v54, v144, v144 \n" " v_mul_f32 v54, v144, v144 \n"
" v_mul_f32 v55, v145, v145 \n" " v_mul_f32 v55, v145, v145 \n"
" v_mul_f32 v56, v146, v146 \n" " v_mul_f32 v56, v146, v146 \n"
...@@ -189,7 +189,7 @@ ...@@ -189,7 +189,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024\n" "buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -202,7 +202,7 @@ ...@@ -202,7 +202,7 @@
" v_mul_f32 v145, v145, v55 \n" " v_mul_f32 v145, v145, v55 \n"
" v_mul_f32 v146, v146, v56 \n" " v_mul_f32 v146, v146, v56 \n"
" v_mul_f32 v147, v147, v57 \n" " v_mul_f32 v147, v147, v57 \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048\n" "buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v148, v148 \n" " v_mul_f32 v54, v148, v148 \n"
" v_mul_f32 v55, v149, v149 \n" " v_mul_f32 v55, v149, v149 \n"
" v_mul_f32 v56, v150, v150 \n" " v_mul_f32 v56, v150, v150 \n"
...@@ -223,7 +223,7 @@ ...@@ -223,7 +223,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072\n" "buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -236,7 +236,7 @@ ...@@ -236,7 +236,7 @@
" v_mul_f32 v149, v149, v55 \n" " v_mul_f32 v149, v149, v55 \n"
" v_mul_f32 v150, v150, v56 \n" " v_mul_f32 v150, v150, v56 \n"
" v_mul_f32 v151, v151, v57 \n" " v_mul_f32 v151, v151, v57 \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen\n" "buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen\n"
" v_mul_f32 v54, v152, v152 \n" " v_mul_f32 v54, v152, v152 \n"
" v_mul_f32 v55, v153, v153 \n" " v_mul_f32 v55, v153, v153 \n"
" v_mul_f32 v56, v154, v154 \n" " v_mul_f32 v56, v154, v154 \n"
...@@ -257,7 +257,7 @@ ...@@ -257,7 +257,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024\n" "buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -270,7 +270,7 @@ ...@@ -270,7 +270,7 @@
" v_mul_f32 v153, v153, v55 \n" " v_mul_f32 v153, v153, v55 \n"
" v_mul_f32 v154, v154, v56 \n" " v_mul_f32 v154, v154, v56 \n"
" v_mul_f32 v155, v155, v57 \n" " v_mul_f32 v155, v155, v57 \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048\n" "buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v156, v156 \n" " v_mul_f32 v54, v156, v156 \n"
" v_mul_f32 v55, v157, v157 \n" " v_mul_f32 v55, v157, v157 \n"
" v_mul_f32 v56, v158, v158 \n" " v_mul_f32 v56, v158, v158 \n"
...@@ -291,7 +291,7 @@ ...@@ -291,7 +291,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072\n" "buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072\n"
" s_add_u32 s12, %[s_tile_os_b_half], s12 \n" " s_add_u32 s12, %[s_tile_os_b_half], s12 \n"
" s_addc_u32 s13, 0, s13 \n" " s_addc_u32 s13, 0, s13 \n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
...@@ -307,7 +307,7 @@ ...@@ -307,7 +307,7 @@
" v_mul_f32 v158, v158, v56 \n" " v_mul_f32 v158, v158, v56 \n"
" v_mul_f32 v159, v159, v57 \n" " v_mul_f32 v159, v159, v57 \n"
" s_waitcnt vmcnt(24) \n" " s_waitcnt vmcnt(24) \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b0], s[12:15], 0 offen\n" "buffer_load_dwordx4 acc[64:67], %[v_os_b0], s[12:15], 0 offen\n"
" v_mul_f32 v54, v160, v160 \n" " v_mul_f32 v54, v160, v160 \n"
" v_mul_f32 v55, v161, v161 \n" " v_mul_f32 v55, v161, v161 \n"
" v_mul_f32 v56, v162, v162 \n" " v_mul_f32 v56, v162, v162 \n"
...@@ -328,7 +328,7 @@ ...@@ -328,7 +328,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b0], s[12:15], 0 offen offset:1024\n" "buffer_load_dwordx4 acc[68:71], %[v_os_b0], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -341,7 +341,7 @@ ...@@ -341,7 +341,7 @@
" v_mul_f32 v161, v161, v55 \n" " v_mul_f32 v161, v161, v55 \n"
" v_mul_f32 v162, v162, v56 \n" " v_mul_f32 v162, v162, v56 \n"
" v_mul_f32 v163, v163, v57 \n" " v_mul_f32 v163, v163, v57 \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b0], s[12:15], 0 offen offset:2048\n" "buffer_load_dwordx4 acc[72:75], %[v_os_b0], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v164, v164 \n" " v_mul_f32 v54, v164, v164 \n"
" v_mul_f32 v55, v165, v165 \n" " v_mul_f32 v55, v165, v165 \n"
" v_mul_f32 v56, v166, v166 \n" " v_mul_f32 v56, v166, v166 \n"
...@@ -362,7 +362,7 @@ ...@@ -362,7 +362,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b0], s[12:15], 0 offen offset:3072\n" "buffer_load_dwordx4 acc[76:79], %[v_os_b0], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -375,7 +375,7 @@ ...@@ -375,7 +375,7 @@
" v_mul_f32 v165, v165, v55 \n" " v_mul_f32 v165, v165, v55 \n"
" v_mul_f32 v166, v166, v56 \n" " v_mul_f32 v166, v166, v56 \n"
" v_mul_f32 v167, v167, v57 \n" " v_mul_f32 v167, v167, v57 \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b1], s[12:15], 0 offen\n" "buffer_load_dwordx4 acc[80:83], %[v_os_b1], s[12:15], 0 offen\n"
" v_mul_f32 v54, v168, v168 \n" " v_mul_f32 v54, v168, v168 \n"
" v_mul_f32 v55, v169, v169 \n" " v_mul_f32 v55, v169, v169 \n"
" v_mul_f32 v56, v170, v170 \n" " v_mul_f32 v56, v170, v170 \n"
...@@ -396,7 +396,7 @@ ...@@ -396,7 +396,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b1], s[12:15], 0 offen offset:1024\n" "buffer_load_dwordx4 acc[84:87], %[v_os_b1], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -409,7 +409,7 @@ ...@@ -409,7 +409,7 @@
" v_mul_f32 v169, v169, v55 \n" " v_mul_f32 v169, v169, v55 \n"
" v_mul_f32 v170, v170, v56 \n" " v_mul_f32 v170, v170, v56 \n"
" v_mul_f32 v171, v171, v57 \n" " v_mul_f32 v171, v171, v57 \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b1], s[12:15], 0 offen offset:2048\n" "buffer_load_dwordx4 acc[88:91], %[v_os_b1], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v172, v172 \n" " v_mul_f32 v54, v172, v172 \n"
" v_mul_f32 v55, v173, v173 \n" " v_mul_f32 v55, v173, v173 \n"
" v_mul_f32 v56, v174, v174 \n" " v_mul_f32 v56, v174, v174 \n"
...@@ -430,7 +430,7 @@ ...@@ -430,7 +430,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b1], s[12:15], 0 offen offset:3072\n" "buffer_load_dwordx4 acc[92:95], %[v_os_b1], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -444,7 +444,7 @@ ...@@ -444,7 +444,7 @@
" v_mul_f32 v174, v174, v56 \n" " v_mul_f32 v174, v174, v56 \n"
" v_mul_f32 v175, v175, v57 \n" " v_mul_f32 v175, v175, v57 \n"
" s_waitcnt vmcnt(24) \n" " s_waitcnt vmcnt(24) \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b2], s[12:15], 0 offen\n" "buffer_load_dwordx4 acc[96:99], %[v_os_b2], s[12:15], 0 offen\n"
" v_mul_f32 v54, v176, v176 \n" " v_mul_f32 v54, v176, v176 \n"
" v_mul_f32 v55, v177, v177 \n" " v_mul_f32 v55, v177, v177 \n"
" v_mul_f32 v56, v178, v178 \n" " v_mul_f32 v56, v178, v178 \n"
...@@ -465,7 +465,7 @@ ...@@ -465,7 +465,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b2], s[12:15], 0 offen offset:1024\n" "buffer_load_dwordx4 acc[100:103], %[v_os_b2], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -478,7 +478,7 @@ ...@@ -478,7 +478,7 @@
" v_mul_f32 v177, v177, v55 \n" " v_mul_f32 v177, v177, v55 \n"
" v_mul_f32 v178, v178, v56 \n" " v_mul_f32 v178, v178, v56 \n"
" v_mul_f32 v179, v179, v57 \n" " v_mul_f32 v179, v179, v57 \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b2], s[12:15], 0 offen offset:2048\n" "buffer_load_dwordx4 acc[104:107], %[v_os_b2], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v180, v180 \n" " v_mul_f32 v54, v180, v180 \n"
" v_mul_f32 v55, v181, v181 \n" " v_mul_f32 v55, v181, v181 \n"
" v_mul_f32 v56, v182, v182 \n" " v_mul_f32 v56, v182, v182 \n"
...@@ -499,7 +499,7 @@ ...@@ -499,7 +499,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b2], s[12:15], 0 offen offset:3072\n" "buffer_load_dwordx4 acc[108:111], %[v_os_b2], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -512,7 +512,7 @@ ...@@ -512,7 +512,7 @@
" v_mul_f32 v181, v181, v55 \n" " v_mul_f32 v181, v181, v55 \n"
" v_mul_f32 v182, v182, v56 \n" " v_mul_f32 v182, v182, v56 \n"
" v_mul_f32 v183, v183, v57 \n" " v_mul_f32 v183, v183, v57 \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b3], s[12:15], 0 offen\n" "buffer_load_dwordx4 acc[112:115], %[v_os_b3], s[12:15], 0 offen\n"
" v_mul_f32 v54, v184, v184 \n" " v_mul_f32 v54, v184, v184 \n"
" v_mul_f32 v55, v185, v185 \n" " v_mul_f32 v55, v185, v185 \n"
" v_mul_f32 v56, v186, v186 \n" " v_mul_f32 v56, v186, v186 \n"
...@@ -533,7 +533,7 @@ ...@@ -533,7 +533,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b3], s[12:15], 0 offen offset:1024\n" "buffer_load_dwordx4 acc[116:119], %[v_os_b3], s[12:15], 0 offen offset:1024\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -546,7 +546,7 @@ ...@@ -546,7 +546,7 @@
" v_mul_f32 v185, v185, v55 \n" " v_mul_f32 v185, v185, v55 \n"
" v_mul_f32 v186, v186, v56 \n" " v_mul_f32 v186, v186, v56 \n"
" v_mul_f32 v187, v187, v57 \n" " v_mul_f32 v187, v187, v57 \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b3], s[12:15], 0 offen offset:2048\n" "buffer_load_dwordx4 acc[120:123], %[v_os_b3], s[12:15], 0 offen offset:2048\n"
" v_mul_f32 v54, v188, v188 \n" " v_mul_f32 v54, v188, v188 \n"
" v_mul_f32 v55, v189, v189 \n" " v_mul_f32 v55, v189, v189 \n"
" v_mul_f32 v56, v190, v190 \n" " v_mul_f32 v56, v190, v190 \n"
...@@ -567,7 +567,7 @@ ...@@ -567,7 +567,7 @@
" v_exp_f32 v55, v55 \n" " v_exp_f32 v55, v55 \n"
" v_exp_f32 v56, v56 \n" " v_exp_f32 v56, v56 \n"
" v_exp_f32 v57, v57 \n" " v_exp_f32 v57, v57 \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b3], s[12:15], 0 offen offset:3072\n" "buffer_load_dwordx4 acc[124:127], %[v_os_b3], s[12:15], 0 offen offset:3072\n"
" v_add_f32 v54, v54, 1.0 \n" " v_add_f32 v54, v54, 1.0 \n"
" v_add_f32 v55, v55, 1.0 \n" " v_add_f32 v55, v55, 1.0 \n"
" v_add_f32 v56, v56, 1.0 \n" " v_add_f32 v56, v56, 1.0 \n"
...@@ -644,7 +644,7 @@ ...@@ -644,7 +644,7 @@
" v_mul_f32 v189, %[smq_scale1], v189 row_newbcast:13 \n" " v_mul_f32 v189, %[smq_scale1], v189 row_newbcast:13 \n"
" v_mul_f32 v190, %[smq_scale1], v190 row_newbcast:14 \n" " v_mul_f32 v190, %[smq_scale1], v190 row_newbcast:14 \n"
" v_mul_f32 v191, %[smq_scale1], v191 row_newbcast:15 \n" " v_mul_f32 v191, %[smq_scale1], v191 row_newbcast:15 \n"
" buffer_load_dword v12, %[v_os_dq], s[16:19], 0 offen \n" ";--buffer_load_dword v12, %[v_os_dq], s[16:19], 0 offen \n"
" v_mov_b32 v22, 0x358637bd \n" " v_mov_b32 v22, 0x358637bd \n"
" v_mov_b32 v23, 0x358637bd \n" " v_mov_b32 v23, 0x358637bd \n"
" v_max3_f32 v22, abs(v128), abs(v129), v22 \n" " v_max3_f32 v22, abs(v128), abs(v129), v22 \n"
...@@ -974,3 +974,5 @@ ...@@ -974,3 +974,5 @@
#undef _UK_PK_CVT_ #undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_ #undef _UK_ATOMIC_ADD_
...@@ -19,9 +19,9 @@ ...@@ -19,9 +19,9 @@
" v_mul_f32 " a2 ", " gq ", " a2 " row_newbcast: " brd2 " \n" \ " v_mul_f32 " a2 ", " gq ", " a2 " row_newbcast: " brd2 " \n" \
" v_mul_f32 " a3 ", " gq ", " a3 " row_newbcast:" brd3 " \n" " v_mul_f32 " a3 ", " gq ", " a3 " row_newbcast:" brd3 " \n"
" s_mov_b32 s22, %[a_bound] \n"
"s_mov_b32 s20, %[s_res_a0] \n" "s_mov_b32 s20, %[s_res_a0] \n"
"s_mov_b32 s21, %[s_res_a1] \n" "s_mov_b32 s21, %[s_res_a1] \n"
"s_mov_b32 s22, %[s_res_a2] \n"
"s_mov_b32 s23, %[s_res_a3] \n" "s_mov_b32 s23, %[s_res_a3] \n"
"s_mov_b32 s24, %[s_res_b0] \n" "s_mov_b32 s24, %[s_res_b0] \n"
"s_mov_b32 s25, %[s_res_b1] \n" "s_mov_b32 s25, %[s_res_b1] \n"
...@@ -110,38 +110,38 @@ ...@@ -110,38 +110,38 @@
" s_add_u32 s20, s57, s20 \n" " s_add_u32 s20, s57, s20 \n"
" s_addc_u32 s21, 0, s21 \n" " s_addc_u32 s21, 0, s21 \n"
"; -- prefetch B0\n" "; -- prefetch B0\n"
"buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[24:27], 0 offen \n" " buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[24:27], 0 offen offset:1024 \n" " buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[24:27], 0 offen offset:2048 \n" " buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[24:27], 0 offen offset:3072 \n" " buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[24:27], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[24:27], 0 offen \n" " buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[24:27], 0 offen offset:1024 \n" " buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[24:27], 0 offen offset:2048 \n" " buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[24:27], 0 offen offset:3072 \n" " buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[24:27], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[24:27], 0 offen \n" " buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[24:27], 0 offen offset:1024 \n" " buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[24:27], 0 offen offset:2048 \n" " buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[24:27], 0 offen offset:3072 \n" " buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[24:27], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[24:27], 0 offen \n" " buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[24:27], 0 offen offset:1024 \n" " buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[24:27], 0 offen offset:2048 \n" " buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[24:27], 0 offen offset:3072 \n" " buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[24:27], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[24:27], 0 offen \n" " buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[24:27], 0 offen offset:1024 \n" " buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[24:27], 0 offen offset:2048 \n" " buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[24:27], 0 offen offset:3072 \n" " buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[24:27], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[24:27], 0 offen \n" " buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[24:27], 0 offen offset:1024 \n" " buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[24:27], 0 offen offset:2048 \n" " buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[24:27], 0 offen offset:3072 \n" " buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[24:27], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[24:27], 0 offen \n" " buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[24:27], 0 offen offset:1024 \n" " buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[24:27], 0 offen offset:2048 \n" " buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[24:27], 0 offen offset:3072 \n" " buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[24:27], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[24:27], 0 offen \n" " buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[24:27], 0 offen \n"
"buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[24:27], 0 offen offset:1024 \n" " buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[24:27], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[24:27], 0 offen offset:2048 \n" " buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[24:27], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[24:27], 0 offen offset:3072 \n" " buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[24:27], 0 offen offset:3072 \n"
"s_add_u32 s24, s58, s24 \n" "s_add_u32 s24, s58, s24 \n"
"s_addc_u32 s25, 0, s25 \n" "s_addc_u32 s25, 0, s25 \n"
......
...@@ -237,12 +237,23 @@ struct FusedMoeGemmKernel ...@@ -237,12 +237,23 @@ struct FusedMoeGemmKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
{ {
if constexpr(UseUK) if constexpr(UseUK)
{ {
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
__shared__ CK_TILE_LDS_ADDR ADataType smem[65536];
// index_t s_size = GetSmemSize();
// ADataType{}.aaa();
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane( IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr)); *reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
// __builtin_amdgcn_sched_barrier(0);
// if(threadIdx.x == 0){
// printf("num_sorted_tiles %d\n", num_sorted_tiles);
// printf("data type :%s\n", t2s<ADataType>::name);
// printf("\nblockIdx.x :%x, blockIdx.y :%x,\n", blockIdx.x, blockIdx.y);
// __builtin_amdgcn_sched_barrier(0);
// }
// __builtin_amdgcn_sched_barrier(0);
num_sorted_tiles = num_sorted_tiles / BlockShape::Block_M0; num_sorted_tiles = num_sorted_tiles / BlockShape::Block_M0;
const auto [sorted_tile_id, intermediate_tile_id] = const auto [sorted_tile_id, intermediate_tile_id] =
......
...@@ -74,7 +74,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -74,7 +74,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize(); constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
constexpr index_t smem_bridge = constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType); BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
return max(smem_0, max(smem_1, smem_bridge)); return max(smem_0, max(smem_1, smem_bridge));
} }
// this is the thread-offset along row/col // this is the thread-offset along row/col
...@@ -159,15 +159,22 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -159,15 +159,22 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
template <typename ROW_IDS> template <typename ROW_IDS>
CK_TILE_DEVICE auto GetAScale(const ROW_IDS row_ids_mma, CK_TILE_DEVICE auto GetAScale(const ROW_IDS row_ids_mma,
const AScaleDataType* a_scale_ptr) // const AScaleDataType* a_scale_ptr, index_t num_tokens_)
index_t num_tokens_)
{ {
constexpr index_t n_size = row_ids_mma.size(); constexpr index_t n_size = row_ids_mma.size();
array<TopkWeightDataType, n_size> w; array<TopkWeightDataType, n_size> w;
static_for<0, n_size, 1>{}([&](auto i) { static_for<0, n_size, 1>{}([&](auto i) {
auto row_id = row_ids_mma[i] & 0xffffff; auto row_id = row_ids_mma[i] & 0xffffff;
auto itp_k = row_ids_mma[i] >> 24; if (row_id >= num_tokens_)
w.at(i) = a_scale_ptr[row_id * 5+itp_k]; {
w.at(i) = 0.f;
} else {
w.at(i) = 1.f;
// auto itp_k = row_ids_mma[i] >> 24;
// w.at(i) = a_scale_ptr[row_id * 5+itp_k];
}
}); });
return w; return w;
...@@ -247,7 +254,12 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -247,7 +254,12 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
index_t kr_0 = kargs.hidden_size / BlockShape::Warp_K0; // divide K in W index_t kr_0 = kargs.hidden_size / BlockShape::Warp_K0; // divide K in W
index_t nr_1 = kargs.hidden_size / BlockShape::Warp_N1; index_t nr_1 = kargs.hidden_size / BlockShape::Warp_N1;
index_t kr_1 = shared_intermediate_size_1 / BlockShape::Warp_K1; index_t kr_1 = shared_intermediate_size_1 / BlockShape::Warp_K1;
// if(threadIdx.x == 31 && blockIdx.x == 0 && blockIdx.y == 0)
// {
// printf("\nWarpPerBlock_N0 :%x, WarpPerBlock_M0:%x,\n", BlockShape::WarpPerBlock_N0
// , BlockShape::WarpPerBlock_M0);
// }
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane( const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]); reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size; index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size;
...@@ -271,7 +283,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -271,7 +283,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
row_coords_a_mma, reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr)); row_coords_a_mma, reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr));
auto token_id = generate_tuple( auto token_id = generate_tuple(
[&](auto i) { [&](auto i) {
return (row_ids_a[i]) &0xffffff; return (row_ids_a[i] &0xffffff);
}, },
number<row_ids_a.size()>{}); number<row_ids_a.size()>{});
auto a_coords = generate_tuple( auto a_coords = generate_tuple(
...@@ -385,10 +397,16 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -385,10 +397,16 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
threadIdx.x % (BlockShape::Block_N1/2 / kAlignmentO) * kAlignmentO; threadIdx.x % (BlockShape::Block_N1/2 / kAlignmentO) * kAlignmentO;
}, },
number<row_ids_a.size()>{}); number<row_ids_a.size()>{});
auto o_flags = auto o_flags =
generate_tuple([&](auto i) { return cmp_lt_to_exec(token_id[i], kargs.num_tokens); }, generate_tuple([&](auto i) { return cmp_lt_to_exec(token_id[i], kargs.num_tokens); },
number<row_ids_a.size()>{}); // generate_tuple([&](auto i) {
// if (__builtin_amdgcn_readfirstlane(token_id[i]) < kargs.num_tokens)
// {return 0xffffffffffffffff;}
// else
// {return uint32x2_t 0;}
// },
number<token_id.size()>{});
auto o_res = auto o_res =
make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr), make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr),
...@@ -398,13 +416,59 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -398,13 +416,59 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto w_scale = GetWeightScale( auto w_scale = GetWeightScale(
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)); row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
auto a_scale = GetAScale( auto a_scale = GetAScale(
row_ids_a_mma, reinterpret_cast<const AScaleDataType*>(kargs.a_scale_ptr)); // row_ids_a_mma, reinterpret_cast<const AScaleDataType*>(kargs.a_scale_ptr), kargs.num_tokens );
row_ids_a_mma, kargs.num_tokens );
auto gqsmq_coords = GetColCoords_GQSMQ(intermediate_tile_id * BlockShape::Block_K1); auto gqsmq_coords = GetColCoords_GQSMQ(intermediate_tile_id * BlockShape::Block_K1);
auto dq_coords = gqsmq_coords[0];//only one for this tiling auto dq_coords = gqsmq_coords[0];//only one for this tiling
auto gq_scale = GetGQScale( auto gq_scale = GetGQScale(
gqsmq_coords, (reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr) + static_cast<long_index_t>(expert_id) * shared_intermediate_size_0)); gqsmq_coords, (reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr) + static_cast<long_index_t>(expert_id) * shared_intermediate_size_0));
auto smq_scale = GetSMQScale( auto smq_scale = GetSMQScale(
gqsmq_coords, (reinterpret_cast<const YSmoothScaleDataType*>(kargs.y_smooth_scale_ptr) + static_cast<long_index_t>(expert_id) * shared_intermediate_size_0)); gqsmq_coords, (reinterpret_cast<const YSmoothScaleDataType*>(kargs.y_smooth_scale_ptr) + static_cast<long_index_t>(expert_id) * shared_intermediate_size_0));
if(threadIdx.x == 95 && blockIdx.x == 0 && blockIdx.y == 0)
{
printf("\nblockIdx.x :%x, blockIdx.y :%x, d ptr: %p, wg d ptr :%x%x,gemm0 done\n", blockIdx.x, blockIdx.y, kargs.d_ptr,d_res[1],d_res[0]);
// // printf("\n wg 1 1, wave 1, row_coords_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_coords_a[number<0>{}],row_coords_a[number<1>{}],row_coords_a[number<2>{}],row_coords_a[number<3>{}], row_coords_a[number<4>{}],row_coords_a[number<5>{}],row_coords_a[number<6>{}],row_coords_a[number<7>{}]);
// // printf("\n -------------- -row_ids_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_ids_a[number<0>{}],row_ids_a[number<1>{}],row_ids_a[number<2>{}],row_ids_a[number<3>{}], row_ids_a[number<4>{}],row_ids_a[number<5>{}],row_ids_a[number<6>{}],row_ids_a[number<7>{}]);
// // printf("\n -----------thread id %x--- - token_id 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", hipThreadIdx_x , token_id[number<0>{}],token_id[number<1>{}],token_id[number<2>{}],token_id[number<3>{}], token_id[number<4>{}],token_id[number<5>{}],token_id[number<6>{}],token_id[number<7>{}]);
// // printf("\n -----------thread id %x--- - token_id , 7:%x,, \n", hipThreadIdx_x , token_id[number<7>{}]);
// // printf("\n -------------- - exec 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", o_flags[number<0>{}][0],o_flags[number<1>{}][0],o_flags[number<2>{}][0],o_flags[number<3>{}][0], o_flags[number<4>{}][0],o_flags[number<5>{}][0],o_flags[number<6>{}][0],o_flags[number<7>{}][0]);
printf("\ntoken id :%x,%x,%x,%x, %x,%x,%x,%x \n d_coords: %x,%x,%x,%x, \n row_idx: %x,%x,%x,%x, %x,%x,%x,%x \n o_flags:%x,%x,%x,%x, %x,%x,%x,%x \n",
token_id[number<0>{}],
token_id[number<1>{}],
token_id[number<2>{}],
token_id[number<3>{}],
token_id[number<4>{}],
token_id[number<5>{}],
token_id[number<6>{}],
token_id[number<7>{}],
d_coords[number<0>{}],
d_coords[number<1>{}],
d_coords[number<2>{}],
d_coords[number<3>{}],
// d_coords[number<4>{}],
// d_coords[number<5>{}],
// d_coords[number<6>{}],
// d_coords[number<7>{}],
row_ids_a[number<0>{}],
row_ids_a[number<1>{}],
row_ids_a[number<2>{}],
row_ids_a[number<3>{}],
row_ids_a[number<4>{}],
row_ids_a[number<5>{}],
row_ids_a[number<6>{}],
row_ids_a[number<7>{}],
o_flags[number<0>{}][0],
o_flags[number<1>{}][0],
o_flags[number<2>{}][0],
o_flags[number<3>{}][0],
o_flags[number<4>{}][0],
o_flags[number<5>{}][0],
o_flags[number<6>{}][0],
o_flags[number<7>{}][0]);
// return;
}
__builtin_amdgcn_sched_barrier(0);
auto uk_0 = Policy::template GetUK_0<Problem>(); auto uk_0 = Policy::template GetUK_0<Problem>();
// auto acc_0= uk_0( // auto acc_0= uk_0(
uk_0(a_scale, uk_0(a_scale,
...@@ -418,17 +482,12 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -418,17 +482,12 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
smem, smem,
kargs.hidden_size, kargs.hidden_size,
BlockShape::Block_K0, // tile offset for B matrix each unroll BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 * 16*256,
BlockShape::Block_W0); // tile offset for B matrix each unroll kargs.num_tokens * kargs.stride_token); // tile offset for B matrix each unroll
if(hipBlockIdx_x == 1 && hipBlockIdx_y == 1 && hipBlockIdx_z == 0 && // return;
hipThreadIdx_x == 64) __builtin_amdgcn_sched_barrier(0);
{
printf("\ngemm0 done\n"); // // sweep_tile(
// printf("\n wg 1 1, wave 1, row_coords_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_coords_a[number<0>{}],row_coords_a[number<1>{}],row_coords_a[number<2>{}],row_coords_a[number<3>{}], row_coords_a[number<4>{}],row_coords_a[number<5>{}],row_coords_a[number<6>{}],row_coords_a[number<7>{}]);
// printf("\n -------------- -row_ids_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_ids_a[number<0>{}],row_ids_a[number<1>{}],row_ids_a[number<2>{}],row_ids_a[number<3>{}], row_ids_a[number<4>{}],row_ids_a[number<5>{}],row_ids_a[number<6>{}],row_ids_a[number<7>{}]);
printf("\n -------------- - token_id 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", token_id[number<0>{}],token_id[number<1>{}],token_id[number<2>{}],token_id[number<3>{}], token_id[number<4>{}],token_id[number<5>{}],token_id[number<6>{}],token_id[number<7>{}]);
}
// sweep_tile(
// acc_0, // acc_0,
// [&](auto idx0, auto idx1) { // [&](auto idx0, auto idx1) {
// fp32x2_t v_{acc_0(idx0), acc_0(idx1)}; // fp32x2_t v_{acc_0(idx0), acc_0(idx1)};
...@@ -449,6 +508,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -449,6 +508,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
uk_1( uk_1(
// dq_res, // dq_res,
// d_res, // d_res,
token_id,
dq_coords, dq_coords,
d_coords, d_coords,
o_res, o_res,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment