Commit 5d00b37e authored by shengnxu's avatar shengnxu
Browse files

fix loop cnt and half d buffer size

parent 2a66e080
...@@ -72,7 +72,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_Base ...@@ -72,7 +72,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_Base
struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x1_16x16x64_Base struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x1_16x16x64_Base
{ {
using BDataType = int8_t; using BDataType = int8_t;
using ODataType = int8_t; using ODataType = bf16_t;
using DScaleDataType = float_t; using DScaleDataType = float_t;
// TODO: need paired with tile_window_linear! // TODO: need paired with tile_window_linear!
......
...@@ -205,7 +205,7 @@ ...@@ -205,7 +205,7 @@
" v_mfma_i32_16x16x32_i8 v[220:223], acc[124:125], v[188:189], v[220:223] \n" " v_mfma_i32_16x16x32_i8 v[220:223], acc[124:125], v[188:189], v[220:223] \n"
" v_mfma_i32_16x16x32_i8 v[220:223], acc[126:127], v[190:191], v[220:223] \n" " v_mfma_i32_16x16x32_i8 v[220:223], acc[126:127], v[190:191], v[220:223] \n"
" s_add_u32 s60, 0x00000200, s80 \n" " s_add_u32 s60, 0x00000200, s80 \n"
" s_cmp_lt_u32 s60, s81 \n" " s_cmp_lt_u32 s60, %[s_loop_cnt] \n"
" s_cselect_b32 %[s_tile_os_b], %[s_tile_os_b], 0 \n" " s_cselect_b32 %[s_tile_os_b], %[s_tile_os_b], 0 \n"
" s_cselect_b32 %[s_tile_os_b_half], %[s_tile_os_b_half], 0 \n" " s_cselect_b32 %[s_tile_os_b_half], %[s_tile_os_b_half], 0 \n"
" s_cselect_b32 %[s_tile_os_dq], %[s_tile_os_dq], 0 \n" " s_cselect_b32 %[s_tile_os_dq], %[s_tile_os_dq], 0 \n"
...@@ -528,10 +528,10 @@ ...@@ -528,10 +528,10 @@
" s_mov_b64 exec, %[s_execflag_7] \n" " s_mov_b64 exec, %[s_execflag_7] \n"
" global_atomic_pk_add_bf16 %[v_os_o7], v79, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n" " global_atomic_pk_add_bf16 %[v_os_o7], v79, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n" " s_mov_b64 exec, s[36:37] \n"
" s_add_u32 %[s_res_o0], s59, %[s_res_o0] \n" " s_add_u32 %[s_res_o0], %[s_tile_os_o], %[s_res_o0] \n"
" s_addc_u32 %[s_res_o1], 0, %[s_res_o1] \n" " s_addc_u32 %[s_res_o1], 0, %[s_res_o1] \n"
" s_addk_i32 s80, 0x0100 \n" " s_addk_i32 s80, 0x0100 \n"
" s_cmp_lt_i32 s80, s81 \n" " s_cmp_lt_i32 s80, %[s_loop_cnt] \n"
" s_cbranch_scc0 label_end_gemm2 \n" " s_cbranch_scc0 label_end_gemm2 \n"
" s_waitcnt vmcnt(41) \n" " s_waitcnt vmcnt(41) \n"
" s_barrier \n" " s_barrier \n"
...@@ -702,7 +702,7 @@ ...@@ -702,7 +702,7 @@
" v_mfma_i32_16x16x32_i8 v[252:255], acc[252:253], v[188:189], v[252:255] \n" " v_mfma_i32_16x16x32_i8 v[252:255], acc[252:253], v[188:189], v[252:255] \n"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[254:255], v[190:191], v[252:255] \n" " v_mfma_i32_16x16x32_i8 v[252:255], acc[254:255], v[190:191], v[252:255] \n"
" s_add_u32 s60, 0x00000200, s80 \n" " s_add_u32 s60, 0x00000200, s80 \n"
" s_cmp_lt_u32 s60, s81 \n" " s_cmp_lt_u32 s60, %[s_loop_cnt] \n"
" s_cselect_b32 %[s_tile_os_b], %[s_tile_os_b], 0 \n" " s_cselect_b32 %[s_tile_os_b], %[s_tile_os_b], 0 \n"
" s_cselect_b32 %[s_tile_os_b_half], %[s_tile_os_b_half], 0 \n" " s_cselect_b32 %[s_tile_os_b_half], %[s_tile_os_b_half], 0 \n"
" s_cselect_b32 %[s_tile_os_dq], %[s_tile_os_dq], 0 \n" " s_cselect_b32 %[s_tile_os_dq], %[s_tile_os_dq], 0 \n"
...@@ -1025,10 +1025,10 @@ ...@@ -1025,10 +1025,10 @@
" s_mov_b64 exec, %[s_execflag_7] \n" " s_mov_b64 exec, %[s_execflag_7] \n"
" global_atomic_pk_add_bf16 %[v_os_o7], v79, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n" " global_atomic_pk_add_bf16 %[v_os_o7], v79, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n"
" s_mov_b64 exec, s[36:37] \n" " s_mov_b64 exec, s[36:37] \n"
" s_add_u32 %[s_res_o0], s59, %[s_res_o0] \n" " s_add_u32 %[s_res_o0], %[s_tile_os_o], %[s_res_o0] \n"
" s_addc_u32 %[s_res_o1], 0, %[s_res_o1] \n" " s_addc_u32 %[s_res_o1], 0, %[s_res_o1] \n"
" s_addk_i32 s80, 0x0100 \n" " s_addk_i32 s80, 0x0100 \n"
" s_cmp_lt_i32 s80, s81 \n" " s_cmp_lt_i32 s80, %[s_loop_cnt] \n"
" s_cbranch_scc0 label_end_gemm2 \n" " s_cbranch_scc0 label_end_gemm2 \n"
" s_branch label_startgemm2 \n" " s_branch label_startgemm2 \n"
" label_end_gemm2: \n" " label_end_gemm2: \n"
...@@ -1037,6 +1037,3 @@ ...@@ -1037,6 +1037,3 @@
#undef _UK_MFMA_ #undef _UK_MFMA_
#undef _UK_PK_CVT_ #undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_ #undef _UK_ATOMIC_ADD_
...@@ -372,8 +372,8 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -372,8 +372,8 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
Nl_; // Kr0_ * Kr1_ * W_; Nl_; // Kr0_ * Kr1_ * W_;
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
constexpr auto i_nr_ = number<i % Nr_>{}; // constexpr auto i_nr_ = number<i % Nr_>{};
return i_nr_ * shared_intermediate_size_1 * Nw_ * Nl_ + return i * shared_intermediate_size_1 * Nw_ * Nl_ +
base_os_; base_os_;
}, },
number<num_offsets_>{}); number<num_offsets_>{});
...@@ -382,7 +382,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -382,7 +382,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto o_coords = generate_tuple( auto o_coords = generate_tuple(
[&](auto i) { [&](auto i) {
return token_id[i] * kargs.stride_token + return token_id[i] * kargs.stride_token +
threadIdx.x % (BlockShape::Block_N1 / kAlignmentO) * kAlignmentO; threadIdx.x % (BlockShape::Block_N1/2 / kAlignmentO) * kAlignmentO;
}, },
number<row_ids_a.size()>{}); number<row_ids_a.size()>{});
...@@ -420,11 +420,13 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -420,11 +420,13 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
BlockShape::Block_K0, // tile offset for B matrix each unroll BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 * BlockShape::Block_Kr0 *
BlockShape::Block_W0); // tile offset for B matrix each unroll BlockShape::Block_W0); // tile offset for B matrix each unroll
if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 && if(hipBlockIdx_x == 1 && hipBlockIdx_y == 1 && hipBlockIdx_z == 0 &&
hipThreadIdx_x == 5) hipThreadIdx_x == 64)
{ {
printf("\ngemm0 done\n"); printf("\ngemm0 done\n");
// printf("\n wg 1 1, wave 1, row_coords_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_coords_a[number<0>{}],row_coords_a[number<1>{}],row_coords_a[number<2>{}],row_coords_a[number<3>{}], row_coords_a[number<4>{}],row_coords_a[number<5>{}],row_coords_a[number<6>{}],row_coords_a[number<7>{}]);
// printf("\n -------------- -row_ids_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_ids_a[number<0>{}],row_ids_a[number<1>{}],row_ids_a[number<2>{}],row_ids_a[number<3>{}], row_ids_a[number<4>{}],row_ids_a[number<5>{}],row_ids_a[number<6>{}],row_ids_a[number<7>{}]);
printf("\n -------------- - token_id 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", token_id[number<0>{}],token_id[number<1>{}],token_id[number<2>{}],token_id[number<3>{}], token_id[number<4>{}],token_id[number<5>{}],token_id[number<6>{}],token_id[number<7>{}]);
} }
// sweep_tile( // sweep_tile(
// acc_0, // acc_0,
...@@ -457,8 +459,8 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -457,8 +459,8 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
w_scale, w_scale,
smq_scale, smq_scale,
BlockShape::Block_N1, BlockShape::Block_N1,
shared_intermediate_size_1 * BlockShape::Block_N1 - kr_1 * BlockShape::Block_W1, // along N shared_intermediate_size_1 * BlockShape::Block_N1 - 256 * 16, // along N
kr_1 * BlockShape::Block_W1, 256 * 16,
BlockShape::Block_N1); // along N BlockShape::Block_N1); // along N
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment