Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9d5f21bf
Commit
9d5f21bf
authored
Jan 15, 2025
by
coderfeli
Browse files
fix tokens==1
parent
4be253ee
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
4 deletions
+8
-4
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
+8
-4
No files found.
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
View file @
9d5f21bf
...
@@ -186,6 +186,8 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -186,6 +186,8 @@ struct FusedMoeGemmPipeline_FlatmmUk
const
IndexDataType
expert_id
=
__builtin_amdgcn_readfirstlane
(
const
IndexDataType
expert_id
=
__builtin_amdgcn_readfirstlane
(
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
const
IndexDataType
expert_first_token
=
__builtin_amdgcn_readfirstlane
(
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_token_ids_ptr
)[
sorted_tile_id
*
32
]);
index_t
expert_stride_0
=
shared_intermediate_size_0
*
kargs
.
hidden_size
;
index_t
expert_stride_0
=
shared_intermediate_size_0
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
shared_intermediate_size_1
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
shared_intermediate_size_1
*
kargs
.
hidden_size
;
...
@@ -207,10 +209,12 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -207,10 +209,12 @@ struct FusedMoeGemmPipeline_FlatmmUk
threadIdx
.
x
%
(
BlockShape
::
Block_K0
/
kAlignmentA
)
*
kAlignmentA
;
threadIdx
.
x
%
(
BlockShape
::
Block_K0
/
kAlignmentA
)
*
kAlignmentA
;
},
},
number
<
row_ids_a
.
size
()
>
{});
number
<
row_ids_a
.
size
()
>
{});
// if(threadIdx.x==0)
if
(
expert_first_token
>=
kargs
.
num_tokens
)
// printf("row id %d\n", row_ids_a[0]);
if
(
row_ids_a
.
at
(
0
)
>=
kargs
.
num_tokens
)
return
;
return
;
// printf("tid %d %d\n", blockIdx.x, threadIdx.x);
// for (int i = 0; i < row_ids_a.size(); i++) {
// printf("%d bid %d tid %d rowid %d\n", i, blockIdx.x, threadIdx.x, row_ids_a[i]);
// }
auto
a_res
=
auto
a_res
=
make_wave_buffer_resource
(
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
),
make_wave_buffer_resource
(
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
),
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ADataType
));
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ADataType
));
...
@@ -343,7 +347,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -343,7 +347,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
// for(auto i = 0; i < 16; i++)
// for(auto i = 0; i < 16; i++)
// {
// {
// if(threadIdx.x==0) {
// if(threadIdx.x==0) {
// printf("%d, %.1f, %.1f, %.1f, %.1f\n",i, acc_0_full.get_thread_buffer()[4 * (i) + 0], acc_0_full.get_thread_buffer()[4 * (i) + 1], acc_0_full.get_thread_buffer()[4 * (i) + 2], acc_0_full.get_thread_buffer()[4 * (i) + 3]);
// printf("
i %d, tid
%d, %.1f, %.1f, %.1f, %.1f\n",i,
threadIdx.x,
acc_0_full.get_thread_buffer()[4 * (i) + 0], acc_0_full.get_thread_buffer()[4 * (i) + 1], acc_0_full.get_thread_buffer()[4 * (i) + 2], acc_0_full.get_thread_buffer()[4 * (i) + 3]);
// }
// }
// }
// }
// auto acc_0 = IsGateOnly ? acc_0_full : Policy::template GetUK_0<Problem>().MakeCBlockTileGUMerge();
// auto acc_0 = IsGateOnly ? acc_0_full : Policy::template GetUK_0<Problem>().MakeCBlockTileGUMerge();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment