Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
8af425fc
Commit
8af425fc
authored
Jan 11, 2025
by
so
Browse files
fix v0 and wave id calc
parent
339a674b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
100 additions
and
92 deletions
+100
-92
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+1
-1
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_1.inc
...uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_1.inc
+33
-38
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc
...uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc
+2
-1
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_3.inc
...uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_3.inc
+1
-2
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
...lock/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
+5
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
+58
-50
No files found.
example/ck_tile/15_fused_moe/main.cpp
View file @
8af425fc
...
...
@@ -90,7 +90,7 @@ auto create_args(int argc, char* argv[])
arg_parser
.
insert
(
"t"
,
"32"
,
"num input tokens"
)
.
insert
(
"e"
,
"1"
,
"num of experts"
)
.
insert
(
"k"
,
"1"
,
"topk"
)
.
insert
(
"h"
,
"
256
"
,
"hidden_size of this model"
)
.
insert
(
"h"
,
"
512
"
,
"hidden_size of this model"
)
.
insert
(
"i"
,
"4096"
,
"intermediate_size between 2 gemms of FFN"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to hidden_size"
)
.
insert
(
"bm"
,
"32"
,
"blocking factor for sorted tokens"
)
...
...
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_1.inc
View file @
8af425fc
...
...
@@ -27,12 +27,9 @@
# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_f16"
#endif
" v_and_b32 v0, 0x3f, v0
\n
"
" v_lshrrev_b32 v3, 6, v0
\n
"
" v_readfirstlane_b32 s7, v3
\n
"
" s_waitcnt vmcnt(24)
\n
"
"buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen
\n
"
"buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024
\n
"
"
buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen
\n
"
"
buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024
\n
"
" v_mul_f32 v54, v128, v128
\n
"
" v_mul_f32 v55, v129, v129
\n
"
" v_mul_f32 v56, v130, v130
\n
"
...
...
@@ -65,7 +62,7 @@
" v_mul_f32 v129, v129, v55
\n
"
" v_mul_f32 v130, v130, v56
\n
"
" v_mul_f32 v131, v131, v57
\n
"
"buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048
\n
"
"
buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v132, v132
\n
"
" v_mul_f32 v55, v133, v133
\n
"
" v_mul_f32 v56, v134, v134
\n
"
...
...
@@ -86,7 +83,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
"buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072
\n
"
"
buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -99,7 +96,7 @@
" v_mul_f32 v133, v133, v55
\n
"
" v_mul_f32 v134, v134, v56
\n
"
" v_mul_f32 v135, v135, v57
\n
"
"buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen
\n
"
"
buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen
\n
"
" v_mul_f32 v54, v136, v136
\n
"
" v_mul_f32 v55, v137, v137
\n
"
" v_mul_f32 v56, v138, v138
\n
"
...
...
@@ -120,7 +117,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
"buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024
\n
"
"
buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -133,7 +130,7 @@
" v_mul_f32 v137, v137, v55
\n
"
" v_mul_f32 v138, v138, v56
\n
"
" v_mul_f32 v139, v139, v57
\n
"
"buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048
\n
"
"
buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v140, v140
\n
"
" v_mul_f32 v55, v141, v141
\n
"
" v_mul_f32 v56, v142, v142
\n
"
...
...
@@ -154,7 +151,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
"buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072
\n
"
"
buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -168,7 +165,7 @@
" v_mul_f32 v142, v142, v56
\n
"
" v_mul_f32 v143, v143, v57
\n
"
" s_waitcnt vmcnt(24)
\n
"
"buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen
\n
"
"
buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen
\n
"
" v_mul_f32 v54, v144, v144
\n
"
" v_mul_f32 v55, v145, v145
\n
"
" v_mul_f32 v56, v146, v146
\n
"
...
...
@@ -189,7 +186,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
"buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024
\n
"
"
buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -202,7 +199,7 @@
" v_mul_f32 v145, v145, v55
\n
"
" v_mul_f32 v146, v146, v56
\n
"
" v_mul_f32 v147, v147, v57
\n
"
"buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048
\n
"
"
buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v148, v148
\n
"
" v_mul_f32 v55, v149, v149
\n
"
" v_mul_f32 v56, v150, v150
\n
"
...
...
@@ -223,7 +220,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
"buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072
\n
"
"
buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -236,7 +233,7 @@
" v_mul_f32 v149, v149, v55
\n
"
" v_mul_f32 v150, v150, v56
\n
"
" v_mul_f32 v151, v151, v57
\n
"
"buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen
\n
"
"
buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen
\n
"
" v_mul_f32 v54, v152, v152
\n
"
" v_mul_f32 v55, v153, v153
\n
"
" v_mul_f32 v56, v154, v154
\n
"
...
...
@@ -257,7 +254,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
"buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024
\n
"
"
buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -270,7 +267,7 @@
" v_mul_f32 v153, v153, v55
\n
"
" v_mul_f32 v154, v154, v56
\n
"
" v_mul_f32 v155, v155, v57
\n
"
"buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048
\n
"
"
buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v156, v156
\n
"
" v_mul_f32 v55, v157, v157
\n
"
" v_mul_f32 v56, v158, v158
\n
"
...
...
@@ -291,7 +288,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
"buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072
\n
"
"
buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072
\n
"
" s_add_u32 s12, %[s_tile_os_b_half], s12
\n
"
" s_addc_u32 s13, 0, s13
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
...
...
@@ -307,7 +304,7 @@
" v_mul_f32 v158, v158, v56
\n
"
" v_mul_f32 v159, v159, v57
\n
"
" s_waitcnt vmcnt(24)
\n
"
"buffer_load_dwordx4 acc[64:67], %[v_os_b0], s[12:15], 0 offen
\n
"
"
buffer_load_dwordx4 acc[64:67], %[v_os_b0], s[12:15], 0 offen
\n
"
" v_mul_f32 v54, v160, v160
\n
"
" v_mul_f32 v55, v161, v161
\n
"
" v_mul_f32 v56, v162, v162
\n
"
...
...
@@ -328,7 +325,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
"buffer_load_dwordx4 acc[68:71], %[v_os_b0], s[12:15], 0 offen offset:1024
\n
"
"
buffer_load_dwordx4 acc[68:71], %[v_os_b0], s[12:15], 0 offen offset:1024
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -341,7 +338,7 @@
" v_mul_f32 v161, v161, v55
\n
"
" v_mul_f32 v162, v162, v56
\n
"
" v_mul_f32 v163, v163, v57
\n
"
"buffer_load_dwordx4 acc[72:75], %[v_os_b0], s[12:15], 0 offen offset:2048
\n
"
"
buffer_load_dwordx4 acc[72:75], %[v_os_b0], s[12:15], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v164, v164
\n
"
" v_mul_f32 v55, v165, v165
\n
"
" v_mul_f32 v56, v166, v166
\n
"
...
...
@@ -362,7 +359,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
"buffer_load_dwordx4 acc[76:79], %[v_os_b0], s[12:15], 0 offen offset:3072
\n
"
"
buffer_load_dwordx4 acc[76:79], %[v_os_b0], s[12:15], 0 offen offset:3072
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -375,7 +372,7 @@
" v_mul_f32 v165, v165, v55
\n
"
" v_mul_f32 v166, v166, v56
\n
"
" v_mul_f32 v167, v167, v57
\n
"
"buffer_load_dwordx4 acc[80:83], %[v_os_b1], s[12:15], 0 offen
\n
"
"
buffer_load_dwordx4 acc[80:83], %[v_os_b1], s[12:15], 0 offen
\n
"
" v_mul_f32 v54, v168, v168
\n
"
" v_mul_f32 v55, v169, v169
\n
"
" v_mul_f32 v56, v170, v170
\n
"
...
...
@@ -396,7 +393,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
"buffer_load_dwordx4 acc[84:87], %[v_os_b1], s[12:15], 0 offen offset:1024
\n
"
"
buffer_load_dwordx4 acc[84:87], %[v_os_b1], s[12:15], 0 offen offset:1024
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -409,7 +406,7 @@
" v_mul_f32 v169, v169, v55
\n
"
" v_mul_f32 v170, v170, v56
\n
"
" v_mul_f32 v171, v171, v57
\n
"
"buffer_load_dwordx4 acc[88:91], %[v_os_b1], s[12:15], 0 offen offset:2048
\n
"
"
buffer_load_dwordx4 acc[88:91], %[v_os_b1], s[12:15], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v172, v172
\n
"
" v_mul_f32 v55, v173, v173
\n
"
" v_mul_f32 v56, v174, v174
\n
"
...
...
@@ -430,7 +427,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
"buffer_load_dwordx4 acc[92:95], %[v_os_b1], s[12:15], 0 offen offset:3072
\n
"
"
buffer_load_dwordx4 acc[92:95], %[v_os_b1], s[12:15], 0 offen offset:3072
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -444,7 +441,7 @@
" v_mul_f32 v174, v174, v56
\n
"
" v_mul_f32 v175, v175, v57
\n
"
" s_waitcnt vmcnt(24)
\n
"
"buffer_load_dwordx4 acc[96:99], %[v_os_b2], s[12:15], 0 offen
\n
"
"
buffer_load_dwordx4 acc[96:99], %[v_os_b2], s[12:15], 0 offen
\n
"
" v_mul_f32 v54, v176, v176
\n
"
" v_mul_f32 v55, v177, v177
\n
"
" v_mul_f32 v56, v178, v178
\n
"
...
...
@@ -465,7 +462,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
"buffer_load_dwordx4 acc[100:103], %[v_os_b2], s[12:15], 0 offen offset:1024
\n
"
"
buffer_load_dwordx4 acc[100:103], %[v_os_b2], s[12:15], 0 offen offset:1024
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -478,7 +475,7 @@
" v_mul_f32 v177, v177, v55
\n
"
" v_mul_f32 v178, v178, v56
\n
"
" v_mul_f32 v179, v179, v57
\n
"
"buffer_load_dwordx4 acc[104:107], %[v_os_b2], s[12:15], 0 offen offset:2048
\n
"
"
buffer_load_dwordx4 acc[104:107], %[v_os_b2], s[12:15], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v180, v180
\n
"
" v_mul_f32 v55, v181, v181
\n
"
" v_mul_f32 v56, v182, v182
\n
"
...
...
@@ -499,7 +496,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
"buffer_load_dwordx4 acc[108:111], %[v_os_b2], s[12:15], 0 offen offset:3072
\n
"
"
buffer_load_dwordx4 acc[108:111], %[v_os_b2], s[12:15], 0 offen offset:3072
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -512,7 +509,7 @@
" v_mul_f32 v181, v181, v55
\n
"
" v_mul_f32 v182, v182, v56
\n
"
" v_mul_f32 v183, v183, v57
\n
"
"buffer_load_dwordx4 acc[112:115], %[v_os_b3], s[12:15], 0 offen
\n
"
"
buffer_load_dwordx4 acc[112:115], %[v_os_b3], s[12:15], 0 offen
\n
"
" v_mul_f32 v54, v184, v184
\n
"
" v_mul_f32 v55, v185, v185
\n
"
" v_mul_f32 v56, v186, v186
\n
"
...
...
@@ -533,7 +530,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
"buffer_load_dwordx4 acc[116:119], %[v_os_b3], s[12:15], 0 offen offset:1024
\n
"
"
buffer_load_dwordx4 acc[116:119], %[v_os_b3], s[12:15], 0 offen offset:1024
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -546,7 +543,7 @@
" v_mul_f32 v185, v185, v55
\n
"
" v_mul_f32 v186, v186, v56
\n
"
" v_mul_f32 v187, v187, v57
\n
"
"buffer_load_dwordx4 acc[120:123], %[v_os_b3], s[12:15], 0 offen offset:2048
\n
"
"
buffer_load_dwordx4 acc[120:123], %[v_os_b3], s[12:15], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v188, v188
\n
"
" v_mul_f32 v55, v189, v189
\n
"
" v_mul_f32 v56, v190, v190
\n
"
...
...
@@ -567,7 +564,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
"buffer_load_dwordx4 acc[124:127], %[v_os_b3], s[12:15], 0 offen offset:3072
\n
"
"
buffer_load_dwordx4 acc[124:127], %[v_os_b3], s[12:15], 0 offen offset:3072
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -644,7 +641,7 @@
" v_mul_f32 v189, %[smq_scale1], v189 row_newbcast:13
\n
"
" v_mul_f32 v190, %[smq_scale1], v190 row_newbcast:14
\n
"
" v_mul_f32 v191, %[smq_scale1], v191 row_newbcast:15
\n
"
"
;--
buffer_load_dword v12, %[v_os_dq], s[16:19], 0 offen
\n
"
"
buffer_load_dword v12, %[v_os_dq], s[16:19], 0 offen
\n
"
" v_mov_b32 v22, 0x358637bd
\n
"
" v_mov_b32 v23, 0x358637bd
\n
"
" v_max3_f32 v22, abs(v128), abs(v129), v22
\n
"
...
...
@@ -974,5 +971,3 @@
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc
View file @
8af425fc
...
...
@@ -166,7 +166,7 @@
" buffer_load_dwordx4 acc[224:227], %[v_os_b2], s[12:15], 0 offen
\n
"
" v_mfma_i32_16x16x32_i8 v[208:211], acc[100:101], v[148:149], v[208:211]
\n
"
" v_mfma_i32_16x16x32_i8 v[208:211], acc[102:103], v[150:151], v[208:211]
\n
"
";--
-
buffer_load_dword v13, %[v_os_dq], s[16:19], 0 offen
\n
"
";-- buffer_load_dword v13, %[v_os_dq], s[16:19], 0 offen
\n
"
" v_mfma_i32_16x16x32_i8 v[208:211], acc[104:105], v[152:153], v[208:211]
\n
"
" v_mfma_i32_16x16x32_i8 v[208:211], acc[106:107], v[154:155], v[208:211]
\n
"
" buffer_load_dwordx4 acc[228:231], %[v_os_b2], s[12:15], 0 offen offset:1024
\n
"
...
...
@@ -483,3 +483,4 @@
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_3.inc
View file @
8af425fc
...
...
@@ -214,7 +214,7 @@
" buffer_load_dwordx4 acc[96:99], %[v_os_b2], s[12:15], 0 offen
\n
"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[228:229], v[148:149], v[240:243]
\n
"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[230:231], v[150:151], v[240:243]
\n
"
"
;--
buffer_load_dword v12, %[v_os_dq], s[16:19], 0 offen
\n
"
"
buffer_load_dword v12, %[v_os_dq], s[16:19], 0 offen
\n
"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[232:233], v[152:153], v[240:243]
\n
"
" v_mfma_i32_16x16x32_i8 v[240:243], acc[234:235], v[154:155], v[240:243]
\n
"
" buffer_load_dwordx4 acc[100:103], %[v_os_b2], s[12:15], 0 offen offset:1024
\n
"
...
...
@@ -587,4 +587,3 @@
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
View file @
8af425fc
...
...
@@ -54,6 +54,10 @@
" v_mov_b32 v53, 0x00007fff
\n
"
" s_waitcnt 0x0000
\n
"
" s_mov_b32 s80, 0
\n
"
" v_lshrrev_b32 v3, 6, v0
\n
"
" v_and_b32 v0, 0x3f, v0
\n
"
" v_readfirstlane_b32 s7, v3
\n
"
";--ds write v3 gen"
" v_lshrrev_b32 v54, 4, v0
\n
"
" v_mul_i32_i24 v3, 34, v54
\n
"
" v_and_b32 v54, 15, v0
\n
"
...
...
@@ -62,6 +66,7 @@
" s_mul_i32 s60, s7, 0x00000088
\n
"
" v_add_u32 v3, s60, v3
\n
"
" v_lshlrev_b32 v3, 2, v3
\n
"
";--ds read v4 gen
\n
"
" v_lshrrev_b32 v54, 1, v0
\n
"
" v_mul_i32_i24 v4, 34, v54
\n
"
" v_and_b32 v55, 1, v0
\n
"
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
View file @
8af425fc
...
...
@@ -105,9 +105,9 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
// TODO: properlly support scatter/gather for load only
CK_TILE_DEVICE
auto
GetRowCoords_A
(
index_t
base_offset
)
{
constexpr
index_t
KLans
=
BlockShape
::
Block_K0
/
kAlignmentA
;
constexpr
index_t
MLans
=
BlockShape
::
BlockSize
/
KLans
;
constexpr
index_t
MRepeat
=
BlockShape
::
Block_M0
/
MLans
;
constexpr
index_t
KLans
=
BlockShape
::
Block_K0
/
kAlignmentA
;
//64
constexpr
index_t
MLans
=
BlockShape
::
BlockSize
/
KLans
;
//4
constexpr
index_t
MRepeat
=
BlockShape
::
Block_M0
/
MLans
;
//8
auto
base_coord
=
threadIdx
.
x
/
KLans
+
base_offset
;
...
...
@@ -424,50 +424,58 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
gqsmq_coords
,
(
reinterpret_cast
<
const
GScaleDataType
*>
(
kargs
.
g_scale_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
shared_intermediate_size_0
));
auto
smq_scale
=
GetSMQScale
(
gqsmq_coords
,
(
reinterpret_cast
<
const
YSmoothScaleDataType
*>
(
kargs
.
y_smooth_scale_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
shared_intermediate_size_0
));
if
(
threadIdx
.
x
==
95
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
)
{
printf
(
"
\n
blockIdx.x :%x, blockIdx.y :%x, d ptr: %p, wg d ptr :%x%x,gemm0 done
\n
"
,
blockIdx
.
x
,
blockIdx
.
y
,
kargs
.
d_ptr
,
d_res
[
1
],
d_res
[
0
]);
// // printf("\n wg 1 1, wave 1, row_coords_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_coords_a[number<0>{}],row_coords_a[number<1>{}],row_coords_a[number<2>{}],row_coords_a[number<3>{}], row_coords_a[number<4>{}],row_coords_a[number<5>{}],row_coords_a[number<6>{}],row_coords_a[number<7>{}]);
// // printf("\n -------------- -row_ids_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_ids_a[number<0>{}],row_ids_a[number<1>{}],row_ids_a[number<2>{}],row_ids_a[number<3>{}], row_ids_a[number<4>{}],row_ids_a[number<5>{}],row_ids_a[number<6>{}],row_ids_a[number<7>{}]);
// // printf("\n -----------thread id %x--- - token_id 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", hipThreadIdx_x , token_id[number<0>{}],token_id[number<1>{}],token_id[number<2>{}],token_id[number<3>{}], token_id[number<4>{}],token_id[number<5>{}],token_id[number<6>{}],token_id[number<7>{}]);
// // printf("\n -----------thread id %x--- - token_id , 7:%x,, \n", hipThreadIdx_x , token_id[number<7>{}]);
// // printf("\n -------------- - exec 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", o_flags[number<0>{}][0],o_flags[number<1>{}][0],o_flags[number<2>{}][0],o_flags[number<3>{}][0], o_flags[number<4>{}][0],o_flags[number<5>{}][0],o_flags[number<6>{}][0],o_flags[number<7>{}][0]);
printf
(
"
\n
token id :%x,%x,%x,%x, %x,%x,%x,%x
\n
d_coords: %x,%x,%x,%x,
\n
row_idx: %x,%x,%x,%x, %x,%x,%x,%x
\n
o_flags:%x,%x,%x,%x, %x,%x,%x,%x
\n
"
,
token_id
[
number
<
0
>
{}],
token_id
[
number
<
1
>
{}],
token_id
[
number
<
2
>
{}],
token_id
[
number
<
3
>
{}],
token_id
[
number
<
4
>
{}],
token_id
[
number
<
5
>
{}],
token_id
[
number
<
6
>
{}],
token_id
[
number
<
7
>
{}],
d_coords
[
number
<
0
>
{}],
d_coords
[
number
<
1
>
{}],
d_coords
[
number
<
2
>
{}],
d_coords
[
number
<
3
>
{}],
// d_coords[number<4>{}],
// d_coords[number<5>{}],
// d_coords[number<6>{}],
// d_coords[number<7>{}],
row_ids_a
[
number
<
0
>
{}],
row_ids_a
[
number
<
1
>
{}],
row_ids_a
[
number
<
2
>
{}],
row_ids_a
[
number
<
3
>
{}],
row_ids_a
[
number
<
4
>
{}],
row_ids_a
[
number
<
5
>
{}],
row_ids_a
[
number
<
6
>
{}],
row_ids_a
[
number
<
7
>
{}],
o_flags
[
number
<
0
>
{}][
0
],
o_flags
[
number
<
1
>
{}][
0
],
o_flags
[
number
<
2
>
{}][
0
],
o_flags
[
number
<
3
>
{}][
0
],
o_flags
[
number
<
4
>
{}][
0
],
o_flags
[
number
<
5
>
{}][
0
],
o_flags
[
number
<
6
>
{}][
0
],
o_flags
[
number
<
7
>
{}][
0
]);
// return;
}
__builtin_amdgcn_sched_barrier
(
0
);
// if(threadIdx.x == 255 && blockIdx.x == 0 && blockIdx.y == 0)
// {
// printf("\nblockIdx.x :%x, blockIdx.y :%x, d ptr: %p, o ptr: %p, wg d ptr :%x%x,gemm0 done\n", blockIdx.x, blockIdx.y,kargs.d_ptr, kargs.o_ptr,d_res[1],d_res[0]);
// // // printf("\n wg 1 1, wave 1, row_coords_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_coords_a[number<0>{}],row_coords_a[number<1>{}],row_coords_a[number<2>{}],row_coords_a[number<3>{}], row_coords_a[number<4>{}],row_coords_a[number<5>{}],row_coords_a[number<6>{}],row_coords_a[number<7>{}]);
// // // printf("\n -------------- -row_ids_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_ids_a[number<0>{}],row_ids_a[number<1>{}],row_ids_a[number<2>{}],row_ids_a[number<3>{}], row_ids_a[number<4>{}],row_ids_a[number<5>{}],row_ids_a[number<6>{}],row_ids_a[number<7>{}]);
// // // printf("\n -----------thread id %x--- - token_id 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", hipThreadIdx_x , token_id[number<0>{}],token_id[number<1>{}],token_id[number<2>{}],token_id[number<3>{}], token_id[number<4>{}],token_id[number<5>{}],token_id[number<6>{}],token_id[number<7>{}]);
// // // printf("\n -----------thread id %x--- - token_id , 7:%x,, \n", hipThreadIdx_x , token_id[number<7>{}]);
// // // printf("\n -------------- - exec 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", o_flags[number<0>{}][0],o_flags[number<1>{}][0],o_flags[number<2>{}][0],o_flags[number<3>{}][0], o_flags[number<4>{}][0],o_flags[number<5>{}][0],o_flags[number<6>{}][0],o_flags[number<7>{}][0]);
// printf("\ntoken_id :%x,%x,%x,%x, %x,%x,%x,%x \no_coords :%x,%x,%x,%x, %x,%x,%x,%x \n d_coords: %x,%x,%x,%x, \n row_idx: %x,%x,%x,%x, %x,%x,%x,%x \n o_flags:%x,%x,%x,%x, %x,%x,%x,%x \n",
// token_id[number<0>{}],
// token_id[number<1>{}],
// token_id[number<2>{}],
// token_id[number<3>{}],
// token_id[number<4>{}],
// token_id[number<5>{}],
// token_id[number<6>{}],
// token_id[number<7>{}],
// o_coords[number<0>{}],
// o_coords[number<1>{}],
// o_coords[number<2>{}],
// o_coords[number<3>{}],
// o_coords[number<4>{}],
// o_coords[number<5>{}],
// o_coords[number<6>{}],
// o_coords[number<7>{}],
// d_coords[number<0>{}],
// d_coords[number<1>{}],
// d_coords[number<2>{}],
// d_coords[number<3>{}],
// // d_coords[number<4>{}],
// // d_coords[number<5>{}],
// // d_coords[number<6>{}],
// // d_coords[number<7>{}],
// row_ids_a[number<0>{}],
// row_ids_a[number<1>{}],
// row_ids_a[number<2>{}],
// row_ids_a[number<3>{}],
// row_ids_a[number<4>{}],
// row_ids_a[number<5>{}],
// row_ids_a[number<6>{}],
// row_ids_a[number<7>{}],
// o_flags[number<0>{}][0],
// o_flags[number<1>{}][0],
// o_flags[number<2>{}][0],
// o_flags[number<3>{}][0],
// o_flags[number<4>{}][0],
// o_flags[number<5>{}][0],
// o_flags[number<6>{}][0],
// o_flags[number<7>{}][0]);
// // return;
// }
// __builtin_amdgcn_sched_barrier(0);
auto
uk_0
=
Policy
::
template
GetUK_0
<
Problem
>();
// auto acc_0= uk_0(
...
...
@@ -483,10 +491,10 @@ if(threadIdx.x == 95 && blockIdx.x == 0 && blockIdx.y == 0)
kargs
.
hidden_size
,
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
16
*
256
,
kargs
.
num_tokens
*
kargs
.
stride_token
);
// tile offset for B matrix each unroll
__builtin_amdgcn_readfirstlane
(
kargs
.
num_tokens
*
kargs
.
stride_token
)
)
;
// tile offset for B matrix each unroll
// return;
__builtin_amdgcn_sched_barrier
(
0
);
// return;
// // sweep_tile(
// acc_0,
// [&](auto idx0, auto idx1) {
...
...
@@ -515,7 +523,7 @@ if(threadIdx.x == 95 && blockIdx.x == 0 && blockIdx.y == 0)
o_coords
,
o_flags
,
smem
,
kargs
.
hidden_size
,
// total n number
__builtin_amdgcn_readfirstlane
(
kargs
.
hidden_size
)
,
// total n number
w_scale
,
smq_scale
,
BlockShape
::
Block_N1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment