Commit a759277d authored by shengnxu's avatar shengnxu
Browse files

fix some error

parent f549173b
......@@ -97,14 +97,14 @@ auto create_args(int argc, char* argv[])
.insert("tp", "8", "tensor parallel size")
.insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not")
.insert("prec_i", "bf16", "input precision")
.insert("prec_w", "bf16", "weight precision")
.insert("prec_i", "int8", "input precision")
.insert("prec_w", "int8", "weight precision")
.insert("prec_o", "bf16", "output precision")
.insert("prec_st", "auto", "token scale data type. auto will set to fp32")
.insert("prec_sw", "auto", "weight scale data type. auto will set to fp32")
.insert("prec_sq", "auto", "(dynamic) smooth quant data type. auto will set to fp32")
.insert("prec_kw", "auto", "topk-weight data type. auto will set to fp32")
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
.insert("fquant", "1", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
.insert(
"gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
.insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm")
......@@ -218,10 +218,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<GDataType> g_host({experts, shared_intermediate_size_0, hidden_size});
ck_tile::HostTensor<DDataType> d_host({experts, hidden_size, shared_intermediate_size_1});
ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<AScaleDataType> sa_host({tokens});
ck_tile::HostTensor<GScaleDataType> sg_host({shared_intermediate_size_0});
ck_tile::HostTensor<DScaleDataType> sd_host({shared_intermediate_size_1});
ck_tile::HostTensor<YSmoothScaleDataType> sy_host({shared_intermediate_size_1}); // smooth-quant
if (fused_quant == 1)
{
ck_tile::HostTensor<AScaleDataType> sa_host({tokens, topk});
} else{
ck_tile::HostTensor<AScaleDataType> sa_host({tokens});
}
ck_tile::HostTensor<GScaleDataType> sg_host({experts, shared_intermediate_size_0});
ck_tile::HostTensor<DScaleDataType> sd_host({experts, shared_intermediate_size_1});
ck_tile::HostTensor<YSmoothScaleDataType> sy_host({experts, shared_intermediate_size_1}); // smooth-quant
ck_tile::HostTensor<IndexDataType> topk_ids_host({tokens, topk}); // to be sort
ck_tile::HostTensor<TopkWeightDataType> topk_weight_host({tokens, topk}); // to be sort
......@@ -440,7 +445,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
hidden_size,
shared_intermediate_size_0,
topk,
gate_only);
gate_only,
fused_quant);
auto o_dev = o_buf.ToHost<ODataType>();
// o_dev.savetxt("gpu-out.txt", "float");
......
......@@ -75,7 +75,8 @@ void reference_fused_moe(
ck_tile::index_t hidden_size,
ck_tile::index_t intermediate_size, // this size is for gate/up
ck_tile::index_t topk,
ck_tile::index_t gate_only)
ck_tile::index_t gate_only,
ck_tile::index_t fquant)
{
assert(sorted_token_ids_host.get_num_of_dimension() == 1);
assert(sorted_weight_host.get_num_of_dimension() == 1);
......@@ -106,22 +107,40 @@ void reference_fused_moe(
return;
ck_tile::index_t i_expert = sorted_expert_ids_host.mData[i_tile];
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
if(i_token >= tokens)
ck_tile::index_t i_weight_idx;
if(fquant == 1)
{
i_weight_idx = i_token >> 24;
i_token = i_token & 0xffffff;
}
if (i_token >= tokens)
return;
ck_tile::index_t i_topk = get_topk_id(i_token, i_expert); // TODO: ugly
auto weight = sorted_weight_host.mData[i_flatten];
auto weight = sorted_weight_host.mData[i_flatten];//top k ratio?
ck_tile::HostTensor<AccDataType> acc_0({1, intermediate_size_0});
ck_tile::HostTensor<float> acc_0({1, intermediate_size_0});
// first gemm
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_0; i_n++)
{
AccDataType acc = static_cast<AccDataType>(0);
for(ck_tile::index_t i_k = 0; i_k < hidden_size; i_k++)
{
acc += type_convert<AccDataType>(a_host(i_token, i_k)) *
type_convert<AccDataType>(g_host(i_expert, i_n, i_k));
acc += type_convert<float>(a_host(i_token, i_k)) *
type_convert<float>(g_host(i_expert, i_n, i_k));
}
if (fquant == 1)
{ //smooth
acc_0(0, i_n) = acc * sa_host(i_token, i_weight_idx) * sg_host(i_expert, i_n);
} else if( fquant == 2 )
{
//dynamic
acc_0(0, i_n) = acc * sa_host(i_token) * sg_host(i_expert, i_n);
}
else
{
//no quant
acc_0(0, i_n) = acc;
}
acc_0(0, i_n) = acc;
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, acc);
}
......@@ -158,10 +177,14 @@ void reference_fused_moe(
{
AccDataType acc = static_cast<AccDataType>(0);
for(ck_tile::index_t i_k = 0; i_k < intermediate_size_1; i_k++)
{
acc += y(0, i_k) * type_convert<AccDataType>(d_host(i_expert, i_n, i_k));
{ if (fquant == 1)
{
acc += y(0, i_k) * sy_host(i_expert, i_k)* type_convert<float>(d_host(i_expert, i_n, i_k));
} else {
acc += y(0, i_k) * type_convert<float>(d_host(i_expert, i_n, i_k));
}
}
acc_1(0, i_n) = acc * weight; // multiple weight here
acc_1(0, i_n) = acc * type_convert<float>(weight); // multiple weight here
}
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
......@@ -177,7 +200,7 @@ void reference_fused_moe(
auto r = [&](auto i_token) {
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
{
AccDataType acc = type_convert<AccDataType>(0);
AccDataType acc = type_convert<float>(0);
for(ck_tile::index_t i_topk = 0; i_topk < topk; i_topk++)
{
acc += out_topk_tokens(i_token, i_topk, i_n);
......
......@@ -4,7 +4,9 @@
#pragma once
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
......
# define _DEQUAN_CVT_(a, b, c) \
" v_cvt_f32_i32 a[0], a[0] \n" \
" v_cvt_f32_i32 a[1], a[1] \n" \
" v_cvt_f32_i32 a[2], a[2] \n" \
" v_cvt_f32_i32 a[3], a[3] \n" \
" v_mul_f32 a[0], v15, a[0] \n" \
" v_mul_f32 a[1], v15, a[1] \n" \
" v_mul_f32 a[2], v15, a[2] \n" \
" v_mul_f32 a[3], v15, a[3] \n" \
" v_mul_f32 a[0], v17, a[0] row_newbcast:12 \n" \
" v_mul_f32 a[1], v17, a[1] row_newbcast:13 \n" \
" v_mul_f32 a[2], v17, a[2] row_newbcast:14 \n" \
" v_mul_f32 a[3], v17, a[3] row_newbcast:15 \n" \
#ifndef CK_TILE_FLATMM_UK_MFMA
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8
#endif
#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_INT8
# define _UK_MFMA_ "v_mfma_i32_16x16x32_i8"
#endif
# define _DEQUAN_CVT_(a0,a1,a2,a3, b, c) \
" v_cvt_f32_i32 a0, a0 \n" \
" v_cvt_f32_i32 a1, a1 \n" \
" v_cvt_f32_i32 a2, a2 \n" \
" v_cvt_f32_i32 a3, a3 \n" \
" v_mul_f32 a0, v15, a0 \n" \
" v_mul_f32 a1, v15, a1 \n" \
" v_mul_f32 a2, v15, a2 \n" \
" v_mul_f32 a3, v15, a3 \n" \
" v_mul_f32 a0, v17, a0 row_newbcast:12 \n" \
" v_mul_f32 a1, v17, a1 row_newbcast:13 \n" \
" v_mul_f32 a2, v17, a2 row_newbcast:14 \n" \
" v_mul_f32 a3, v17, a3 row_newbcast:15 \n" \
";-------------------------------\n"
"s_mov_b32 s28, %[s_res_aq0] \n"
"s_mov_b32 s29, %[s_res_aq1] \n"
"s_mov_b32 s30, %[s_res_aq2] \n"
"s_mov_b32 s31, %[s_res_aq3] \n"
"s_mov_b32 s16, %[s_res_dq0] \n"
"s_mov_b32 s17, %[s_res_dq1] \n"
"s_mov_b32 s18, %[s_res_dq2] \n"
......@@ -32,19 +43,7 @@
"s_mov_b32 s25, %[s_res_b1] \n"
"s_mov_b32 s26, %[s_res_b2] \n"
"s_mov_b32 s27, %[s_res_b3] \n"
//////////GQ/DQ/GsmQ_addr///////////////
//expert weight addr no need
// s_mul_i32 s60, s3, 32 // 00000000056C: 923CA003 s3 s_tg_idy
// s_mul_i32 s60, 4, s60 // 000000000570: 923C3C84
// s_add_u32 s40, s60, s40 // 000000000574: 8028283C s40 sw_ptr
// s_addc_u32 s41, 0, s41 // 000000000578: 82292980 s41 sw_ptr
// v_and_b32 v54, 15, v0 // 00000000057C: 266C008F
// v_lshlrev_b32 v8, 2, v54 // 000000000580: 24106C82 v8/9 w addr
// v_add_u32 v9, 64, v8 // 000000000584: 681210C0
//GQDQ addr function kkkkkkkkkkkkkk
";---------------------------------------------- \n"
" v_lshrrev_b32 v54, 4, v0 \n"
" v_lshlrev_b32 v55, 2, v54 \n"
" v_and_b32 v54, 15, v0 \n"
......@@ -55,21 +54,17 @@
" v_add_u32 v55, v54, v55 \n"
" v_lshlrev_b32 v10, 2, v55 \n"
" v_add_u32 v11, 0x00000400, v10 \n"
" s_mul_i32 s60, %[s_wave_id], 16 \n"
" s_mul_i32 s60, %[s_wave_id], 16 \n"
" s_mul_i32 s60, s60, 4 \n"
" v_add_u32 v10, s60, v10 \n"
" v_add_u32 v11, s60, v11 \n"
" v_mov_b32 v5, v10 \n"
//////////////////////////////
";---------------------------------------------- \n"
" s_mov_b32 s57, 0x00000100 \n"
" s_mov_b32 s58, 0x00001000 \n"
" s_mov_b32 s79, 0x00000400 \n"
" s_mov_b32 s59, 0x00000200 \n"
////////
//" s_mul_i32 s60, s70, 0x00000100 \n"
//" s_sub_u32 s56, s60, 0x00001000 \n"
///////////////
";---------------------------------------------- \n"
" s_mov_b32 s78, 0x00001000 \n"
" s_mov_b32 s52, 0x07060302 \n"
" s_mov_b32 s53, 0x00000400 \n"
......@@ -82,7 +77,7 @@
" v_mov_b32 v52, 0x7fff0000 \n"
" v_mov_b32 v53, 0x00007fff \n"
" s_waitcnt 0x0000 \n"
///XQ ADDR, fake token id
";---------------------------------------------- \n"
" v_mov_b32 %[v_token_id], %[v_token_id] \n"
" v_lshrrev_b32 v54, 24, %[v_token_id] \n"
" v_mul_i32_i24 v54, s66, v54 \n"
......@@ -104,8 +99,7 @@
" buffer_load_dword v21, v9, s[40:43], 0 offen \n"
" s_mov_b32 s80, 0 \n"
//---------------------v26-33 no need
// "s_nop 4\n"
";---------------------------------------------- \n"
"; -- prefetch A0\n"
"s_add_u32 m0, 0, %[s_m0_init] \n"
"buffer_load_dword %[v_os_a0], s[20:23], 0 offen lds \n"
......@@ -183,18 +177,17 @@
" s_waitcnt vmcnt(40) \n"
" s_barrier \n"
///////////////////////////////
"ds_read_b128 v[192:195], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0]\n" // 1024: N stride, 64 K stride
"ds_read_b128 v[196:199], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1]\n"
"ds_read_b128 v[200:203], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2]\n"
"ds_read_b128 v[204:207], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3]\n"
"ds_read_b128 v[208:211], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4]\n"
"ds_read_b128 v[212:215], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5]\n"
"ds_read_b128 v[216:219], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6]\n"
"ds_read_b128 v[220:223], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7]\n"
////////////////////////////
"label_start:
";---------------------------------------------- \n"
"ds_read_b128 v[192:195], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_0]\n" // 1024: N stride, 64 K stride
"ds_read_b128 v[196:199], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_1]\n"
"ds_read_b128 v[200:203], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_2]\n"
"ds_read_b128 v[204:207], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_3]\n"
"ds_read_b128 v[208:211], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_4]\n"
"ds_read_b128 v[212:215], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_5]\n"
"ds_read_b128 v[216:219], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_6]\n"
"ds_read_b128 v[220:223], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_7]\n"
";---------------------------------------------- \n"
" label_start: \n"
" s_waitcnt vmcnt(24) & lgkmcnt(0) \n"
" s_barrier \n"
_UK_MFMA_ " v[128:131], acc[0:1], v[192:193], v[128:131] \n"
......@@ -400,7 +393,7 @@
" s_waitcnt vmcnt(24) & lgkmcnt(0) \n"
" s_barrier \n"
_UK_MFMA_ " v[128:131], acc[128:129], v[224:225], v[128:131] \n"
_UK_MFMA_ " v[128:131], acc[130:131], v[226:227], v[128:131] \n"
_UK_MFMA_ " v[128:131], acc[130:131], v[226:227], v[128:131] \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[24:27], 0 offen \n"
_UK_MFMA_ " v[128:131], acc[132:133], v[228:229], v[128:131] \n"
_UK_MFMA_ " v[128:131], acc[134:135], v[230:231], v[128:131] \n"
......@@ -461,49 +454,49 @@
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[24:27], 0 offen \n"
_UK_MFMA_ " v[144:147], acc[164:165], v[228:229], v[144:147] \n"
_UK_MFMA_ " v[144:147], acc[166:167], v[230:231], v[144:147] \n"
" ds_read_b128 v[192:195], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_0]
" ds_read_b128 v[192:195], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_0] \n"
_UK_MFMA_ " v[144:147], acc[168:169], v[232:233], v[144:147] \n"
_UK_MFMA_ " v[144:147], acc[170:171], v[234:235], v[144:147] \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[24:27], 0 offen offset:1024 \n"
_UK_MFMA_ " v[144:147], acc[172:173], v[236:237], v[144:147] \n"
_UK_MFMA_ " v[144:147], acc[174:175], v[238:239], v[144:147] \n"
" ds_read_b128 v[196:199], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_1]
" ds_read_b128 v[196:199], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_1] \n"
_UK_MFMA_ " v[148:151], acc[160:161], v[240:241], v[148:151] \n"
_UK_MFMA_ " v[148:151], acc[162:163], v[242:243], v[148:151] \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[24:27], 0 offen offset:2048 \n"
_UK_MFMA_ " v[148:151], acc[164:165], v[244:245], v[148:151] \n"
_UK_MFMA_ " v[148:151], acc[166:167], v[246:247], v[148:151] \n"
" ds_read_b128 v[200:203], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_2]
" ds_read_b128 v[200:203], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_2] \n"
_UK_MFMA_ " v[148:151], acc[168:169], v[248:249], v[148:151] \n"
_UK_MFMA_ " v[148:151], acc[170:171], v[250:251], v[148:151] \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[24:27], 0 offen offset:3072 \n"
_UK_MFMA_ " v[148:151], acc[172:173], v[252:253], v[148:151] \n"
_UK_MFMA_ " v[148:151], acc[174:175], v[254:255], v[148:151] \n"
" ds_read_b128 v[204:207], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_3]
" ds_read_b128 v[204:207], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_3] \n"
_UK_MFMA_ " v[152:155], acc[176:177], v[224:225], v[152:155] \n"
_UK_MFMA_ " v[152:155], acc[178:179], v[226:227], v[152:155] \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[24:27], 0 offen \n"
_UK_MFMA_ " v[152:155], acc[180:181], v[228:229], v[152:155] \n"
_UK_MFMA_ " v[152:155], acc[182:183], v[230:231], v[152:155] \n"
" ds_read_b128 v[208:211], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_4]
" ds_read_b128 v[208:211], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_4] \n"
_UK_MFMA_ " v[152:155], acc[184:185], v[232:233], v[152:155] \n"
_UK_MFMA_ " v[152:155], acc[186:187], v[234:235], v[152:155] \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[24:27], 0 offen offset:1024 \n"
_UK_MFMA_ " v[152:155], acc[188:189], v[236:237], v[152:155] \n"
_UK_MFMA_ " v[152:155], acc[190:191], v[238:239], v[152:155] \n"
" ds_read_b128 v[212:215], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_5]
" ds_read_b128 v[212:215], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_5] \n"
_UK_MFMA_ " v[156:159], acc[176:177], v[240:241], v[156:159] \n"
_UK_MFMA_ " v[156:159], acc[178:179], v[242:243], v[156:159] \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[24:27], 0 offen offset:2048 \n"
_UK_MFMA_ " v[156:159], acc[180:181], v[244:245], v[156:159] \n"
_UK_MFMA_ " v[156:159], acc[182:183], v[246:247], v[156:159] \n"
" ds_read_b128 v[216:219], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_6]
" ds_read_b128 v[216:219], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_6] \n"
_UK_MFMA_ " v[156:159], acc[184:185], v[248:249], v[156:159] \n"
_UK_MFMA_ " v[156:159], acc[186:187], v[250:251], v[156:159] \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[24:27], 0 offen offset:3072 \n"
_UK_MFMA_ " v[156:159], acc[188:189], v[252:253], v[156:159] \n"
_UK_MFMA_ " v[156:159], acc[190:191], v[254:255], v[156:159] \n"
" ds_read_b128 v[220:223], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_7]
" ds_read_b128 v[220:223], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_7] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " v[160:163], acc[192:193], v[224:225], v[160:163] \n"
_UK_MFMA_ " v[160:163], acc[194:195], v[226:227], v[160:163] \n"
......@@ -601,7 +594,7 @@
" s_cbranch_scc0 label_end \n"
" s_branch label_start%= \n"
" label_end : \n"
//dequant
";---------------------------------------------- \n"
" v_cvt_f32_i32 v128, v128 \n"
" v_cvt_f32_i32 v129, v129 \n"
" v_cvt_f32_i32 v130, v130 \n"
......@@ -794,7 +787,7 @@
" v_mul_f32 v189, v17, v189 row_newbcast:13 \n"
" v_mul_f32 v190, v17, v190 row_newbcast:14 \n"
" v_mul_f32 v191, v17, v191 row_newbcast:15 \n"
#undef _UK_MFMA_
//dequant end
#undef _UK_MFMA_
#undef _DEQUAN_CVT_
......@@ -10,6 +10,7 @@
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
......
......@@ -198,7 +198,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
//addr in fact
auto a_coords = generate_tuple(
[&](auto i) {
return (token_id) * kargs.stride_token +
return (token_id[i]) * kargs.stride_token +
threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA;
},
number<row_ids_a.size()>{});
......@@ -254,7 +254,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
make_tuple(shared_intermediate_size_1),
number<1>{});
return g_view_;
return gq_view_;
}();
auto gq_res = gq_win.get_buffer_view().cached_buf_res_;
......@@ -345,7 +345,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto o_coords = generate_tuple(
[&](auto i) {
return token_id * kargs.stride_token +
return token_id[i] * kargs.stride_token +
threadIdx.x % (BlockShape::Block_N1 / kAlignmentO) * kAlignmentO;
},
number<row_ids_a.size()>{});
......@@ -376,6 +376,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
row_ids_a,//fake token id, 2D index for X scale
aq_res,
gq_res,
gq_res,
dq_res,
a_res,
a_coords,
......
......@@ -143,7 +143,7 @@ using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution =
// int8
using WarpGemmMfma_i32_16x16x64_int8_int8_CTransposed =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>
WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>,
2>>;
} // namespace ck_tile
......@@ -655,7 +655,7 @@ struct WarpGemmAttributeMfmaImpl_i32_16x16x32_i8
else
{
#if defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_i32_16x16x32i8(
c_vec = __builtin_amdgcn_mfma_i32_16x16x32_i8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
#elif defined(__gfx908__) || defined(__gfx90a__)
static_for<0, 8, 1>{}([&](auto k) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment