Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7bc31426
"vscode:/vscode.git/clone" did not exist on "57a79819b052ed627ea2e8497224fdf7475bb027"
Commit
7bc31426
authored
Jan 09, 2025
by
carlushuang
Browse files
fix mock token id
parent
c35bb816
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
13 deletions
+28
-13
include/ck_tile/host/reference/reference_fused_moe.hpp
include/ck_tile/host/reference/reference_fused_moe.hpp
+22
-13
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
...ude/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
+3
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
+3
-0
No files found.
include/ck_tile/host/reference/reference_fused_moe.hpp
View file @
7bc31426
...
@@ -85,6 +85,24 @@ void reference_fused_moe(
...
@@ -85,6 +85,24 @@ void reference_fused_moe(
ck_tile
::
index_t
intermediate_size_0
=
intermediate_size
;
ck_tile
::
index_t
intermediate_size_0
=
intermediate_size
;
ck_tile
::
index_t
intermediate_size_1
=
intermediate_size
/
(
gate_only
?
1
:
2
);
ck_tile
::
index_t
intermediate_size_1
=
intermediate_size
/
(
gate_only
?
1
:
2
);
ck_tile
::
HostTensor
<
AccDataType
>
out_topk_tokens
({
tokens
,
topk
,
hidden_size
});
int
max_num_tokens_padded
=
topk
*
tokens
+
experts
*
block_m
-
topk
;
// assert();
auto
f
=
[
&
](
auto
i_flatten
)
{
ck_tile
::
index_t
i_tile
=
i_flatten
/
block_m
;
if
(
i_tile
>=
num_sorted_tiles
)
return
;
ck_tile
::
index_t
i_expert
=
sorted_expert_ids_host
.
mData
[
i_tile
];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
ck_tile
::
index_t
i_token
=
sorted_token_ids_host
.
mData
[
i_flatten
];
ck_tile
::
index_t
i_topk
=
i_token
>>
24
;
i_token
&=
0xffffff
;
if
(
i_token
>=
tokens
)
return
;
(
void
)
token_ids_host
;
#else
// TODO: better remove this in the future, or modify the token_id value
// TODO: better remove this in the future, or modify the token_id value
auto
get_topk_id
=
[
&
](
ck_tile
::
index_t
token_id_
,
ck_tile
::
index_t
expert_id_
)
{
auto
get_topk_id
=
[
&
](
ck_tile
::
index_t
token_id_
,
ck_tile
::
index_t
expert_id_
)
{
for
(
ck_tile
::
index_t
i_
=
0
;
i_
<
topk
;
i_
++
)
for
(
ck_tile
::
index_t
i_
=
0
;
i_
<
topk
;
i_
++
)
...
@@ -95,20 +113,11 @@ void reference_fused_moe(
...
@@ -95,20 +113,11 @@ void reference_fused_moe(
throw
std
::
runtime_error
(
"not correct token/expert pair
\n
"
);
throw
std
::
runtime_error
(
"not correct token/expert pair
\n
"
);
return
-
1
;
// TODO: not correct!!
return
-
1
;
// TODO: not correct!!
};
};
ck_tile
::
HostTensor
<
AccDataType
>
out_topk_tokens
({
tokens
,
topk
,
hidden_size
});
int
max_num_tokens_padded
=
topk
*
tokens
+
experts
*
block_m
-
topk
;
// assert();
auto
f
=
[
&
](
auto
i_flatten
)
{
ck_tile
::
index_t
i_tile
=
i_flatten
/
block_m
;
if
(
i_tile
>=
num_sorted_tiles
)
return
;
ck_tile
::
index_t
i_expert
=
sorted_expert_ids_host
.
mData
[
i_tile
];
ck_tile
::
index_t
i_token
=
sorted_token_ids_host
.
mData
[
i_flatten
];
ck_tile
::
index_t
i_token
=
sorted_token_ids_host
.
mData
[
i_flatten
];
if
(
i_token
>=
tokens
)
if
(
i_token
>=
tokens
)
return
;
return
;
ck_tile
::
index_t
i_topk
=
get_topk_id
(
i_token
,
i_expert
);
// TODO: ugly
ck_tile
::
index_t
i_topk
=
get_topk_id
(
i_token
,
i_expert
);
// TODO: ugly
#endif
auto
weight
=
sorted_weight_host
.
mData
[
i_flatten
];
auto
weight
=
sorted_weight_host
.
mData
[
i_flatten
];
ck_tile
::
HostTensor
<
AccDataType
>
acc_0
({
1
,
intermediate_size_0
});
ck_tile
::
HostTensor
<
AccDataType
>
acc_0
({
1
,
intermediate_size_0
});
...
...
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
View file @
7bc31426
...
@@ -299,6 +299,9 @@ struct FusedMoeGemmKernel
...
@@ -299,6 +299,9 @@ struct FusedMoeGemmKernel
index_t
token_id
=
index_t
token_id
=
reinterpret_cast
<
const
index_t
*>
(
kargs
.
sorted_token_ids_ptr
)[
sorted_token_id
];
reinterpret_cast
<
const
index_t
*>
(
kargs
.
sorted_token_ids_ptr
)[
sorted_token_id
];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
token_id
&=
0xffffff
;
#endif
auto
topk_weight
=
reinterpret_cast
<
const
TopkWeightDataType
*>
(
auto
topk_weight
=
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
)[
sorted_token_id
];
kargs
.
sorted_weight_ptr
)[
sorted_token_id
];
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
View file @
7bc31426
...
@@ -125,6 +125,9 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -125,6 +125,9 @@ struct FusedMoeGemmPipeline_FlatmmUk
array
<
index_t
,
n_size
>
row_ids
;
array
<
index_t
,
n_size
>
row_ids
;
static_for
<
0
,
n_size
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
n_size
,
1
>
{}([
&
](
auto
i
)
{
row_ids
.
at
(
i
)
=
sorted_token_ids_ptr
[
coords
[
i
]];
// base_coord + i * MLans;
row_ids
.
at
(
i
)
=
sorted_token_ids_ptr
[
coords
[
i
]];
// base_coord + i * MLans;
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
row_ids
.
at
(
i
)
&=
0xffffff
;
#endif
});
});
return
row_ids
;
return
row_ids
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment