Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
2a66e080
Commit
2a66e080
authored
Jan 06, 2025
by
shengnxu
Browse files
fix some issue, next step, res recalc,
parent
7cc808f2
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
60 additions
and
44 deletions
+60
-44
include/ck_tile/ops/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
...ps/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
+7
-3
include/ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
...flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
+8
-4
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc
...uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc
+34
-34
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
...lock/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
+3
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
+8
-3
No files found.
include/ck_tile/ops/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
View file @
2a66e080
...
@@ -295,6 +295,11 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
...
@@ -295,6 +295,11 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
const
auto
[
m0_init_value
,
size_per_issue
]
=
get_async_store_smem_info
(
a_sst
);
const
auto
[
m0_init_value
,
size_per_issue
]
=
get_async_store_smem_info
(
a_sst
);
constexpr
auto
smem_buf_size
=
constexpr
auto
smem_buf_size
=
MakeLdsLoadDesc_A
().
get_element_space_size
()
*
sizeof
(
ADataType
);
MakeLdsLoadDesc_A
().
get_element_space_size
()
*
sizeof
(
ADataType
);
// if(threadIdx.x%64 == 0 ){
// printf("wave id:%d, m0_init_value:%d, size_per_issue:%d\n",
// int(threadIdx.x/64),int(m0_init_value), int(size_per_issue));
// }
static_assert
(
a_sld
.
get_num_of_access
()
==
8
);
static_assert
(
a_sld
.
get_num_of_access
()
==
8
);
constexpr
auto
sld_os
=
generate_tuple
(
constexpr
auto
sld_os
=
generate_tuple
(
[
&
](
auto
i_access
)
{
[
&
](
auto
i_access
)
{
...
@@ -533,9 +538,8 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
...
@@ -533,9 +538,8 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
"s55"
,
"s56"
,
"s57"
,
"s58"
,
"s59"
,
"s60"
,
"s61"
,
"s62"
,
"s63"
,
"s55"
,
"s56"
,
"s57"
,
"s58"
,
"s59"
,
"s60"
,
"s61"
,
"s62"
,
"s63"
,
"s64"
,
"s65"
,
"s66"
,
"s67"
,
"s68"
,
"s69"
,
"s70"
,
"s71"
,
"s72"
,
"s64"
,
"s65"
,
"s66"
,
"s67"
,
"s68"
,
"s69"
,
"s70"
,
"s71"
,
"s72"
,
"s73"
,
"s74"
,
"s75"
,
"s76"
,
"s77"
,
"s78"
,
"s79"
,
"s80"
,
// s86 as tmp
"s73"
,
"s74"
,
"s75"
,
"s76"
,
"s77"
,
"s78"
,
"s79"
,
"s80"
,
// s86 as tmp
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v12"
,
"v13"
,
"v21"
,
"v22"
,
"v23"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v24"
,
"v25"
,
"v50"
,
"v51"
,
"v52"
,
"v53"
,
"v54"
,
"v55"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v50"
,
"v51"
,
"v52"
,
"v53"
,
"v54"
,
"v55"
,
"v56"
,
"v57"
,
"v64"
,
"v56"
,
"v57"
,
"v64"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
"v74"
,
"v75"
,
"v76"
,
"v77"
,
"v78"
,
"v79"
,
"v80"
,
"v81"
,
"v82"
,
"v74"
,
"v75"
,
"v76"
,
"v77"
,
"v78"
,
"v79"
,
"v80"
,
"v81"
,
"v82"
,
...
...
include/ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
View file @
2a66e080
...
@@ -233,8 +233,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
...
@@ -233,8 +233,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
"s55"
,
"s56"
,
"s57"
,
"s58"
,
"s59"
,
"s60"
,
"s61"
,
"s62"
,
"s63"
,
"s55"
,
"s56"
,
"s57"
,
"s58"
,
"s59"
,
"s60"
,
"s61"
,
"s62"
,
"s63"
,
"s64"
,
"s65"
,
"s66"
,
"s67"
,
"s68"
,
"s69"
,
"s70"
,
"s71"
,
"s72"
,
"s64"
,
"s65"
,
"s66"
,
"s67"
,
"s68"
,
"s69"
,
"s70"
,
"s71"
,
"s72"
,
"s73"
,
"s74"
,
"s75"
,
"s76"
,
"s77"
,
"s78"
,
"s79"
,
"s80"
,
// s86 as tmp
"s73"
,
"s74"
,
"s75"
,
"s76"
,
"s77"
,
"s78"
,
"s79"
,
"s80"
,
// s86 as tmp
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v12"
,
"v13"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v50"
,
"v51"
,
"v52"
,
"v53"
,
"v54"
,
"v55"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v50"
,
"v51"
,
"v52"
,
"v53"
,
"v54"
,
"v55"
,
"v56"
,
"v57"
,
"v64"
,
"v56"
,
"v57"
,
"v64"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
"v74"
,
"v75"
,
"v76"
,
"v77"
,
"v78"
,
"v79"
,
"v80"
,
"v81"
,
"v82"
,
"v74"
,
"v75"
,
"v76"
,
"v77"
,
"v78"
,
"v79"
,
"v80"
,
"v81"
,
"v82"
,
...
@@ -261,6 +260,12 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
...
@@ -261,6 +260,12 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
"v245"
,
"v246"
,
"v247"
,
"v248"
,
"v249"
,
"v250"
,
"v251"
,
"v252"
,
"v245"
,
"v246"
,
"v247"
,
"v248"
,
"v249"
,
"v250"
,
"v251"
,
"v252"
,
"v253"
,
"v254"
,
"v255"
"v253"
,
"v254"
,
"v255"
);
);
if
(
hipBlockIdx_x
==
0
&&
hipBlockIdx_y
==
0
&&
hipBlockIdx_z
==
0
&&
hipThreadIdx_x
==
5
)
{
printf
(
"
\n
sn0 done
\n
"
);
}
asm
volatile
(
asm
volatile
(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8
#include "uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc"
#include "uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc"
...
@@ -335,8 +340,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
...
@@ -335,8 +340,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
"s55"
,
"s56"
,
"s57"
,
"s58"
,
"s59"
,
"s60"
,
"s61"
,
"s62"
,
"s63"
,
"s55"
,
"s56"
,
"s57"
,
"s58"
,
"s59"
,
"s60"
,
"s61"
,
"s62"
,
"s63"
,
"s64"
,
"s65"
,
"s66"
,
"s67"
,
"s68"
,
"s69"
,
"s70"
,
"s71"
,
"s72"
,
"s64"
,
"s65"
,
"s66"
,
"s67"
,
"s68"
,
"s69"
,
"s70"
,
"s71"
,
"s72"
,
"s73"
,
"s74"
,
"s75"
,
"s76"
,
"s77"
,
"s78"
,
"s79"
,
"s80"
,
// s86 as tmp
"s73"
,
"s74"
,
"s75"
,
"s76"
,
"s77"
,
"s78"
,
"s79"
,
"s80"
,
// s86 as tmp
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v12"
,
"v13"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v50"
,
"v51"
,
"v52"
,
"v53"
,
"v54"
,
"v55"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v50"
,
"v51"
,
"v52"
,
"v53"
,
"v54"
,
"v55"
,
"v56"
,
"v57"
,
"v64"
,
"v56"
,
"v57"
,
"v64"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
"v74"
,
"v75"
,
"v76"
,
"v77"
,
"v78"
,
"v79"
,
"v80"
,
"v81"
,
"v82"
,
"v74"
,
"v75"
,
"v76"
,
"v77"
,
"v78"
,
"v79"
,
"v80"
,
"v81"
,
"v82"
,
...
...
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc
View file @
2a66e080
...
@@ -27,6 +27,8 @@
...
@@ -27,6 +27,8 @@
# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_f16"
# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_f16"
#endif
#endif
" s_mov_b32 s36, -1
\n
"
" s_mov_b32 s37, -1
\n
"
" s_add_u32 s12, %[s_tile_os_b], s12
\n
"
" s_add_u32 s12, %[s_tile_os_b], s12
\n
"
" s_addc_u32 s13, 0, s13
\n
"
" s_addc_u32 s13, 0, s13
\n
"
" s_add_u32 s16, %[s_tile_os_dq], s16
\n
"
" s_add_u32 s16, %[s_tile_os_dq], s16
\n
"
...
@@ -478,52 +480,52 @@
...
@@ -478,52 +480,52 @@
" ds_read_b32 v78, v4 offset:43872
\n
"
" ds_read_b32 v78, v4 offset:43872
\n
"
" ds_read_b32 v79, v4 offset:48224
\n
"
" ds_read_b32 v79, v4 offset:48224
\n
"
" s_waitcnt lgkmcnt(0)
\n
"
" s_waitcnt lgkmcnt(0)
\n
"
" s_mov_b64 exec,
s[20:21
]
\n
"
" s_mov_b64 exec,
%[s_execflag_0
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o0], v64, [%[s_res_o0],%[s_res_o1]]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o0], v64, [%[s_res_o0],%[s_res_o1]]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[20:21
]
\n
"
" s_mov_b64 exec,
%[s_execflag_0
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o0], v65, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" global_atomic_pk_add_bf16 %[v_os_o0], v65, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[22:23
]
\n
"
" s_mov_b64 exec,
%[s_execflag_1
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o1], v66, [%[s_res_o0],%[s_res_o1]]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o1], v66, [%[s_res_o0],%[s_res_o1]]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[22:23
]
\n
"
" s_mov_b64 exec,
%[s_execflag_1
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o1], v67, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" global_atomic_pk_add_bf16 %[v_os_o1], v67, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[24:25
]
\n
"
" s_mov_b64 exec,
%[s_execflag_2
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o2], v68, [%[s_res_o0],%[s_res_o1]]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o2], v68, [%[s_res_o0],%[s_res_o1]]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[24:25
]
\n
"
" s_mov_b64 exec,
%[s_execflag_2
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o2], v69, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" global_atomic_pk_add_bf16 %[v_os_o2], v69, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[26:27
]
\n
"
" s_mov_b64 exec,
%[s_execflag_3
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o3], v70, [%[s_res_o0],%[s_res_o1]]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o3], v70, [%[s_res_o0],%[s_res_o1]]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[26:27
]
\n
"
" s_mov_b64 exec,
%[s_execflag_3
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o3], v71, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" global_atomic_pk_add_bf16 %[v_os_o3], v71, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[28:29
]
\n
"
" s_mov_b64 exec,
%[s_execflag_4
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o4], v72, [%[s_res_o0],%[s_res_o1]]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o4], v72, [%[s_res_o0],%[s_res_o1]]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[28:29
]
\n
"
" s_mov_b64 exec,
%[s_execflag_4
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o4], v73, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" global_atomic_pk_add_bf16 %[v_os_o4], v73, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[30:31
]
\n
"
" s_mov_b64 exec,
%[s_execflag_5
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o5], v74, [%[s_res_o0],%[s_res_o1]]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o5], v74, [%[s_res_o0],%[s_res_o1]]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[30:31
]
\n
"
" s_mov_b64 exec,
%[s_execflag_5
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o5], v75, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" global_atomic_pk_add_bf16 %[v_os_o5], v75, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[32:33
]
\n
"
" s_mov_b64 exec,
%[s_execflag_6
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o6], v76, [%[s_res_o0],%[s_res_o1]]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o6], v76, [%[s_res_o0],%[s_res_o1]]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[32:33
]
\n
"
" s_mov_b64 exec,
%[s_execflag_6
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o6], v77, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" global_atomic_pk_add_bf16 %[v_os_o6], v77, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[34:35
]
\n
"
" s_mov_b64 exec,
%[s_execflag_7
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o7], v78, [%[s_res_o0],%[s_res_o1]]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o7], v78, [%[s_res_o0],%[s_res_o1]]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[34:35
]
\n
"
" s_mov_b64 exec,
%[s_execflag_7
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o7], v79, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" global_atomic_pk_add_bf16 %[v_os_o7], v79, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_add_u32 %[s_res_o0], s59, %[s_res_o0]
\n
"
" s_add_u32 %[s_res_o0], s59, %[s_res_o0]
\n
"
...
@@ -975,52 +977,52 @@
...
@@ -975,52 +977,52 @@
" ds_read_b32 v78, v4 offset:43872
\n
"
" ds_read_b32 v78, v4 offset:43872
\n
"
" ds_read_b32 v79, v4 offset:48224
\n
"
" ds_read_b32 v79, v4 offset:48224
\n
"
" s_waitcnt lgkmcnt(0)
\n
"
" s_waitcnt lgkmcnt(0)
\n
"
" s_mov_b64 exec,
s[20:21
]
\n
"
" s_mov_b64 exec,
%[s_execflag_0
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o0], v64, [%[s_res_o0],%[s_res_o1]]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o0], v64, [%[s_res_o0],%[s_res_o1]]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[20:21
]
\n
"
" s_mov_b64 exec,
%[s_execflag_0
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o0], v65, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" global_atomic_pk_add_bf16 %[v_os_o0], v65, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[22:23
]
\n
"
" s_mov_b64 exec,
%[s_execflag_1
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o1], v66, [%[s_res_o0],%[s_res_o1]]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o1], v66, [%[s_res_o0],%[s_res_o1]]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[22:23
]
\n
"
" s_mov_b64 exec,
%[s_execflag_1
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o1], v67, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" global_atomic_pk_add_bf16 %[v_os_o1], v67, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[24:25
]
\n
"
" s_mov_b64 exec,
%[s_execflag_2
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o2], v68, [%[s_res_o0],%[s_res_o1]]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o2], v68, [%[s_res_o0],%[s_res_o1]]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[24:25
]
\n
"
" s_mov_b64 exec,
%[s_execflag_2
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o2], v69, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" global_atomic_pk_add_bf16 %[v_os_o2], v69, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[26:27
]
\n
"
" s_mov_b64 exec,
%[s_execflag_3
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o3], v70, [%[s_res_o0],%[s_res_o1]]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o3], v70, [%[s_res_o0],%[s_res_o1]]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[26:27
]
\n
"
" s_mov_b64 exec,
%[s_execflag_3
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o3], v71, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" global_atomic_pk_add_bf16 %[v_os_o3], v71, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[28:29
]
\n
"
" s_mov_b64 exec,
%[s_execflag_4
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o4], v72, [%[s_res_o0],%[s_res_o1]]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o4], v72, [%[s_res_o0],%[s_res_o1]]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[28:29
]
\n
"
" s_mov_b64 exec,
%[s_execflag_4
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o4], v73, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" global_atomic_pk_add_bf16 %[v_os_o4], v73, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[30:31
]
\n
"
" s_mov_b64 exec,
%[s_execflag_5
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o5], v74, [%[s_res_o0],%[s_res_o1]]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o5], v74, [%[s_res_o0],%[s_res_o1]]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[30:31
]
\n
"
" s_mov_b64 exec,
%[s_execflag_5
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o5], v75, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" global_atomic_pk_add_bf16 %[v_os_o5], v75, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[32:33
]
\n
"
" s_mov_b64 exec,
%[s_execflag_6
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o6], v76, [%[s_res_o0],%[s_res_o1]]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o6], v76, [%[s_res_o0],%[s_res_o1]]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[32:33
]
\n
"
" s_mov_b64 exec,
%[s_execflag_6
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o6], v77, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" global_atomic_pk_add_bf16 %[v_os_o6], v77, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[34:35
]
\n
"
" s_mov_b64 exec,
%[s_execflag_7
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o7], v78, [%[s_res_o0],%[s_res_o1]]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o7], v78, [%[s_res_o0],%[s_res_o1]]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec,
s[34:35
]
\n
"
" s_mov_b64 exec,
%[s_execflag_7
]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o7], v79, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" global_atomic_pk_add_bf16 %[v_os_o7], v79, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_add_u32 %[s_res_o0], s59, %[s_res_o0]
\n
"
" s_add_u32 %[s_res_o0], s59, %[s_res_o0]
\n
"
...
@@ -1038,5 +1040,3 @@
...
@@ -1038,5 +1040,3 @@
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
View file @
2a66e080
...
@@ -29,6 +29,8 @@
...
@@ -29,6 +29,8 @@
"s_mov_b32 s27, %[s_res_b3]
\n
"
"s_mov_b32 s27, %[s_res_b3]
\n
"
"s_mov_b32 s16, %[s_res_dq0]
\n
"
"s_mov_b32 s16, %[s_res_dq0]
\n
"
"s_mov_b32 s17, %[s_res_dq1]
\n
"
"s_mov_b32 s17, %[s_res_dq1]
\n
"
"s_mov_b32 s18, %[s_res_dq2]
\n
"
"s_mov_b32 s19, %[s_res_dq3]
\n
"
"s_mov_b32 s12, %[s_res_d0]
\n
"
"s_mov_b32 s12, %[s_res_d0]
\n
"
"s_mov_b32 s13, %[s_res_d1]
\n
"
"s_mov_b32 s13, %[s_res_d1]
\n
"
"s_mov_b32 s14, %[s_res_d2]
\n
"
"s_mov_b32 s14, %[s_res_d2]
\n
"
...
@@ -584,3 +586,4 @@ _DEQUAN_CVT_("%[c60]","%[c61]","%[c62]","%[c63]","%[a_scale1]"," %[gq_scale1]","
...
@@ -584,3 +586,4 @@ _DEQUAN_CVT_("%[c60]","%[c61]","%[c62]","%[c63]","%[a_scale1]"," %[gq_scale1]","
#undef _DEQUAN_CVT_
#undef _DEQUAN_CVT_
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
View file @
2a66e080
...
@@ -276,10 +276,10 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
...
@@ -276,10 +276,10 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
number
<
row_ids_a
.
size
()
>
{});
number
<
row_ids_a
.
size
()
>
{});
auto
a_coords
=
generate_tuple
(
auto
a_coords
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
return
(
(
row_ids_a
[
i
])
&
0xffffff
)
*
kargs
.
stride_token
+
return
(
token_id
[
i
]
)
*
kargs
.
stride_token
+
threadIdx
.
x
%
(
BlockShape
::
Block_K0
/
kAlignmentA
)
*
kAlignmentA
;
threadIdx
.
x
%
(
BlockShape
::
Block_K0
/
kAlignmentA
)
*
kAlignmentA
;
},
},
number
<
row_ids_a
.
size
()
>
{});
number
<
token_id
.
size
()
>
{});
auto
a_res
=
auto
a_res
=
make_wave_buffer_resource
(
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
),
make_wave_buffer_resource
(
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
),
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ADataType
));
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ADataType
));
...
@@ -407,7 +407,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
...
@@ -407,7 +407,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
gqsmq_coords
,
(
reinterpret_cast
<
const
YSmoothScaleDataType
*>
(
kargs
.
y_smooth_scale_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
shared_intermediate_size_0
));
gqsmq_coords
,
(
reinterpret_cast
<
const
YSmoothScaleDataType
*>
(
kargs
.
y_smooth_scale_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
shared_intermediate_size_0
));
auto
uk_0
=
Policy
::
template
GetUK_0
<
Problem
>();
auto
uk_0
=
Policy
::
template
GetUK_0
<
Problem
>();
// auto acc_0= uk_0(
// auto acc_0= uk_0(
uk_0
(
a_scale
,
uk_0
(
a_scale
,
gq_scale
,
gq_scale
,
d_res
,
d_res
,
dq_res
,
dq_res
,
...
@@ -420,7 +420,12 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
...
@@ -420,7 +420,12 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
BlockShape
::
Block_Kr0
*
BlockShape
::
Block_Kr0
*
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
if
(
hipBlockIdx_x
==
0
&&
hipBlockIdx_y
==
0
&&
hipBlockIdx_z
==
0
&&
hipThreadIdx_x
==
5
)
{
printf
(
"
\n
gemm0 done
\n
"
);
}
// sweep_tile(
// sweep_tile(
// acc_0,
// acc_0,
// [&](auto idx0, auto idx1) {
// [&](auto idx0, auto idx1) {
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment