Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5d00b37e
Commit
5d00b37e
authored
Jan 07, 2025
by
shengnxu
Browse files
fix loop cnt and half d buffer size
parent
2a66e080
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
18 deletions
+17
-18
include/ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
...flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
+1
-1
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc
...uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc
+6
-9
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
+10
-8
No files found.
include/ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
View file @
5d00b37e
...
@@ -72,7 +72,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_Base
...
@@ -72,7 +72,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_Base
struct
FlatmmSn_32x256x512_1x4x1_16x16x64_int8
:
public
FlatmmSn_32x256x512_1x4x1_16x16x64_Base
struct
FlatmmSn_32x256x512_1x4x1_16x16x64_int8
:
public
FlatmmSn_32x256x512_1x4x1_16x16x64_Base
{
{
using
BDataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
ODataType
=
int8
_t
;
using
ODataType
=
bf16
_t
;
using
DScaleDataType
=
float_t
;
using
DScaleDataType
=
float_t
;
// TODO: need paired with tile_window_linear!
// TODO: need paired with tile_window_linear!
...
...
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc
View file @
5d00b37e
...
@@ -205,7 +205,7 @@
...
@@ -205,7 +205,7 @@
" v_mfma_i32_16x16x32_i8 v[220:223], acc[124:125], v[188:189], v[220:223]
\n
"
" v_mfma_i32_16x16x32_i8 v[220:223], acc[124:125], v[188:189], v[220:223]
\n
"
" v_mfma_i32_16x16x32_i8 v[220:223], acc[126:127], v[190:191], v[220:223]
\n
"
" v_mfma_i32_16x16x32_i8 v[220:223], acc[126:127], v[190:191], v[220:223]
\n
"
" s_add_u32 s60, 0x00000200, s80
\n
"
" s_add_u32 s60, 0x00000200, s80
\n
"
" s_cmp_lt_u32 s60,
s81
\n
"
" s_cmp_lt_u32 s60,
%[s_loop_cnt]
\n
"
" s_cselect_b32 %[s_tile_os_b], %[s_tile_os_b], 0
\n
"
" s_cselect_b32 %[s_tile_os_b], %[s_tile_os_b], 0
\n
"
" s_cselect_b32 %[s_tile_os_b_half], %[s_tile_os_b_half], 0
\n
"
" s_cselect_b32 %[s_tile_os_b_half], %[s_tile_os_b_half], 0
\n
"
" s_cselect_b32 %[s_tile_os_dq], %[s_tile_os_dq], 0
\n
"
" s_cselect_b32 %[s_tile_os_dq], %[s_tile_os_dq], 0
\n
"
...
@@ -528,10 +528,10 @@
...
@@ -528,10 +528,10 @@
" s_mov_b64 exec, %[s_execflag_7]
\n
"
" s_mov_b64 exec, %[s_execflag_7]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o7], v79, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" global_atomic_pk_add_bf16 %[v_os_o7], v79, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_add_u32 %[s_res_o0],
s59
, %[s_res_o0]
\n
"
" s_add_u32 %[s_res_o0],
%[s_tile_os_o]
, %[s_res_o0]
\n
"
" s_addc_u32 %[s_res_o1], 0, %[s_res_o1]
\n
"
" s_addc_u32 %[s_res_o1], 0, %[s_res_o1]
\n
"
" s_addk_i32 s80, 0x0100
\n
"
" s_addk_i32 s80, 0x0100
\n
"
" s_cmp_lt_i32 s80,
s81
\n
"
" s_cmp_lt_i32 s80,
%[s_loop_cnt]
\n
"
" s_cbranch_scc0 label_end_gemm2
\n
"
" s_cbranch_scc0 label_end_gemm2
\n
"
" s_waitcnt vmcnt(41)
\n
"
" s_waitcnt vmcnt(41)
\n
"
" s_barrier
\n
"
" s_barrier
\n
"
...
@@ -702,7 +702,7 @@
...
@@ -702,7 +702,7 @@
" v_mfma_i32_16x16x32_i8 v[252:255], acc[252:253], v[188:189], v[252:255]
\n
"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[252:253], v[188:189], v[252:255]
\n
"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[254:255], v[190:191], v[252:255]
\n
"
" v_mfma_i32_16x16x32_i8 v[252:255], acc[254:255], v[190:191], v[252:255]
\n
"
" s_add_u32 s60, 0x00000200, s80
\n
"
" s_add_u32 s60, 0x00000200, s80
\n
"
" s_cmp_lt_u32 s60,
s81
\n
"
" s_cmp_lt_u32 s60,
%[s_loop_cnt]
\n
"
" s_cselect_b32 %[s_tile_os_b], %[s_tile_os_b], 0
\n
"
" s_cselect_b32 %[s_tile_os_b], %[s_tile_os_b], 0
\n
"
" s_cselect_b32 %[s_tile_os_b_half], %[s_tile_os_b_half], 0
\n
"
" s_cselect_b32 %[s_tile_os_b_half], %[s_tile_os_b_half], 0
\n
"
" s_cselect_b32 %[s_tile_os_dq], %[s_tile_os_dq], 0
\n
"
" s_cselect_b32 %[s_tile_os_dq], %[s_tile_os_dq], 0
\n
"
...
@@ -1025,10 +1025,10 @@
...
@@ -1025,10 +1025,10 @@
" s_mov_b64 exec, %[s_execflag_7]
\n
"
" s_mov_b64 exec, %[s_execflag_7]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o7], v79, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" global_atomic_pk_add_bf16 %[v_os_o7], v79, [%[s_res_o0],%[s_res_o1]] inst_offset:256
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_mov_b64 exec, s[36:37]
\n
"
" s_add_u32 %[s_res_o0],
s59
, %[s_res_o0]
\n
"
" s_add_u32 %[s_res_o0],
%[s_tile_os_o]
, %[s_res_o0]
\n
"
" s_addc_u32 %[s_res_o1], 0, %[s_res_o1]
\n
"
" s_addc_u32 %[s_res_o1], 0, %[s_res_o1]
\n
"
" s_addk_i32 s80, 0x0100
\n
"
" s_addk_i32 s80, 0x0100
\n
"
" s_cmp_lt_i32 s80,
s81
\n
"
" s_cmp_lt_i32 s80,
%[s_loop_cnt]
\n
"
" s_cbranch_scc0 label_end_gemm2
\n
"
" s_cbranch_scc0 label_end_gemm2
\n
"
" s_branch label_startgemm2
\n
"
" s_branch label_startgemm2
\n
"
" label_end_gemm2:
\n
"
" label_end_gemm2:
\n
"
...
@@ -1037,6 +1037,3 @@
...
@@ -1037,6 +1037,3 @@
#undef _UK_MFMA_
#undef _UK_MFMA_
#undef _UK_PK_CVT_
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
#undef _UK_ATOMIC_ADD_
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
View file @
5d00b37e
...
@@ -372,8 +372,8 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
...
@@ -372,8 +372,8 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
Nl_
;
// Kr0_ * Kr1_ * W_;
Nl_
;
// Kr0_ * Kr1_ * W_;
return
generate_tuple
(
return
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
constexpr
auto
i_nr_
=
number
<
i
%
Nr_
>
{};
//
constexpr auto i_nr_ = number<i % Nr_>{};
return
i
_nr_
*
shared_intermediate_size_1
*
Nw_
*
Nl_
+
return
i
*
shared_intermediate_size_1
*
Nw_
*
Nl_
+
base_os_
;
base_os_
;
},
},
number
<
num_offsets_
>
{});
number
<
num_offsets_
>
{});
...
@@ -382,7 +382,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
...
@@ -382,7 +382,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto
o_coords
=
generate_tuple
(
auto
o_coords
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
return
token_id
[
i
]
*
kargs
.
stride_token
+
return
token_id
[
i
]
*
kargs
.
stride_token
+
threadIdx
.
x
%
(
BlockShape
::
Block_N1
/
kAlignmentO
)
*
kAlignmentO
;
threadIdx
.
x
%
(
BlockShape
::
Block_N1
/
2
/
kAlignmentO
)
*
kAlignmentO
;
},
},
number
<
row_ids_a
.
size
()
>
{});
number
<
row_ids_a
.
size
()
>
{});
...
@@ -420,11 +420,13 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
...
@@ -420,11 +420,13 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
BlockShape
::
Block_Kr0
*
BlockShape
::
Block_Kr0
*
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
if
(
hipBlockIdx_x
==
0
&&
hipBlockIdx_y
==
0
&&
hipBlockIdx_z
==
0
&&
if
(
hipBlockIdx_x
==
1
&&
hipBlockIdx_y
==
1
&&
hipBlockIdx_z
==
0
&&
hipThreadIdx_x
==
5
)
hipThreadIdx_x
==
64
)
{
{
printf
(
"
\n
gemm0 done
\n
"
);
printf
(
"
\n
gemm0 done
\n
"
);
// printf("\n wg 1 1, wave 1, row_coords_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_coords_a[number<0>{}],row_coords_a[number<1>{}],row_coords_a[number<2>{}],row_coords_a[number<3>{}], row_coords_a[number<4>{}],row_coords_a[number<5>{}],row_coords_a[number<6>{}],row_coords_a[number<7>{}]);
// printf("\n -------------- -row_ids_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_ids_a[number<0>{}],row_ids_a[number<1>{}],row_ids_a[number<2>{}],row_ids_a[number<3>{}], row_ids_a[number<4>{}],row_ids_a[number<5>{}],row_ids_a[number<6>{}],row_ids_a[number<7>{}]);
printf
(
"
\n
-------------- - token_id 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,,
\n
"
,
token_id
[
number
<
0
>
{}],
token_id
[
number
<
1
>
{}],
token_id
[
number
<
2
>
{}],
token_id
[
number
<
3
>
{}],
token_id
[
number
<
4
>
{}],
token_id
[
number
<
5
>
{}],
token_id
[
number
<
6
>
{}],
token_id
[
number
<
7
>
{}]);
}
}
// sweep_tile(
// sweep_tile(
// acc_0,
// acc_0,
...
@@ -457,8 +459,8 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
...
@@ -457,8 +459,8 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
w_scale
,
w_scale
,
smq_scale
,
smq_scale
,
BlockShape
::
Block_N1
,
BlockShape
::
Block_N1
,
shared_intermediate_size_1
*
BlockShape
::
Block_N1
-
kr_1
*
BlockShape
::
Block_W1
,
// along N
shared_intermediate_size_1
*
BlockShape
::
Block_N1
-
256
*
16
,
// along N
kr_1
*
BlockShape
::
Block_W1
,
256
*
16
,
BlockShape
::
Block_N1
);
// along N
BlockShape
::
Block_N1
);
// along N
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment