Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
88412f9e
Commit
88412f9e
authored
Feb 17, 2025
by
coderfeli
Browse files
impl sorting count eid
parent
4b91d1ce
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
5 deletions
+9
-5
example/ck_tile/13_moe_sorting/moe_sorting.cpp
example/ck_tile/13_moe_sorting/moe_sorting.cpp
+3
-2
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
+6
-3
No files found.
example/ck_tile/13_moe_sorting/moe_sorting.cpp
View file @
88412f9e
...
...
@@ -125,7 +125,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
ck_tile
::
HostTensor
<
IndexType
>
sorted_ids_host
({
max_output_ids
},
{
1
});
ck_tile
::
HostTensor
<
WeightType
>
sorted_weights_host
({
max_output_ids
},
{
1
});
ck_tile
::
HostTensor
<
IndexType
>
sorted_expert_ids_host
({
max_output_ids
/
unit_size
},
{
1
});
ck_tile
::
HostTensor
<
IndexType
>
sorted_id_cnt_host
({
1
},
{
1
});
ck_tile
::
HostTensor
<
IndexType
>
sorted_id_cnt_host
({
1
+
max_output_ids
/
unit_size
},
{
1
});
ck_tile
::
HostTensor
<
float
>
moe_buf_host
({
moe_buf_size
});
ck_tile
::
FillUniformDistribution
<
WeightType
>
{
-
.5
f
,
.5
f
}(
weights_host
);
...
...
@@ -205,7 +205,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
{
moe_buf_dev
.
FromDevice
(
moe_buf_host
.
data
());
}
sorted_expert_ids_host
.
savetxt
(
"sorted_expert_ids_host.txt"
,
"int"
);
sorted_id_cnt_host
.
savetxt
(
"sorted_id_cnt_host.txt"
,
"int"
);
bool
rtn
=
true
;
if
(
validate
)
{
...
...
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
View file @
88412f9e
...
...
@@ -573,11 +573,13 @@ struct MoeSortingKernel
{
int
e_start
=
cumsum
[
tid
];
int
e_end
=
cumsum
[
tid
+
1
];
index_t
*
p_sorted_expert_cnts
=
p_total_tokens_post_pad
+
1
;
int
e_size
=
unit_size_mdiv
.
div
(
e_end
-
e_start
+
unit_size_mdiv
.
divisor
-
1
);
for
(
int
i
=
e_start
;
i
<
e_end
;
i
+=
unit_size_mdiv
.
divisor
)
{
p_sorted_expert_ids
[
unit_size_mdiv
.
div
(
i
)]
=
tid
;
p_sorted_expert_cnts
[]
p_sorted_expert_cnts
[
unit_size_mdiv
.
div
(
i
)]
=
e_size
;
printf
(
"tid %d size %d
\n
"
,
tid
,
e_size
);
}
}
...
...
@@ -866,7 +868,6 @@ struct MoeSortingKernel
}
__syncthreads
();
}
for
(
int
i_e
=
tid
;
i_e
<
num_experts
;
i_e
+=
block_size
)
{
int
e_start
=
smem_cumsum
(
i_e
);
...
...
@@ -894,10 +895,12 @@ struct MoeSortingKernel
if
(
local_expert_mask
[
i_e
]
==
0
)
continue
;
}
index_t
*
p_sorted_expert_cnts
=
p_total_tokens_post_pad
+
1
;
int
e_size
=
unit_size_mdiv
.
div
(
e_end
-
e_start
+
unit_size_mdiv
.
divisor
-
1
);
for
(
int
i
=
e_start
;
i
<
e_end
;
i
+=
unit_size_mdiv
.
divisor
)
{
p_sorted_expert_ids
[
unit_size_mdiv
.
div
(
i
)]
=
expert_id
;
p_sorted_expert_cnts
[
unit_size_mdiv
.
div
(
i
)]
=
e_size
;
}
}
smem_cumdup
(
num_experts
)
=
smem_cumsum
(
num_experts
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment