"tests/pipelines/controlnet/test_controlnet_blip_diffusion.py" did not exist on "24563ca654f6574dae93aeece8eeef69e39097e5"
Commit a759277d authored by shengnxu's avatar shengnxu
Browse files

fix some error

parent f549173b
...@@ -97,14 +97,14 @@ auto create_args(int argc, char* argv[]) ...@@ -97,14 +97,14 @@ auto create_args(int argc, char* argv[])
.insert("tp", "8", "tensor parallel size") .insert("tp", "8", "tensor parallel size")
.insert("v", "1", "cpu validation or not") .insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not") .insert("kname", "1", "print kernel name or not")
.insert("prec_i", "bf16", "input precision") .insert("prec_i", "int8", "input precision")
.insert("prec_w", "bf16", "weight precision") .insert("prec_w", "int8", "weight precision")
.insert("prec_o", "bf16", "output precision") .insert("prec_o", "bf16", "output precision")
.insert("prec_st", "auto", "token scale data type. auto will set to fp32") .insert("prec_st", "auto", "token scale data type. auto will set to fp32")
.insert("prec_sw", "auto", "weight scale data type. auto will set to fp32") .insert("prec_sw", "auto", "weight scale data type. auto will set to fp32")
.insert("prec_sq", "auto", "(dynamic) smooth quant data type. auto will set to fp32") .insert("prec_sq", "auto", "(dynamic) smooth quant data type. auto will set to fp32")
.insert("prec_kw", "auto", "topk-weight data type. auto will set to fp32") .insert("prec_kw", "auto", "topk-weight data type. auto will set to fp32")
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") .insert("fquant", "1", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
.insert( .insert(
"gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate") "gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
.insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm") .insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm")
...@@ -218,10 +218,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -218,10 +218,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<GDataType> g_host({experts, shared_intermediate_size_0, hidden_size}); ck_tile::HostTensor<GDataType> g_host({experts, shared_intermediate_size_0, hidden_size});
ck_tile::HostTensor<DDataType> d_host({experts, hidden_size, shared_intermediate_size_1}); ck_tile::HostTensor<DDataType> d_host({experts, hidden_size, shared_intermediate_size_1});
ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1}); ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<AScaleDataType> sa_host({tokens}); if (fused_quant == 1)
ck_tile::HostTensor<GScaleDataType> sg_host({shared_intermediate_size_0}); {
ck_tile::HostTensor<DScaleDataType> sd_host({shared_intermediate_size_1}); ck_tile::HostTensor<AScaleDataType> sa_host({tokens, topk});
ck_tile::HostTensor<YSmoothScaleDataType> sy_host({shared_intermediate_size_1}); // smooth-quant } else{
ck_tile::HostTensor<AScaleDataType> sa_host({tokens});
}
ck_tile::HostTensor<GScaleDataType> sg_host({experts, shared_intermediate_size_0});
ck_tile::HostTensor<DScaleDataType> sd_host({experts, shared_intermediate_size_1});
ck_tile::HostTensor<YSmoothScaleDataType> sy_host({experts, shared_intermediate_size_1}); // smooth-quant
ck_tile::HostTensor<IndexDataType> topk_ids_host({tokens, topk}); // to be sort ck_tile::HostTensor<IndexDataType> topk_ids_host({tokens, topk}); // to be sort
ck_tile::HostTensor<TopkWeightDataType> topk_weight_host({tokens, topk}); // to be sort ck_tile::HostTensor<TopkWeightDataType> topk_weight_host({tokens, topk}); // to be sort
...@@ -440,7 +445,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -440,7 +445,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
hidden_size, hidden_size,
shared_intermediate_size_0, shared_intermediate_size_0,
topk, topk,
gate_only); gate_only,
fused_quant);
auto o_dev = o_buf.ToHost<ODataType>(); auto o_dev = o_buf.ToHost<ODataType>();
// o_dev.savetxt("gpu-out.txt", "float"); // o_dev.savetxt("gpu-out.txt", "float");
......
...@@ -75,7 +75,8 @@ void reference_fused_moe( ...@@ -75,7 +75,8 @@ void reference_fused_moe(
ck_tile::index_t hidden_size, ck_tile::index_t hidden_size,
ck_tile::index_t intermediate_size, // this size is for gate/up ck_tile::index_t intermediate_size, // this size is for gate/up
ck_tile::index_t topk, ck_tile::index_t topk,
ck_tile::index_t gate_only) ck_tile::index_t gate_only,
ck_tile::index_t fquant)
{ {
assert(sorted_token_ids_host.get_num_of_dimension() == 1); assert(sorted_token_ids_host.get_num_of_dimension() == 1);
assert(sorted_weight_host.get_num_of_dimension() == 1); assert(sorted_weight_host.get_num_of_dimension() == 1);
...@@ -106,22 +107,40 @@ void reference_fused_moe( ...@@ -106,22 +107,40 @@ void reference_fused_moe(
return; return;
ck_tile::index_t i_expert = sorted_expert_ids_host.mData[i_tile]; ck_tile::index_t i_expert = sorted_expert_ids_host.mData[i_tile];
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten]; ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
if(i_token >= tokens) ck_tile::index_t i_weight_idx;
if(fquant == 1)
{
i_weight_idx = i_token >> 24;
i_token = i_token & 0xffffff;
}
if (i_token >= tokens)
return; return;
ck_tile::index_t i_topk = get_topk_id(i_token, i_expert); // TODO: ugly ck_tile::index_t i_topk = get_topk_id(i_token, i_expert); // TODO: ugly
auto weight = sorted_weight_host.mData[i_flatten]; auto weight = sorted_weight_host.mData[i_flatten];//top k ratio?
ck_tile::HostTensor<AccDataType> acc_0({1, intermediate_size_0}); ck_tile::HostTensor<float> acc_0({1, intermediate_size_0});
// first gemm // first gemm
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_0; i_n++) for(ck_tile::index_t i_n = 0; i_n < intermediate_size_0; i_n++)
{ {
AccDataType acc = static_cast<AccDataType>(0); AccDataType acc = static_cast<AccDataType>(0);
for(ck_tile::index_t i_k = 0; i_k < hidden_size; i_k++) for(ck_tile::index_t i_k = 0; i_k < hidden_size; i_k++)
{ {
acc += type_convert<AccDataType>(a_host(i_token, i_k)) * acc += type_convert<float>(a_host(i_token, i_k)) *
type_convert<AccDataType>(g_host(i_expert, i_n, i_k)); type_convert<float>(g_host(i_expert, i_n, i_k));
}
if (fquant == 1)
{ //smooth
acc_0(0, i_n) = acc * sa_host(i_token, i_weight_idx) * sg_host(i_expert, i_n);
} else if( fquant == 2 )
{
//dynamic
acc_0(0, i_n) = acc * sa_host(i_token) * sg_host(i_expert, i_n);
}
else
{
//no quant
acc_0(0, i_n) = acc;
} }
acc_0(0, i_n) = acc;
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, acc); // printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, acc);
} }
...@@ -158,10 +177,14 @@ void reference_fused_moe( ...@@ -158,10 +177,14 @@ void reference_fused_moe(
{ {
AccDataType acc = static_cast<AccDataType>(0); AccDataType acc = static_cast<AccDataType>(0);
for(ck_tile::index_t i_k = 0; i_k < intermediate_size_1; i_k++) for(ck_tile::index_t i_k = 0; i_k < intermediate_size_1; i_k++)
{ { if (fquant == 1)
acc += y(0, i_k) * type_convert<AccDataType>(d_host(i_expert, i_n, i_k)); {
acc += y(0, i_k) * sy_host(i_expert, i_k)* type_convert<float>(d_host(i_expert, i_n, i_k));
} else {
acc += y(0, i_k) * type_convert<float>(d_host(i_expert, i_n, i_k));
}
} }
acc_1(0, i_n) = acc * weight; // multiple weight here acc_1(0, i_n) = acc * type_convert<float>(weight); // multiple weight here
} }
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++) for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
...@@ -177,7 +200,7 @@ void reference_fused_moe( ...@@ -177,7 +200,7 @@ void reference_fused_moe(
auto r = [&](auto i_token) { auto r = [&](auto i_token) {
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++) for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
{ {
AccDataType acc = type_convert<AccDataType>(0); AccDataType acc = type_convert<float>(0);
for(ck_tile::index_t i_topk = 0; i_topk < topk; i_topk++) for(ck_tile::index_t i_topk = 0; i_topk < topk; i_topk++)
{ {
acc += out_topk_tokens(i_token, i_topk, i_n); acc += out_topk_tokens(i_token, i_topk, i_n);
......
...@@ -4,7 +4,9 @@ ...@@ -4,7 +4,9 @@
#pragma once #pragma once
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp" #include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp" #include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp" #include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp" #include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
......
# define _DEQUAN_CVT_(a, b, c) \ #ifndef CK_TILE_FLATMM_UK_MFMA
" v_cvt_f32_i32 a[0], a[0] \n" \ #define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8
" v_cvt_f32_i32 a[1], a[1] \n" \ #endif
" v_cvt_f32_i32 a[2], a[2] \n" \
" v_cvt_f32_i32 a[3], a[3] \n" \
" v_mul_f32 a[0], v15, a[0] \n" \
" v_mul_f32 a[1], v15, a[1] \n" \
" v_mul_f32 a[2], v15, a[2] \n" \
" v_mul_f32 a[3], v15, a[3] \n" \
" v_mul_f32 a[0], v17, a[0] row_newbcast:12 \n" \
" v_mul_f32 a[1], v17, a[1] row_newbcast:13 \n" \
" v_mul_f32 a[2], v17, a[2] row_newbcast:14 \n" \
" v_mul_f32 a[3], v17, a[3] row_newbcast:15 \n" \
#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_INT8
# define _UK_MFMA_ "v_mfma_i32_16x16x32_i8"
#endif
# define _DEQUAN_CVT_(a0,a1,a2,a3, b, c) \
" v_cvt_f32_i32 a0, a0 \n" \
" v_cvt_f32_i32 a1, a1 \n" \
" v_cvt_f32_i32 a2, a2 \n" \
" v_cvt_f32_i32 a3, a3 \n" \
" v_mul_f32 a0, v15, a0 \n" \
" v_mul_f32 a1, v15, a1 \n" \
" v_mul_f32 a2, v15, a2 \n" \
" v_mul_f32 a3, v15, a3 \n" \
" v_mul_f32 a0, v17, a0 row_newbcast:12 \n" \
" v_mul_f32 a1, v17, a1 row_newbcast:13 \n" \
" v_mul_f32 a2, v17, a2 row_newbcast:14 \n" \
" v_mul_f32 a3, v17, a3 row_newbcast:15 \n" \
";-------------------------------\n"
"s_mov_b32 s28, %[s_res_aq0] \n"
"s_mov_b32 s29, %[s_res_aq1] \n"
"s_mov_b32 s30, %[s_res_aq2] \n"
"s_mov_b32 s31, %[s_res_aq3] \n"
"s_mov_b32 s16, %[s_res_dq0] \n" "s_mov_b32 s16, %[s_res_dq0] \n"
"s_mov_b32 s17, %[s_res_dq1] \n" "s_mov_b32 s17, %[s_res_dq1] \n"
"s_mov_b32 s18, %[s_res_dq2] \n" "s_mov_b32 s18, %[s_res_dq2] \n"
...@@ -32,19 +43,7 @@ ...@@ -32,19 +43,7 @@
"s_mov_b32 s25, %[s_res_b1] \n" "s_mov_b32 s25, %[s_res_b1] \n"
"s_mov_b32 s26, %[s_res_b2] \n" "s_mov_b32 s26, %[s_res_b2] \n"
"s_mov_b32 s27, %[s_res_b3] \n" "s_mov_b32 s27, %[s_res_b3] \n"
";---------------------------------------------- \n"
//////////GQ/DQ/GsmQ_addr///////////////
//expert weight addr no need
// s_mul_i32 s60, s3, 32 // 00000000056C: 923CA003 s3 s_tg_idy
// s_mul_i32 s60, 4, s60 // 000000000570: 923C3C84
// s_add_u32 s40, s60, s40 // 000000000574: 8028283C s40 sw_ptr
// s_addc_u32 s41, 0, s41 // 000000000578: 82292980 s41 sw_ptr
// v_and_b32 v54, 15, v0 // 00000000057C: 266C008F
// v_lshlrev_b32 v8, 2, v54 // 000000000580: 24106C82 v8/9 w addr
// v_add_u32 v9, 64, v8 // 000000000584: 681210C0
//GQDQ addr function kkkkkkkkkkkkkk
" v_lshrrev_b32 v54, 4, v0 \n" " v_lshrrev_b32 v54, 4, v0 \n"
" v_lshlrev_b32 v55, 2, v54 \n" " v_lshlrev_b32 v55, 2, v54 \n"
" v_and_b32 v54, 15, v0 \n" " v_and_b32 v54, 15, v0 \n"
...@@ -55,21 +54,17 @@ ...@@ -55,21 +54,17 @@
" v_add_u32 v55, v54, v55 \n" " v_add_u32 v55, v54, v55 \n"
" v_lshlrev_b32 v10, 2, v55 \n" " v_lshlrev_b32 v10, 2, v55 \n"
" v_add_u32 v11, 0x00000400, v10 \n" " v_add_u32 v11, 0x00000400, v10 \n"
" s_mul_i32 s60, %[s_wave_id], 16 \n" " s_mul_i32 s60, %[s_wave_id], 16 \n"
" s_mul_i32 s60, s60, 4 \n" " s_mul_i32 s60, s60, 4 \n"
" v_add_u32 v10, s60, v10 \n" " v_add_u32 v10, s60, v10 \n"
" v_add_u32 v11, s60, v11 \n" " v_add_u32 v11, s60, v11 \n"
" v_mov_b32 v5, v10 \n" " v_mov_b32 v5, v10 \n"
";---------------------------------------------- \n"
//////////////////////////////
" s_mov_b32 s57, 0x00000100 \n" " s_mov_b32 s57, 0x00000100 \n"
" s_mov_b32 s58, 0x00001000 \n" " s_mov_b32 s58, 0x00001000 \n"
" s_mov_b32 s79, 0x00000400 \n" " s_mov_b32 s79, 0x00000400 \n"
" s_mov_b32 s59, 0x00000200 \n" " s_mov_b32 s59, 0x00000200 \n"
//////// ";---------------------------------------------- \n"
//" s_mul_i32 s60, s70, 0x00000100 \n"
//" s_sub_u32 s56, s60, 0x00001000 \n"
///////////////
" s_mov_b32 s78, 0x00001000 \n" " s_mov_b32 s78, 0x00001000 \n"
" s_mov_b32 s52, 0x07060302 \n" " s_mov_b32 s52, 0x07060302 \n"
" s_mov_b32 s53, 0x00000400 \n" " s_mov_b32 s53, 0x00000400 \n"
...@@ -82,7 +77,7 @@ ...@@ -82,7 +77,7 @@
" v_mov_b32 v52, 0x7fff0000 \n" " v_mov_b32 v52, 0x7fff0000 \n"
" v_mov_b32 v53, 0x00007fff \n" " v_mov_b32 v53, 0x00007fff \n"
" s_waitcnt 0x0000 \n" " s_waitcnt 0x0000 \n"
///XQ ADDR, fake token id ";---------------------------------------------- \n"
" v_mov_b32 %[v_token_id], %[v_token_id] \n" " v_mov_b32 %[v_token_id], %[v_token_id] \n"
" v_lshrrev_b32 v54, 24, %[v_token_id] \n" " v_lshrrev_b32 v54, 24, %[v_token_id] \n"
" v_mul_i32_i24 v54, s66, v54 \n" " v_mul_i32_i24 v54, s66, v54 \n"
...@@ -104,8 +99,7 @@ ...@@ -104,8 +99,7 @@
" buffer_load_dword v21, v9, s[40:43], 0 offen \n" " buffer_load_dword v21, v9, s[40:43], 0 offen \n"
" s_mov_b32 s80, 0 \n" " s_mov_b32 s80, 0 \n"
//---------------------v26-33 no need ";---------------------------------------------- \n"
// "s_nop 4\n"
"; -- prefetch A0\n" "; -- prefetch A0\n"
"s_add_u32 m0, 0, %[s_m0_init] \n" "s_add_u32 m0, 0, %[s_m0_init] \n"
"buffer_load_dword %[v_os_a0], s[20:23], 0 offen lds \n" "buffer_load_dword %[v_os_a0], s[20:23], 0 offen lds \n"
...@@ -183,18 +177,17 @@ ...@@ -183,18 +177,17 @@
" s_waitcnt vmcnt(40) \n" " s_waitcnt vmcnt(40) \n"
" s_barrier \n" " s_barrier \n"
/////////////////////////////// ";---------------------------------------------- \n"
"ds_read_b128 v[192:195], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0]\n" // 1024: N stride, 64 K stride "ds_read_b128 v[192:195], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_0]\n" // 1024: N stride, 64 K stride
"ds_read_b128 v[196:199], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1]\n" "ds_read_b128 v[196:199], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_1]\n"
"ds_read_b128 v[200:203], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2]\n" "ds_read_b128 v[200:203], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_2]\n"
"ds_read_b128 v[204:207], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3]\n" "ds_read_b128 v[204:207], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_3]\n"
"ds_read_b128 v[208:211], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4]\n" "ds_read_b128 v[208:211], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_4]\n"
"ds_read_b128 v[212:215], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5]\n" "ds_read_b128 v[212:215], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_5]\n"
"ds_read_b128 v[216:219], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6]\n" "ds_read_b128 v[216:219], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_6]\n"
"ds_read_b128 v[220:223], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7]\n" "ds_read_b128 v[220:223], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_7]\n"
//////////////////////////// ";---------------------------------------------- \n"
" label_start: \n"
"label_start:
" s_waitcnt vmcnt(24) & lgkmcnt(0) \n" " s_waitcnt vmcnt(24) & lgkmcnt(0) \n"
" s_barrier \n" " s_barrier \n"
_UK_MFMA_ " v[128:131], acc[0:1], v[192:193], v[128:131] \n" _UK_MFMA_ " v[128:131], acc[0:1], v[192:193], v[128:131] \n"
...@@ -400,7 +393,7 @@ ...@@ -400,7 +393,7 @@
" s_waitcnt vmcnt(24) & lgkmcnt(0) \n" " s_waitcnt vmcnt(24) & lgkmcnt(0) \n"
" s_barrier \n" " s_barrier \n"
_UK_MFMA_ " v[128:131], acc[128:129], v[224:225], v[128:131] \n" _UK_MFMA_ " v[128:131], acc[128:129], v[224:225], v[128:131] \n"
_UK_MFMA_ " v[128:131], acc[130:131], v[226:227], v[128:131] \n" _UK_MFMA_ " v[128:131], acc[130:131], v[226:227], v[128:131] \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[24:27], 0 offen \n" " buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[24:27], 0 offen \n"
_UK_MFMA_ " v[128:131], acc[132:133], v[228:229], v[128:131] \n" _UK_MFMA_ " v[128:131], acc[132:133], v[228:229], v[128:131] \n"
_UK_MFMA_ " v[128:131], acc[134:135], v[230:231], v[128:131] \n" _UK_MFMA_ " v[128:131], acc[134:135], v[230:231], v[128:131] \n"
...@@ -461,49 +454,49 @@ ...@@ -461,49 +454,49 @@
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[24:27], 0 offen \n" " buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[24:27], 0 offen \n"
_UK_MFMA_ " v[144:147], acc[164:165], v[228:229], v[144:147] \n" _UK_MFMA_ " v[144:147], acc[164:165], v[228:229], v[144:147] \n"
_UK_MFMA_ " v[144:147], acc[166:167], v[230:231], v[144:147] \n" _UK_MFMA_ " v[144:147], acc[166:167], v[230:231], v[144:147] \n"
" ds_read_b128 v[192:195], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_0] " ds_read_b128 v[192:195], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_0] \n"
_UK_MFMA_ " v[144:147], acc[168:169], v[232:233], v[144:147] \n" _UK_MFMA_ " v[144:147], acc[168:169], v[232:233], v[144:147] \n"
_UK_MFMA_ " v[144:147], acc[170:171], v[234:235], v[144:147] \n" _UK_MFMA_ " v[144:147], acc[170:171], v[234:235], v[144:147] \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[24:27], 0 offen offset:1024 \n" " buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[24:27], 0 offen offset:1024 \n"
_UK_MFMA_ " v[144:147], acc[172:173], v[236:237], v[144:147] \n" _UK_MFMA_ " v[144:147], acc[172:173], v[236:237], v[144:147] \n"
_UK_MFMA_ " v[144:147], acc[174:175], v[238:239], v[144:147] \n" _UK_MFMA_ " v[144:147], acc[174:175], v[238:239], v[144:147] \n"
" ds_read_b128 v[196:199], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_1] " ds_read_b128 v[196:199], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_1] \n"
_UK_MFMA_ " v[148:151], acc[160:161], v[240:241], v[148:151] \n" _UK_MFMA_ " v[148:151], acc[160:161], v[240:241], v[148:151] \n"
_UK_MFMA_ " v[148:151], acc[162:163], v[242:243], v[148:151] \n" _UK_MFMA_ " v[148:151], acc[162:163], v[242:243], v[148:151] \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[24:27], 0 offen offset:2048 \n" " buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[24:27], 0 offen offset:2048 \n"
_UK_MFMA_ " v[148:151], acc[164:165], v[244:245], v[148:151] \n" _UK_MFMA_ " v[148:151], acc[164:165], v[244:245], v[148:151] \n"
_UK_MFMA_ " v[148:151], acc[166:167], v[246:247], v[148:151] \n" _UK_MFMA_ " v[148:151], acc[166:167], v[246:247], v[148:151] \n"
" ds_read_b128 v[200:203], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_2] " ds_read_b128 v[200:203], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_2] \n"
_UK_MFMA_ " v[148:151], acc[168:169], v[248:249], v[148:151] \n" _UK_MFMA_ " v[148:151], acc[168:169], v[248:249], v[148:151] \n"
_UK_MFMA_ " v[148:151], acc[170:171], v[250:251], v[148:151] \n" _UK_MFMA_ " v[148:151], acc[170:171], v[250:251], v[148:151] \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[24:27], 0 offen offset:3072 \n" " buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[24:27], 0 offen offset:3072 \n"
_UK_MFMA_ " v[148:151], acc[172:173], v[252:253], v[148:151] \n" _UK_MFMA_ " v[148:151], acc[172:173], v[252:253], v[148:151] \n"
_UK_MFMA_ " v[148:151], acc[174:175], v[254:255], v[148:151] \n" _UK_MFMA_ " v[148:151], acc[174:175], v[254:255], v[148:151] \n"
" ds_read_b128 v[204:207], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_3] " ds_read_b128 v[204:207], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_3] \n"
_UK_MFMA_ " v[152:155], acc[176:177], v[224:225], v[152:155] \n" _UK_MFMA_ " v[152:155], acc[176:177], v[224:225], v[152:155] \n"
_UK_MFMA_ " v[152:155], acc[178:179], v[226:227], v[152:155] \n" _UK_MFMA_ " v[152:155], acc[178:179], v[226:227], v[152:155] \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[24:27], 0 offen \n" " buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[24:27], 0 offen \n"
_UK_MFMA_ " v[152:155], acc[180:181], v[228:229], v[152:155] \n" _UK_MFMA_ " v[152:155], acc[180:181], v[228:229], v[152:155] \n"
_UK_MFMA_ " v[152:155], acc[182:183], v[230:231], v[152:155] \n" _UK_MFMA_ " v[152:155], acc[182:183], v[230:231], v[152:155] \n"
" ds_read_b128 v[208:211], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_4] " ds_read_b128 v[208:211], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_4] \n"
_UK_MFMA_ " v[152:155], acc[184:185], v[232:233], v[152:155] \n" _UK_MFMA_ " v[152:155], acc[184:185], v[232:233], v[152:155] \n"
_UK_MFMA_ " v[152:155], acc[186:187], v[234:235], v[152:155] \n" _UK_MFMA_ " v[152:155], acc[186:187], v[234:235], v[152:155] \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[24:27], 0 offen offset:1024 \n" " buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[24:27], 0 offen offset:1024 \n"
_UK_MFMA_ " v[152:155], acc[188:189], v[236:237], v[152:155] \n" _UK_MFMA_ " v[152:155], acc[188:189], v[236:237], v[152:155] \n"
_UK_MFMA_ " v[152:155], acc[190:191], v[238:239], v[152:155] \n" _UK_MFMA_ " v[152:155], acc[190:191], v[238:239], v[152:155] \n"
" ds_read_b128 v[212:215], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_5] " ds_read_b128 v[212:215], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_5] \n"
_UK_MFMA_ " v[156:159], acc[176:177], v[240:241], v[156:159] \n" _UK_MFMA_ " v[156:159], acc[176:177], v[240:241], v[156:159] \n"
_UK_MFMA_ " v[156:159], acc[178:179], v[242:243], v[156:159] \n" _UK_MFMA_ " v[156:159], acc[178:179], v[242:243], v[156:159] \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[24:27], 0 offen offset:2048 \n" " buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[24:27], 0 offen offset:2048 \n"
_UK_MFMA_ " v[156:159], acc[180:181], v[244:245], v[156:159] \n" _UK_MFMA_ " v[156:159], acc[180:181], v[244:245], v[156:159] \n"
_UK_MFMA_ " v[156:159], acc[182:183], v[246:247], v[156:159] \n" _UK_MFMA_ " v[156:159], acc[182:183], v[246:247], v[156:159] \n"
" ds_read_b128 v[216:219], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_6] " ds_read_b128 v[216:219], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_6] \n"
_UK_MFMA_ " v[156:159], acc[184:185], v[248:249], v[156:159] \n" _UK_MFMA_ " v[156:159], acc[184:185], v[248:249], v[156:159] \n"
_UK_MFMA_ " v[156:159], acc[186:187], v[250:251], v[156:159] \n" _UK_MFMA_ " v[156:159], acc[186:187], v[250:251], v[156:159] \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[24:27], 0 offen offset:3072 \n" " buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[24:27], 0 offen offset:3072 \n"
_UK_MFMA_ " v[156:159], acc[188:189], v[252:253], v[156:159] \n" _UK_MFMA_ " v[156:159], acc[188:189], v[252:253], v[156:159] \n"
_UK_MFMA_ " v[156:159], acc[190:191], v[254:255], v[156:159] \n" _UK_MFMA_ " v[156:159], acc[190:191], v[254:255], v[156:159] \n"
" ds_read_b128 v[220:223], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_7] " ds_read_b128 v[220:223], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_7] \n"
" s_waitcnt vmcnt(32) \n" " s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " v[160:163], acc[192:193], v[224:225], v[160:163] \n" _UK_MFMA_ " v[160:163], acc[192:193], v[224:225], v[160:163] \n"
_UK_MFMA_ " v[160:163], acc[194:195], v[226:227], v[160:163] \n" _UK_MFMA_ " v[160:163], acc[194:195], v[226:227], v[160:163] \n"
...@@ -601,7 +594,7 @@ ...@@ -601,7 +594,7 @@
" s_cbranch_scc0 label_end \n" " s_cbranch_scc0 label_end \n"
" s_branch label_start%= \n" " s_branch label_start%= \n"
" label_end : \n" " label_end : \n"
//dequant ";---------------------------------------------- \n"
" v_cvt_f32_i32 v128, v128 \n" " v_cvt_f32_i32 v128, v128 \n"
" v_cvt_f32_i32 v129, v129 \n" " v_cvt_f32_i32 v129, v129 \n"
" v_cvt_f32_i32 v130, v130 \n" " v_cvt_f32_i32 v130, v130 \n"
...@@ -794,7 +787,7 @@ ...@@ -794,7 +787,7 @@
" v_mul_f32 v189, v17, v189 row_newbcast:13 \n" " v_mul_f32 v189, v17, v189 row_newbcast:13 \n"
" v_mul_f32 v190, v17, v190 row_newbcast:14 \n" " v_mul_f32 v190, v17, v190 row_newbcast:14 \n"
" v_mul_f32 v191, v17, v191 row_newbcast:15 \n" " v_mul_f32 v191, v17, v191 row_newbcast:15 \n"
#undef _UK_MFMA_ #undef _UK_MFMA_
//dequant end #undef _DEQUAN_CVT_
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
......
...@@ -198,7 +198,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -198,7 +198,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
//addr in fact //addr in fact
auto a_coords = generate_tuple( auto a_coords = generate_tuple(
[&](auto i) { [&](auto i) {
return (token_id) * kargs.stride_token + return (token_id[i]) * kargs.stride_token +
threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA; threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA;
}, },
number<row_ids_a.size()>{}); number<row_ids_a.size()>{});
...@@ -254,7 +254,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -254,7 +254,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
make_tuple(shared_intermediate_size_1), make_tuple(shared_intermediate_size_1),
number<1>{}); number<1>{});
return g_view_; return gq_view_;
}(); }();
auto gq_res = gq_win.get_buffer_view().cached_buf_res_; auto gq_res = gq_win.get_buffer_view().cached_buf_res_;
...@@ -345,7 +345,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -345,7 +345,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto o_coords = generate_tuple( auto o_coords = generate_tuple(
[&](auto i) { [&](auto i) {
return token_id * kargs.stride_token + return token_id[i] * kargs.stride_token +
threadIdx.x % (BlockShape::Block_N1 / kAlignmentO) * kAlignmentO; threadIdx.x % (BlockShape::Block_N1 / kAlignmentO) * kAlignmentO;
}, },
number<row_ids_a.size()>{}); number<row_ids_a.size()>{});
...@@ -376,6 +376,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 ...@@ -376,6 +376,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
row_ids_a,//fake token id, 2D index for X scale row_ids_a,//fake token id, 2D index for X scale
aq_res, aq_res,
gq_res, gq_res,
gq_res,
dq_res, dq_res,
a_res, a_res,
a_coords, a_coords,
......
...@@ -143,7 +143,7 @@ using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution = ...@@ -143,7 +143,7 @@ using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution =
// int8 // int8
using WarpGemmMfma_i32_16x16x64_int8_int8_CTransposed = using WarpGemmMfma_i32_16x16x64_int8_int8_CTransposed =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_> WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>,
2>>; 2>>;
} // namespace ck_tile } // namespace ck_tile
...@@ -655,7 +655,7 @@ struct WarpGemmAttributeMfmaImpl_i32_16x16x32_i8 ...@@ -655,7 +655,7 @@ struct WarpGemmAttributeMfmaImpl_i32_16x16x32_i8
else else
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_i32_16x16x32i8( c_vec = __builtin_amdgcn_mfma_i32_16x16x32_i8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0); bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
#elif defined(__gfx908__) || defined(__gfx90a__) #elif defined(__gfx908__) || defined(__gfx90a__)
static_for<0, 8, 1>{}([&](auto k) { static_for<0, 8, 1>{}([&](auto k) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment