"test/tests.cpp" did not exist on "2fe7e8d5257a3661e67e7c124b9ce708dff89185"
Commit 2a66e080 authored by shengnxu's avatar shengnxu
Browse files

fix some issue, next step, res recalc,

parent 7cc808f2
......@@ -295,6 +295,11 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst);
constexpr auto smem_buf_size =
MakeLdsLoadDesc_A().get_element_space_size() * sizeof(ADataType);
// if(threadIdx.x%64 == 0 ){
// printf("wave id:%d, m0_init_value:%d, size_per_issue:%d\n",
// int(threadIdx.x/64),int(m0_init_value), int(size_per_issue));
// }
static_assert(a_sld.get_num_of_access() == 8);
constexpr auto sld_os = generate_tuple(
[&](auto i_access) {
......@@ -533,9 +538,8 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
"s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63",
"s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72",
"s73", "s74", "s75", "s76", "s77", "s78", "s79", "s80", // s86 as tmp
"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v50", "v51", "v52", "v53", "v54", "v55",
"v1", "v2", "v3", "v4", "v5", "v12", "v13", "v21", "v22", "v23",
"v24", "v25", "v50", "v51", "v52", "v53", "v54", "v55",
"v56", "v57", "v64",
"v65", "v66", "v67", "v68", "v69", "v70", "v71", "v72", "v73",
"v74", "v75", "v76", "v77", "v78", "v79", "v80", "v81", "v82",
......
......@@ -233,8 +233,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
"s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63",
"s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72",
"s73", "s74", "s75", "s76", "s77", "s78", "s79", "s80", // s86 as tmp
"v1", "v2", "v3", "v4", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v50", "v51", "v52", "v53", "v54", "v55",
"v1", "v2", "v3", "v4","v5", "v12", "v13", "v21", "v22", "v23", "v24", "v25", "v50", "v51", "v52", "v53", "v54", "v55",
"v56", "v57", "v64",
"v65", "v66", "v67", "v68", "v69", "v70", "v71", "v72", "v73",
"v74", "v75", "v76", "v77", "v78", "v79", "v80", "v81", "v82",
......@@ -261,6 +260,12 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
"v245", "v246", "v247", "v248", "v249", "v250", "v251", "v252",
"v253", "v254", "v255"
);
if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 &&
hipThreadIdx_x == 5)
{
printf("\n sn0 done\n");
}
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8
#include "uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc"
......@@ -335,8 +340,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
"s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63",
"s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72",
"s73", "s74", "s75", "s76", "s77", "s78", "s79", "s80", // s86 as tmp
"v1", "v2", "v3", "v4", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v50", "v51", "v52", "v53", "v54", "v55",
"v1", "v2", "v3", "v4", "v5", "v12", "v13", "v21", "v22", "v23", "v24", "v25", "v50", "v51", "v52", "v53", "v54", "v55",
"v56", "v57", "v64",
"v65", "v66", "v67", "v68", "v69", "v70", "v71", "v72", "v73",
"v74", "v75", "v76", "v77", "v78", "v79", "v80", "v81", "v82",
......
......@@ -27,6 +27,8 @@
# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_f16"
#endif
" s_mov_b32 s36, -1 \n"
" s_mov_b32 s37, -1 \n"
" s_add_u32 s12, %[s_tile_os_b], s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" s_add_u32 s16, %[s_tile_os_dq], s16 \n"
......@@ -478,52 +480,52 @@
" ds_read_b32 v78, v4 offset:43872 \n"
" ds_read_b32 v79, v4 offset:48224 \n"
" s_waitcnt lgkmcnt(0) \n"
" s_mov_b64 exec, s[20:21] \n"
" s_mov_b64 exec, %[s_execflag_0] \n"
" global_atomic_pk_add_bf16 %[v_os_o0], v64, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[20:21] \n"
" s_mov_b64 exec, %[s_execflag_0] \n"
" global_atomic_pk_add_bf16 %[v_os_o0], v65, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[22:23] \n"
" s_mov_b64 exec, %[s_execflag_1] \n"
" global_atomic_pk_add_bf16 %[v_os_o1], v66, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[22:23] \n"
" s_mov_b64 exec, %[s_execflag_1] \n"
" global_atomic_pk_add_bf16 %[v_os_o1], v67, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[24:25] \n"
" s_mov_b64 exec, %[s_execflag_2] \n"
" global_atomic_pk_add_bf16 %[v_os_o2], v68, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[24:25] \n"
" s_mov_b64 exec, %[s_execflag_2] \n"
" global_atomic_pk_add_bf16 %[v_os_o2], v69, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[26:27] \n"
" s_mov_b64 exec, %[s_execflag_3] \n"
" global_atomic_pk_add_bf16 %[v_os_o3], v70, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[26:27] \n"
" s_mov_b64 exec, %[s_execflag_3] \n"
" global_atomic_pk_add_bf16 %[v_os_o3], v71, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[28:29] \n"
" s_mov_b64 exec, %[s_execflag_4] \n"
" global_atomic_pk_add_bf16 %[v_os_o4], v72, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[28:29] \n"
" s_mov_b64 exec, %[s_execflag_4] \n"
" global_atomic_pk_add_bf16 %[v_os_o4], v73, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[30:31] \n"
" s_mov_b64 exec, %[s_execflag_5] \n"
" global_atomic_pk_add_bf16 %[v_os_o5], v74, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[30:31] \n"
" s_mov_b64 exec, %[s_execflag_5] \n"
" global_atomic_pk_add_bf16 %[v_os_o5], v75, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[32:33] \n"
" s_mov_b64 exec, %[s_execflag_6] \n"
" global_atomic_pk_add_bf16 %[v_os_o6], v76, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[32:33] \n"
" s_mov_b64 exec, %[s_execflag_6] \n"
" global_atomic_pk_add_bf16 %[v_os_o6], v77, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[34:35] \n"
" s_mov_b64 exec, %[s_execflag_7] \n"
" global_atomic_pk_add_bf16 %[v_os_o7], v78, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[34:35] \n"
" s_mov_b64 exec, %[s_execflag_7] \n"
" global_atomic_pk_add_bf16 %[v_os_o7], v79, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_add_u32 %[s_res_o0], s59, %[s_res_o0] \n"
......@@ -975,52 +977,52 @@
" ds_read_b32 v78, v4 offset:43872 \n"
" ds_read_b32 v79, v4 offset:48224 \n"
" s_waitcnt lgkmcnt(0) \n"
" s_mov_b64 exec, s[20:21] \n"
" s_mov_b64 exec, %[s_execflag_0] \n"
" global_atomic_pk_add_bf16 %[v_os_o0], v64, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[20:21] \n"
" s_mov_b64 exec, %[s_execflag_0] \n"
" global_atomic_pk_add_bf16 %[v_os_o0], v65, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[22:23] \n"
" s_mov_b64 exec, %[s_execflag_1] \n"
" global_atomic_pk_add_bf16 %[v_os_o1], v66, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[22:23] \n"
" s_mov_b64 exec, %[s_execflag_1] \n"
" global_atomic_pk_add_bf16 %[v_os_o1], v67, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[24:25] \n"
" s_mov_b64 exec, %[s_execflag_2] \n"
" global_atomic_pk_add_bf16 %[v_os_o2], v68, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[24:25] \n"
" s_mov_b64 exec, %[s_execflag_2] \n"
" global_atomic_pk_add_bf16 %[v_os_o2], v69, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[26:27] \n"
" s_mov_b64 exec, %[s_execflag_3] \n"
" global_atomic_pk_add_bf16 %[v_os_o3], v70, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[26:27] \n"
" s_mov_b64 exec, %[s_execflag_3] \n"
" global_atomic_pk_add_bf16 %[v_os_o3], v71, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[28:29] \n"
" s_mov_b64 exec, %[s_execflag_4] \n"
" global_atomic_pk_add_bf16 %[v_os_o4], v72, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[28:29] \n"
" s_mov_b64 exec, %[s_execflag_4] \n"
" global_atomic_pk_add_bf16 %[v_os_o4], v73, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[30:31] \n"
" s_mov_b64 exec, %[s_execflag_5] \n"
" global_atomic_pk_add_bf16 %[v_os_o5], v74, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[30:31] \n"
" s_mov_b64 exec, %[s_execflag_5] \n"
" global_atomic_pk_add_bf16 %[v_os_o5], v75, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[32:33] \n"
" s_mov_b64 exec, %[s_execflag_6] \n"
" global_atomic_pk_add_bf16 %[v_os_o6], v76, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[32:33] \n"
" s_mov_b64 exec, %[s_execflag_6] \n"
" global_atomic_pk_add_bf16 %[v_os_o6], v77, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[34:35] \n"
" s_mov_b64 exec, %[s_execflag_7] \n"
" global_atomic_pk_add_bf16 %[v_os_o7], v78, [%[s_res_o0],%[s_res_o1]] \n"
" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[34:35] \n"
" s_mov_b64 exec, %[s_execflag_7] \n"
" global_atomic_pk_add_bf16 %[v_os_o7], v79, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n"
" s_add_u32 %[s_res_o0], s59, %[s_res_o0] \n"
......@@ -1038,5 +1040,3 @@
......@@ -29,6 +29,8 @@
"s_mov_b32 s27, %[s_res_b3] \n"
"s_mov_b32 s16, %[s_res_dq0] \n"
"s_mov_b32 s17, %[s_res_dq1] \n"
"s_mov_b32 s18, %[s_res_dq2] \n"
"s_mov_b32 s19, %[s_res_dq3] \n"
"s_mov_b32 s12, %[s_res_d0] \n"
"s_mov_b32 s13, %[s_res_d1] \n"
"s_mov_b32 s14, %[s_res_d2] \n"
......@@ -584,3 +586,4 @@ _DEQUAN_CVT_("%[c60]","%[c61]","%[c62]","%[c63]","%[a_scale1]"," %[gq_scale1]","
#undef _DEQUAN_CVT_
......@@ -276,10 +276,10 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
number<row_ids_a.size()>{});
auto a_coords = generate_tuple(
[&](auto i) {
return ((row_ids_a[i])&0xffffff) * kargs.stride_token +
return (token_id[i]) * kargs.stride_token +
threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA;
},
number<row_ids_a.size()>{});
number<token_id.size()>{});
auto a_res =
make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
......@@ -407,7 +407,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
gqsmq_coords, (reinterpret_cast<const YSmoothScaleDataType*>(kargs.y_smooth_scale_ptr) + static_cast<long_index_t>(expert_id) * shared_intermediate_size_0));
auto uk_0 = Policy::template GetUK_0<Problem>();
// auto acc_0= uk_0(
uk_0( a_scale,
uk_0(a_scale,
gq_scale,
d_res,
dq_res,
......@@ -420,7 +420,12 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 *
BlockShape::Block_W0); // tile offset for B matrix each unroll
if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 &&
hipThreadIdx_x == 5)
{
printf("\ngemm0 done\n");
}
// sweep_tile(
// acc_0,
// [&](auto idx0, auto idx1) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment