Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a759277d
Commit
a759277d
authored
Jan 02, 2025
by
shengnxu
Browse files
fix some error
parent
f549173b
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
2431 additions
and
614 deletions
+2431
-614
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+14
-8
include/ck_tile/host/reference/reference_fused_moe.hpp
include/ck_tile/host/reference/reference_fused_moe.hpp
+34
-11
include/ck_tile/ops/flatmm.hpp
include/ck_tile/ops/flatmm.hpp
+2
-0
include/ck_tile/ops/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
...ps/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
+119
-270
include/ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
...flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
+185
-260
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_1.inc
...uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_1.inc
+947
-0
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc
...uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc
+1070
-0
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
...lock/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
+53
-60
include/ck_tile/ops/fused_moe.hpp
include/ck_tile/ops/fused_moe.hpp
+1
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
+4
-3
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
+1
-1
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
...e/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
+1
-1
No files found.
example/ck_tile/15_fused_moe/main.cpp
View file @
a759277d
...
...
@@ -97,14 +97,14 @@ auto create_args(int argc, char* argv[])
.
insert
(
"tp"
,
"8"
,
"tensor parallel size"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"prec_i"
,
"
bf16
"
,
"input precision"
)
.
insert
(
"prec_w"
,
"
bf16
"
,
"weight precision"
)
.
insert
(
"prec_i"
,
"
int8
"
,
"input precision"
)
.
insert
(
"prec_w"
,
"
int8
"
,
"weight precision"
)
.
insert
(
"prec_o"
,
"bf16"
,
"output precision"
)
.
insert
(
"prec_st"
,
"auto"
,
"token scale data type. auto will set to fp32"
)
.
insert
(
"prec_sw"
,
"auto"
,
"weight scale data type. auto will set to fp32"
)
.
insert
(
"prec_sq"
,
"auto"
,
"(dynamic) smooth quant data type. auto will set to fp32"
)
.
insert
(
"prec_kw"
,
"auto"
,
"topk-weight data type. auto will set to fp32"
)
.
insert
(
"fquant"
,
"
0
"
,
"fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant"
)
.
insert
(
"fquant"
,
"
1
"
,
"fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant"
)
.
insert
(
"gate_only"
,
"1"
,
"w0(gate/up) style, 0:gate+up will double interm size, 1:only gate"
)
.
insert
(
"api"
,
"0"
,
"benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm"
)
...
...
@@ -218,10 +218,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
GDataType
>
g_host
({
experts
,
shared_intermediate_size_0
,
hidden_size
});
ck_tile
::
HostTensor
<
DDataType
>
d_host
({
experts
,
hidden_size
,
shared_intermediate_size_1
});
ck_tile
::
HostTensor
<
ODataType
>
o_host
({
tokens
,
hidden_size
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
AScaleDataType
>
sa_host
({
tokens
});
ck_tile
::
HostTensor
<
GScaleDataType
>
sg_host
({
shared_intermediate_size_0
});
ck_tile
::
HostTensor
<
DScaleDataType
>
sd_host
({
shared_intermediate_size_1
});
ck_tile
::
HostTensor
<
YSmoothScaleDataType
>
sy_host
({
shared_intermediate_size_1
});
// smooth-quant
if
(
fused_quant
==
1
)
{
ck_tile
::
HostTensor
<
AScaleDataType
>
sa_host
({
tokens
,
topk
});
}
else
{
ck_tile
::
HostTensor
<
AScaleDataType
>
sa_host
({
tokens
});
}
ck_tile
::
HostTensor
<
GScaleDataType
>
sg_host
({
experts
,
shared_intermediate_size_0
});
ck_tile
::
HostTensor
<
DScaleDataType
>
sd_host
({
experts
,
shared_intermediate_size_1
});
ck_tile
::
HostTensor
<
YSmoothScaleDataType
>
sy_host
({
experts
,
shared_intermediate_size_1
});
// smooth-quant
ck_tile
::
HostTensor
<
IndexDataType
>
topk_ids_host
({
tokens
,
topk
});
// to be sort
ck_tile
::
HostTensor
<
TopkWeightDataType
>
topk_weight_host
({
tokens
,
topk
});
// to be sort
...
...
@@ -440,7 +445,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
hidden_size
,
shared_intermediate_size_0
,
topk
,
gate_only
);
gate_only
,
fused_quant
);
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
// o_dev.savetxt("gpu-out.txt", "float");
...
...
include/ck_tile/host/reference/reference_fused_moe.hpp
View file @
a759277d
...
...
@@ -75,7 +75,8 @@ void reference_fused_moe(
ck_tile
::
index_t
hidden_size
,
ck_tile
::
index_t
intermediate_size
,
// this size is for gate/up
ck_tile
::
index_t
topk
,
ck_tile
::
index_t
gate_only
)
ck_tile
::
index_t
gate_only
,
ck_tile
::
index_t
fquant
)
{
assert
(
sorted_token_ids_host
.
get_num_of_dimension
()
==
1
);
assert
(
sorted_weight_host
.
get_num_of_dimension
()
==
1
);
...
...
@@ -106,22 +107,40 @@ void reference_fused_moe(
return
;
ck_tile
::
index_t
i_expert
=
sorted_expert_ids_host
.
mData
[
i_tile
];
ck_tile
::
index_t
i_token
=
sorted_token_ids_host
.
mData
[
i_flatten
];
if
(
i_token
>=
tokens
)
ck_tile
::
index_t
i_weight_idx
;
if
(
fquant
==
1
)
{
i_weight_idx
=
i_token
>>
24
;
i_token
=
i_token
&
0xffffff
;
}
if
(
i_token
>=
tokens
)
return
;
ck_tile
::
index_t
i_topk
=
get_topk_id
(
i_token
,
i_expert
);
// TODO: ugly
auto
weight
=
sorted_weight_host
.
mData
[
i_flatten
];
auto
weight
=
sorted_weight_host
.
mData
[
i_flatten
];
//top k ratio?
ck_tile
::
HostTensor
<
AccDataType
>
acc_0
({
1
,
intermediate_size_0
});
ck_tile
::
HostTensor
<
float
>
acc_0
({
1
,
intermediate_size_0
});
// first gemm
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
intermediate_size_0
;
i_n
++
)
{
AccDataType
acc
=
static_cast
<
AccDataType
>
(
0
);
for
(
ck_tile
::
index_t
i_k
=
0
;
i_k
<
hidden_size
;
i_k
++
)
{
acc
+=
type_convert
<
AccDataType
>
(
a_host
(
i_token
,
i_k
))
*
type_convert
<
AccDataType
>
(
g_host
(
i_expert
,
i_n
,
i_k
));
acc
+=
type_convert
<
float
>
(
a_host
(
i_token
,
i_k
))
*
type_convert
<
float
>
(
g_host
(
i_expert
,
i_n
,
i_k
));
}
if
(
fquant
==
1
)
{
//smooth
acc_0
(
0
,
i_n
)
=
acc
*
sa_host
(
i_token
,
i_weight_idx
)
*
sg_host
(
i_expert
,
i_n
);
}
else
if
(
fquant
==
2
)
{
//dynamic
acc_0
(
0
,
i_n
)
=
acc
*
sa_host
(
i_token
)
*
sg_host
(
i_expert
,
i_n
);
}
else
{
//no quant
acc_0
(
0
,
i_n
)
=
acc
;
}
acc_0
(
0
,
i_n
)
=
acc
;
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, acc);
}
...
...
@@ -158,10 +177,14 @@ void reference_fused_moe(
{
AccDataType
acc
=
static_cast
<
AccDataType
>
(
0
);
for
(
ck_tile
::
index_t
i_k
=
0
;
i_k
<
intermediate_size_1
;
i_k
++
)
{
acc
+=
y
(
0
,
i_k
)
*
type_convert
<
AccDataType
>
(
d_host
(
i_expert
,
i_n
,
i_k
));
{
if
(
fquant
==
1
)
{
acc
+=
y
(
0
,
i_k
)
*
sy_host
(
i_expert
,
i_k
)
*
type_convert
<
float
>
(
d_host
(
i_expert
,
i_n
,
i_k
));
}
else
{
acc
+=
y
(
0
,
i_k
)
*
type_convert
<
float
>
(
d_host
(
i_expert
,
i_n
,
i_k
));
}
}
acc_1
(
0
,
i_n
)
=
acc
*
weight
;
// multiple weight here
acc_1
(
0
,
i_n
)
=
acc
*
type_convert
<
float
>
(
weight
)
;
// multiple weight here
}
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
hidden_size
;
i_n
++
)
...
...
@@ -177,7 +200,7 @@ void reference_fused_moe(
auto
r
=
[
&
](
auto
i_token
)
{
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
hidden_size
;
i_n
++
)
{
AccDataType
acc
=
type_convert
<
AccDataType
>
(
0
);
AccDataType
acc
=
type_convert
<
float
>
(
0
);
for
(
ck_tile
::
index_t
i_topk
=
0
;
i_topk
<
topk
;
i_topk
++
)
{
acc
+=
out_topk_tokens
(
i_token
,
i_topk
,
i_n
);
...
...
include/ck_tile/ops/flatmm.hpp
View file @
a759277d
...
...
@@ -4,7 +4,9 @@
#pragma once
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
...
...
include/ck_tile/ops/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
View file @
a759277d
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
View file @
a759277d
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_1.inc
0 → 100644
View file @
a759277d
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc
0 → 100644
View file @
a759277d
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
View file @
a759277d
# define _DEQUAN_CVT_(a, b, c) \
" v_cvt_f32_i32 a[0], a[0]
\n
"
\
" v_cvt_f32_i32 a[1], a[1]
\n
"
\
" v_cvt_f32_i32 a[2], a[2]
\n
"
\
" v_cvt_f32_i32 a[3], a[3]
\n
"
\
" v_mul_f32 a[0], v15, a[0]
\n
"
\
" v_mul_f32 a[1], v15, a[1]
\n
"
\
" v_mul_f32 a[2], v15, a[2]
\n
"
\
" v_mul_f32 a[3], v15, a[3]
\n
"
\
" v_mul_f32 a[0], v17, a[0] row_newbcast:12
\n
"
\
" v_mul_f32 a[1], v17, a[1] row_newbcast:13
\n
"
\
" v_mul_f32 a[2], v17, a[2] row_newbcast:14
\n
"
\
" v_mul_f32 a[3], v17, a[3] row_newbcast:15
\n
"
\
#ifndef CK_TILE_FLATMM_UK_MFMA
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8
#endif
#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_INT8
# define _UK_MFMA_ "v_mfma_i32_16x16x32_i8"
#endif
# define _DEQUAN_CVT_(a0,a1,a2,a3, b, c) \
" v_cvt_f32_i32 a0, a0
\n
"
\
" v_cvt_f32_i32 a1, a1
\n
"
\
" v_cvt_f32_i32 a2, a2
\n
"
\
" v_cvt_f32_i32 a3, a3
\n
"
\
" v_mul_f32 a0, v15, a0
\n
"
\
" v_mul_f32 a1, v15, a1
\n
"
\
" v_mul_f32 a2, v15, a2
\n
"
\
" v_mul_f32 a3, v15, a3
\n
"
\
" v_mul_f32 a0, v17, a0 row_newbcast:12
\n
"
\
" v_mul_f32 a1, v17, a1 row_newbcast:13
\n
"
\
" v_mul_f32 a2, v17, a2 row_newbcast:14
\n
"
\
" v_mul_f32 a3, v17, a3 row_newbcast:15
\n
"
\
";-------------------------------
\n
"
"s_mov_b32 s28, %[s_res_aq0]
\n
"
"s_mov_b32 s29, %[s_res_aq1]
\n
"
"s_mov_b32 s30, %[s_res_aq2]
\n
"
"s_mov_b32 s31, %[s_res_aq3]
\n
"
"s_mov_b32 s16, %[s_res_dq0]
\n
"
"s_mov_b32 s17, %[s_res_dq1]
\n
"
"s_mov_b32 s18, %[s_res_dq2]
\n
"
...
...
@@ -32,19 +43,7 @@
"s_mov_b32 s25, %[s_res_b1]
\n
"
"s_mov_b32 s26, %[s_res_b2]
\n
"
"s_mov_b32 s27, %[s_res_b3]
\n
"
//////////GQ/DQ/GsmQ_addr///////////////
//expert weight addr no need
// s_mul_i32 s60, s3, 32 // 00000000056C: 923CA003 s3 s_tg_idy
// s_mul_i32 s60, 4, s60 // 000000000570: 923C3C84
// s_add_u32 s40, s60, s40 // 000000000574: 8028283C s40 sw_ptr
// s_addc_u32 s41, 0, s41 // 000000000578: 82292980 s41 sw_ptr
// v_and_b32 v54, 15, v0 // 00000000057C: 266C008F
// v_lshlrev_b32 v8, 2, v54 // 000000000580: 24106C82 v8/9 w addr
// v_add_u32 v9, 64, v8 // 000000000584: 681210C0
//GQDQ addr function kkkkkkkkkkkkkk
";----------------------------------------------
\n
"
" v_lshrrev_b32 v54, 4, v0
\n
"
" v_lshlrev_b32 v55, 2, v54
\n
"
" v_and_b32 v54, 15, v0
\n
"
...
...
@@ -55,21 +54,17 @@
" v_add_u32 v55, v54, v55
\n
"
" v_lshlrev_b32 v10, 2, v55
\n
"
" v_add_u32 v11, 0x00000400, v10
\n
"
" s_mul_i32 s60, %[s_wave_id], 16
\n
"
" s_mul_i32 s60, %[s_wave_id], 16
\n
"
" s_mul_i32 s60, s60, 4
\n
"
" v_add_u32 v10, s60, v10
\n
"
" v_add_u32 v11, s60, v11
\n
"
" v_mov_b32 v5, v10
\n
"
//////////////////////////////
";----------------------------------------------
\n
"
" s_mov_b32 s57, 0x00000100
\n
"
" s_mov_b32 s58, 0x00001000
\n
"
" s_mov_b32 s79, 0x00000400
\n
"
" s_mov_b32 s59, 0x00000200
\n
"
////////
//" s_mul_i32 s60, s70, 0x00000100 \n"
//" s_sub_u32 s56, s60, 0x00001000 \n"
///////////////
";----------------------------------------------
\n
"
" s_mov_b32 s78, 0x00001000
\n
"
" s_mov_b32 s52, 0x07060302
\n
"
" s_mov_b32 s53, 0x00000400
\n
"
...
...
@@ -82,7 +77,7 @@
" v_mov_b32 v52, 0x7fff0000
\n
"
" v_mov_b32 v53, 0x00007fff
\n
"
" s_waitcnt 0x0000
\n
"
///XQ ADDR, fake token id
";----------------------------------------------
\n
"
" v_mov_b32 %[v_token_id], %[v_token_id]
\n
"
" v_lshrrev_b32 v54, 24, %[v_token_id]
\n
"
" v_mul_i32_i24 v54, s66, v54
\n
"
...
...
@@ -104,8 +99,7 @@
" buffer_load_dword v21, v9, s[40:43], 0 offen
\n
"
" s_mov_b32 s80, 0
\n
"
//---------------------v26-33 no need
// "s_nop 4\n"
";----------------------------------------------
\n
"
"; -- prefetch A0
\n
"
"s_add_u32 m0, 0, %[s_m0_init]
\n
"
"buffer_load_dword %[v_os_a0], s[20:23], 0 offen lds
\n
"
...
...
@@ -183,18 +177,17 @@
" s_waitcnt vmcnt(40)
\n
"
" s_barrier
\n
"
///////////////////////////////
"ds_read_b128 v[192:195], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0]
\n
"
// 1024: N stride, 64 K stride
"ds_read_b128 v[196:199], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1]
\n
"
"ds_read_b128 v[200:203], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2]
\n
"
"ds_read_b128 v[204:207], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3]
\n
"
"ds_read_b128 v[208:211], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4]
\n
"
"ds_read_b128 v[212:215], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5]
\n
"
"ds_read_b128 v[216:219], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6]
\n
"
"ds_read_b128 v[220:223], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7]
\n
"
////////////////////////////
"label_start:
";----------------------------------------------
\n
"
"ds_read_b128 v[192:195], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_0]
\n
"
// 1024: N stride, 64 K stride
"ds_read_b128 v[196:199], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_1]
\n
"
"ds_read_b128 v[200:203], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_2]
\n
"
"ds_read_b128 v[204:207], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_3]
\n
"
"ds_read_b128 v[208:211], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_4]
\n
"
"ds_read_b128 v[212:215], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_5]
\n
"
"ds_read_b128 v[216:219], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_6]
\n
"
"ds_read_b128 v[220:223], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_7]
\n
"
";----------------------------------------------
\n
"
" label_start:
\n
"
" s_waitcnt vmcnt(24) & lgkmcnt(0)
\n
"
" s_barrier
\n
"
_UK_MFMA_
" v[128:131], acc[0:1], v[192:193], v[128:131]
\n
"
...
...
@@ -400,7 +393,7 @@
" s_waitcnt vmcnt(24) & lgkmcnt(0)
\n
"
" s_barrier
\n
"
_UK_MFMA_
" v[128:131], acc[128:129], v[224:225], v[128:131]
\n
"
_UK_MFMA_ "
v
[
128
:
131
],
acc
[
130
:
131
],
v
[
226
:
227
],
v
[
128
:
131
]
\
n
"
_UK_MFMA_
" v[128:131], acc[130:131], v[226:227], v[128:131]
\n
"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[24:27], 0 offen
\n
"
_UK_MFMA_
" v[128:131], acc[132:133], v[228:229], v[128:131]
\n
"
_UK_MFMA_
" v[128:131], acc[134:135], v[230:231], v[128:131]
\n
"
...
...
@@ -461,49 +454,49 @@
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[24:27], 0 offen
\n
"
_UK_MFMA_
" v[144:147], acc[164:165], v[228:229], v[144:147]
\n
"
_UK_MFMA_
" v[144:147], acc[166:167], v[230:231], v[144:147]
\n
"
"
ds_read_b128
v
[
192
:
195
],
%
[
v_os_sld
]
offset
:
0
*%
[
smem_sz
]
+
%
[
sld_os_0
]
" ds_read_b128 v[192:195], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_0]
\n
"
_UK_MFMA_
" v[144:147], acc[168:169], v[232:233], v[144:147]
\n
"
_UK_MFMA_
" v[144:147], acc[170:171], v[234:235], v[144:147]
\n
"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[24:27], 0 offen offset:1024
\n
"
_UK_MFMA_
" v[144:147], acc[172:173], v[236:237], v[144:147]
\n
"
_UK_MFMA_
" v[144:147], acc[174:175], v[238:239], v[144:147]
\n
"
" ds_read_b128 v[196:199], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_1]
" ds_read_b128 v[196:199], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_1]
\n
"
_UK_MFMA_
" v[148:151], acc[160:161], v[240:241], v[148:151]
\n
"
_UK_MFMA_
" v[148:151], acc[162:163], v[242:243], v[148:151]
\n
"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[24:27], 0 offen offset:2048
\n
"
_UK_MFMA_
" v[148:151], acc[164:165], v[244:245], v[148:151]
\n
"
_UK_MFMA_
" v[148:151], acc[166:167], v[246:247], v[148:151]
\n
"
"
ds_read_b128
v
[
200
:
203
],
%
[
v_os_sld
]
offset
:
0
*%
[
smem_sz
]
+
%
[
sld_os_2
]
" ds_read_b128 v[200:203], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_2]
\n
"
_UK_MFMA_
" v[148:151], acc[168:169], v[248:249], v[148:151]
\n
"
_UK_MFMA_
" v[148:151], acc[170:171], v[250:251], v[148:151]
\n
"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[24:27], 0 offen offset:3072
\n
"
_UK_MFMA_
" v[148:151], acc[172:173], v[252:253], v[148:151]
\n
"
_UK_MFMA_
" v[148:151], acc[174:175], v[254:255], v[148:151]
\n
"
" ds_read_b128 v[204:207], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_3]
" ds_read_b128 v[204:207], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_3]
\n
"
_UK_MFMA_
" v[152:155], acc[176:177], v[224:225], v[152:155]
\n
"
_UK_MFMA_
" v[152:155], acc[178:179], v[226:227], v[152:155]
\n
"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[24:27], 0 offen
\n
"
_UK_MFMA_
" v[152:155], acc[180:181], v[228:229], v[152:155]
\n
"
_UK_MFMA_
" v[152:155], acc[182:183], v[230:231], v[152:155]
\n
"
"
ds_read_b128
v
[
208
:
211
],
%
[
v_os_sld
]
offset
:
0
*%
[
smem_sz
]
+
%
[
sld_os_4
]
" ds_read_b128 v[208:211], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_4]
\n
"
_UK_MFMA_
" v[152:155], acc[184:185], v[232:233], v[152:155]
\n
"
_UK_MFMA_
" v[152:155], acc[186:187], v[234:235], v[152:155]
\n
"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[24:27], 0 offen offset:1024
\n
"
_UK_MFMA_
" v[152:155], acc[188:189], v[236:237], v[152:155]
\n
"
_UK_MFMA_
" v[152:155], acc[190:191], v[238:239], v[152:155]
\n
"
" ds_read_b128 v[212:215], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_5]
" ds_read_b128 v[212:215], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_5]
\n
"
_UK_MFMA_
" v[156:159], acc[176:177], v[240:241], v[156:159]
\n
"
_UK_MFMA_
" v[156:159], acc[178:179], v[242:243], v[156:159]
\n
"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[24:27], 0 offen offset:2048
\n
"
_UK_MFMA_
" v[156:159], acc[180:181], v[244:245], v[156:159]
\n
"
_UK_MFMA_
" v[156:159], acc[182:183], v[246:247], v[156:159]
\n
"
"
ds_read_b128
v
[
216
:
219
],
%
[
v_os_sld
]
offset
:
0
*%
[
smem_sz
]
+
%
[
sld_os_6
]
" ds_read_b128 v[216:219], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_6]
\n
"
_UK_MFMA_
" v[156:159], acc[184:185], v[248:249], v[156:159]
\n
"
_UK_MFMA_
" v[156:159], acc[186:187], v[250:251], v[156:159]
\n
"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[24:27], 0 offen offset:3072
\n
"
_UK_MFMA_
" v[156:159], acc[188:189], v[252:253], v[156:159]
\n
"
_UK_MFMA_
" v[156:159], acc[190:191], v[254:255], v[156:159]
\n
"
" ds_read_b128 v[220:223], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_7]
" ds_read_b128 v[220:223], %[v_os_sld] offset:0*%[smem_sz] + %[sld_os_7]
\n
"
" s_waitcnt vmcnt(32)
\n
"
_UK_MFMA_
" v[160:163], acc[192:193], v[224:225], v[160:163]
\n
"
_UK_MFMA_
" v[160:163], acc[194:195], v[226:227], v[160:163]
\n
"
...
...
@@ -601,7 +594,7 @@
" s_cbranch_scc0 label_end
\n
"
" s_branch label_start%=
\n
"
" label_end :
\n
"
//dequant
";----------------------------------------------
\n
"
" v_cvt_f32_i32 v128, v128
\n
"
" v_cvt_f32_i32 v129, v129
\n
"
" v_cvt_f32_i32 v130, v130
\n
"
...
...
@@ -794,7 +787,7 @@
" v_mul_f32 v189, v17, v189 row_newbcast:13
\n
"
" v_mul_f32 v190, v17, v190 row_newbcast:14
\n
"
" v_mul_f32 v191, v17, v191 row_newbcast:15
\n
"
#undef _UK_MFMA_
//dequant end
#undef _UK_MFMA_
#undef _DEQUAN_CVT_
include/ck_tile/ops/fused_moe.hpp
View file @
a759277d
...
...
@@ -10,6 +10,7 @@
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
View file @
a759277d
...
...
@@ -198,7 +198,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
//addr in fact
auto
a_coords
=
generate_tuple
(
[
&
](
auto
i
)
{
return
(
token_id
)
*
kargs
.
stride_token
+
return
(
token_id
[
i
]
)
*
kargs
.
stride_token
+
threadIdx
.
x
%
(
BlockShape
::
Block_K0
/
kAlignmentA
)
*
kAlignmentA
;
},
number
<
row_ids_a
.
size
()
>
{});
...
...
@@ -254,7 +254,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
make_tuple
(
shared_intermediate_size_1
),
number
<
1
>
{});
return
g_view_
;
return
g
q
_view_
;
}();
auto
gq_res
=
gq_win
.
get_buffer_view
().
cached_buf_res_
;
...
...
@@ -345,7 +345,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto
o_coords
=
generate_tuple
(
[
&
](
auto
i
)
{
return
token_id
*
kargs
.
stride_token
+
return
token_id
[
i
]
*
kargs
.
stride_token
+
threadIdx
.
x
%
(
BlockShape
::
Block_N1
/
kAlignmentO
)
*
kAlignmentO
;
},
number
<
row_ids_a
.
size
()
>
{});
...
...
@@ -376,6 +376,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
row_ids_a
,
//fake token id, 2D index for X scale
aq_res
,
gq_res
,
gq_res
,
dq_res
,
a_res
,
a_coords
,
...
...
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
View file @
a759277d
...
...
@@ -143,7 +143,7 @@ using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution =
// int8
using
WarpGemmMfma_i32_16x16x64_int8_int8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_i32_16x16x32_i8
<
WGAttrCtlEnum
::
Default_
>
,
WarpGemmAttributeMfmaImpl_i32_16x16x32_i8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
}
// namespace ck_tile
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
View file @
a759277d
...
...
@@ -655,7 +655,7 @@ struct WarpGemmAttributeMfmaImpl_i32_16x16x32_i8
else
{
#if defined(__gfx94__)
c_vec
=
__builtin_amdgcn_mfma_i32_16x16x32i8
(
c_vec
=
__builtin_amdgcn_mfma_i32_16x16x32
_
i8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
#elif defined(__gfx908__) || defined(__gfx90a__)
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment