Commit f549173b authored by shengnxu's avatar shengnxu
Browse files

simple gemm2 for gemm1 debuggging

parent 811b75d3
......@@ -245,12 +245,13 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
template <typename DQRes, typename GQRes,typename ARes, typename ACoords, typename BRes, typename BCoords>
template <typename DQRes, typename GQRes, typename SMQRes, typename ARes, typename ACoords, typename BRes, typename BCoords>
CK_TILE_DEVICE auto
operator()( index_t row_ids_a_,
const DQes& res_aq
const DQes& res_dq,
const GQRes& res_gq,
const SMQRes& res_smq,
const Res& res_a,
const ACoords& cached_coords_a,
const BRes& res_b,
......@@ -405,6 +406,10 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
[s_res_gq1]"s"(res_gq[1]),
[s_res_gq2]"s"(res_gq[2]),
[s_res_gq3]"s"(res_gq[3]),
[s_res_smq0]"s"(res_smq[0]),
[s_res_smq1]"s"(res_smq[1]),
[s_res_smq2]"s"(res_smq[2]),
[s_res_smq3]"s"(res_smq[3]),
[s_res_a0]"s"(res_a[0]),
[s_res_a1]"s"(res_a[1]),
[s_res_a2]"s"(res_a[2]),
......
......@@ -92,6 +92,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
CK_TILE_LDS_ADDR void* smem,
index_t n, // loop along n dim
const ScaleTensor& scale_,
index_t tile_offset_dq,
index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust
index_t tile_offset_half_b, //splited load alone K in to 2 part
index_t tile_offset_o)
......@@ -102,6 +103,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
const index_t tile_offset_half_b_bytes = tile_offset_half_b * sizeof(BDataType);
const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
const index_t tile_stride_dq_bytes = tile_offset_dq * sizeof(DScaleDataType);
static_assert(ScaleTensor::size() == 2);
float s0 = scale_[number<0>{}];
......@@ -244,6 +246,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
[s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b_half]"s"(tile_offset_half_b_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes),
[s_tile_os_dq]"s"(tile_stride_dq_bytes),
[scale_0]"v"(s0),
[scale_1]"v"(s1),
[v_nan_lo]"v"(nan_lo),
......
......@@ -12,7 +12,26 @@
" v_mul_f32 a[2], v17, a[2] row_newbcast:14 \n" \
" v_mul_f32 a[3], v17, a[3] row_newbcast:15 \n" \
"s_mov_b32 s16, %[s_res_dq0] \n"
"s_mov_b32 s17, %[s_res_dq1] \n"
"s_mov_b32 s18, %[s_res_dq2] \n"
"s_mov_b32 s19, %[s_res_dq3] \n"
"s_mov_b32 s32, %[s_res_gq0] \n"
"s_mov_b32 s33, %[s_res_gq1] \n"
"s_mov_b32 s34, %[s_res_gq2] \n"
"s_mov_b32 s35, %[s_res_gq3] \n"
"s_mov_b32 s36, %[s_res_smq0] \n"
"s_mov_b32 s37, %[s_res_smq1] \n"
"s_mov_b32 s38, %[s_res_smq2] \n"
"s_mov_b32 s39, %[s_res_smq3] \n"
"s_mov_b32 s20, %[s_res_a0] \n"
"s_mov_b32 s21, %[s_res_a1] \n"
"s_mov_b32 s22, %[s_res_a2] \n"
"s_mov_b32 s23, %[s_res_a3] \n"
"s_mov_b32 s24, %[s_res_b0] \n"
"s_mov_b32 s25, %[s_res_b1] \n"
"s_mov_b32 s26, %[s_res_b2] \n"
"s_mov_b32 s27, %[s_res_b3] \n"
//////////GQ/DQ/GsmQ_addr///////////////
//expert weight addr no need
......@@ -84,22 +103,6 @@
" buffer_load_dword v20, v8, s[40:43], 0 offen \n"
" buffer_load_dword v21, v9, s[40:43], 0 offen \n"
"s_mov_b32 s16, %[s_res_dq0] \n"
"s_mov_b32 s17, %[s_res_dq1] \n"
"s_mov_b32 s18, %[s_res_dq2] \n"
"s_mov_b32 s19, %[s_res_dq3] \n"
"s_mov_b32 s32, %[s_res_gq0] \n"
"s_mov_b32 s33, %[s_res_gq1] \n"
"s_mov_b32 s34, %[s_res_gq2] \n"
"s_mov_b32 s35, %[s_res_gq3] \n"
"s_mov_b32 s20, %[s_res_a0] \n"
"s_mov_b32 s21, %[s_res_a1] \n"
"s_mov_b32 s22, %[s_res_a2] \n"
"s_mov_b32 s23, %[s_res_a3] \n"
"s_mov_b32 s24, %[s_res_b0] \n"
"s_mov_b32 s25, %[s_res_b1] \n"
"s_mov_b32 s26, %[s_res_b2] \n"
"s_mov_b32 s27, %[s_res_b3] \n"
" s_mov_b32 s80, 0 \n"
//---------------------v26-33 no need
// "s_nop 4\n"
......
......@@ -180,6 +180,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
/////////////
index_t a_scale_expert_stride_0 = kargs.hidden_size;
index_t g_scale_expert_stride_0 = shared_intermediate_size_0;
index_t smq_scale_expert_stride_0 = shared_intermediate_size_0;
index_t d_scale_expert_stride_1 = kargs.hidden_size;
// nr*kr*w
index_t interm_idx_nr0 = __builtin_amdgcn_readfirstlane(
......@@ -244,12 +245,12 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
number<decltype(g_win)::NumAccess_NonLinear>{});
//////gq
auto gq_win = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr) +
const GScaleDataType* gq_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr) +
static_cast<long_index_t>(expert_id) * g_scale_expert_stride_0 +
intermediate_tile_id * BlockShape::Block_N0;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline
auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr,
auto gq_view_ = make_naive_tensor_view<address_space_enum::global>(
gq_ptr,
make_tuple(shared_intermediate_size_1),
number<1>{});
......@@ -257,7 +258,22 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
}();
auto gq_res = gq_win.get_buffer_view().cached_buf_res_;
////
////smQ
auto smq_win = [&]() {
const YSmoothScaleDataType* smq_ptr = reinterpret_cast<const YSmoothScaleDataType*>(kargs.y_smooth_scale_ptr) +
static_cast<long_index_t>(expert_id) * smq_scale_expert_stride_0 +
intermediate_tile_id * BlockShape::Block_N0;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline
auto smq_view_ = make_naive_tensor_view<address_space_enum::global>(
smq_ptr,
make_tuple(shared_intermediate_size_1),
number<1>{});
return smq_view_;
}();
auto smq_res = smq_win.get_buffer_view().cached_buf_res_;
/////////////////////
const auto d_win = [&]() {
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_1 +
......@@ -284,8 +300,9 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto d_res = d_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
//////gq
auto dq_win = [&]() {
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.d_scale_ptr) + static_cast<long_index_t>(expert_id) * d_scale_expert_stride_0;
const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.d_scale_ptr)//remember to add expert_id as expert_idx
const DScaleDataType* g_ptr = reinterpret_cast<const DScaleDataType*>(kargs.d_scale_ptr) +
static_cast<long_index_t>(expert_id) * d_scale_expert_stride_1;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.d_scale_ptr)//remember to add expert_id as expert_idx
auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr,
make_tuple(kargs.hidden_size),
......@@ -396,6 +413,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
smem,
kargs.hidden_size, // total n number
w_scale,
BlockShape::Block_N1,
shared_intermediate_size_1 * Block_N1 - kr_1 * BlockShape::Block_W1, // along N
kr_1 * BlockShape::Block_W1,
BlockShape::Block_N1); // along N
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment