Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8f987883
Unverified
Commit
8f987883
authored
Jan 26, 2026
by
Wentao Ye
Committed by
GitHub
Jan 26, 2026
Browse files
[Refactor] Remove unused `_moe_permute` function (#33108)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
ebe0ba91
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
169 deletions
+38
-169
benchmarks/kernels/benchmark_moe_permute_unpermute.py
benchmarks/kernels/benchmark_moe_permute_unpermute.py
+38
-105
vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
.../model_executor/layers/fused_moe/moe_permute_unpermute.py
+0
-64
No files found.
benchmarks/kernels/benchmark_moe_permute_unpermute.py
View file @
8f987883
...
...
@@ -10,8 +10,6 @@ from transformers import AutoConfig
from
vllm.model_executor.layers.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.moe_permute_unpermute
import
(
_moe_permute
,
_moe_unpermute_and_reduce
,
moe_permute
,
moe_unpermute
,
)
...
...
@@ -41,7 +39,6 @@ def benchmark_permute(
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
use_customized_permute
:
bool
=
False
,
)
->
float
:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
...
...
@@ -64,29 +61,14 @@ def benchmark_permute(
input_gating
.
copy_
(
gating_output
[
i
])
def
run
():
if
use_customized_permute
:
(
permuted_hidden_states
,
a1q_scale
,
first_token_off
,
inv_perm_idx
,
m_indices
,
)
=
moe_permute
(
qhidden_states
,
a1q_scale
=
None
,
topk_ids
=
topk_ids
,
n_expert
=
num_experts
,
expert_map
=
None
,
align_block_size
=
align_block_size
,
)
else
:
(
permuted_hidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
,
)
=
_moe_permute
(
qhidden_states
,
None
,
topk_ids
,
num_experts
,
None
,
16
)
moe_permute
(
qhidden_states
,
a1q_scale
=
None
,
topk_ids
=
topk_ids
,
n_expert
=
num_experts
,
expert_map
=
None
,
align_block_size
=
align_block_size
,
)
# JIT compilation & warmup
run
()
...
...
@@ -131,11 +113,9 @@ def benchmark_unpermute(
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
use_customized_permute
:
bool
=
False
,
)
->
float
:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
output_hidden_states
=
torch
.
empty_like
(
hidden_states
)
if
use_fp8_w8a8
:
align_block_size
=
128
# deepgemm needs 128 m aligned block
qhidden_states
,
scale
=
_fp8_quantize
(
hidden_states
,
None
,
None
)
...
...
@@ -150,78 +130,37 @@ def benchmark_unpermute(
)
def
prepare
():
if
use_customized_permute
:
(
permuted_hidden_states
,
a1q_scale
,
first_token_off
,
inv_perm_idx
,
m_indices
,
)
=
moe_permute
(
qhidden_states
,
a1q_scale
=
None
,
topk_ids
=
topk_ids
,
n_expert
=
num_experts
,
expert_map
=
None
,
align_block_size
=
align_block_size
,
)
# convert to fp16/bf16 as gemm output
return
(
permuted_hidden_states
.
to
(
dtype
),
first_token_off
,
inv_perm_idx
,
m_indices
,
)
else
:
(
permuted_qhidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
,
)
=
_moe_permute
(
qhidden_states
,
None
,
topk_ids
,
num_experts
,
None
,
block_m
=
16
)
# convert to fp16/bf16 as gemm output
return
(
permuted_qhidden_states
.
to
(
dtype
),
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
,
)
(
permuted_hidden_states
,
_
,
first_token_off
,
inv_perm_idx
,
_
,
)
=
moe_permute
(
qhidden_states
,
a1q_scale
=
None
,
topk_ids
=
topk_ids
,
n_expert
=
num_experts
,
expert_map
=
None
,
align_block_size
=
align_block_size
,
)
# convert to fp16/bf16 as gemm output
return
(
permuted_hidden_states
.
to
(
dtype
),
first_token_off
,
inv_perm_idx
,
)
def
run
(
input
:
tuple
):
if
use_customized_permute
:
(
permuted_hidden_states
,
first_token_off
,
inv_perm_idx
,
m_indices
,
)
=
input
output
=
torch
.
empty_like
(
hidden_states
)
moe_unpermute
(
output
,
permuted_hidden_states
,
topk_weights
,
inv_perm_idx
,
first_token_off
,
)
else
:
(
permuted_hidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
,
)
=
input
_moe_unpermute_and_reduce
(
output_hidden_states
,
permuted_hidden_states
,
inv_perm
,
topk_weights
,
True
,
)
(
permuted_hidden_states
,
first_token_off
,
inv_perm_idx
)
=
input
output
=
torch
.
empty_like
(
hidden_states
)
moe_unpermute
(
output
,
permuted_hidden_states
,
topk_weights
,
inv_perm_idx
,
first_token_off
,
)
# JIT compilation & warmup
input
=
prepare
()
...
...
@@ -276,8 +215,7 @@ class BenchmarkWorker:
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_customized_permute
:
bool
=
False
,
)
->
tuple
[
dict
[
str
,
int
],
float
]:
)
->
tuple
[
float
,
float
]:
set_random_seed
(
self
.
seed
)
permute_time
=
benchmark_permute
(
...
...
@@ -289,7 +227,6 @@ class BenchmarkWorker:
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
100
,
use_customized_permute
=
use_customized_permute
,
)
unpermute_time
=
benchmark_unpermute
(
num_tokens
,
...
...
@@ -300,7 +237,6 @@ class BenchmarkWorker:
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
100
,
use_customized_permute
=
use_customized_permute
,
)
return
permute_time
,
unpermute_time
...
...
@@ -347,7 +283,6 @@ def main(args: argparse.Namespace):
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
dtype
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
use_customized_permute
=
args
.
use_customized_permute
if
args
.
batch_size
is
None
:
batch_sizes
=
[
...
...
@@ -399,7 +334,6 @@ def main(args: argparse.Namespace):
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_customized_permute
,
)
for
batch_size
in
batch_sizes
],
...
...
@@ -419,7 +353,6 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8_w8a8"
,
"int8_w8a16"
],
default
=
"auto"
)
parser
.
add_argument
(
"--use-customized-permute"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
)
...
...
vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
View file @
8f987883
...
...
@@ -3,70 +3,6 @@
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
,
)
from
vllm.model_executor.layers.fused_moe.utils
import
_fp8_perm
def
_moe_permute
(
curr_hidden_states
:
torch
.
Tensor
,
a1q_scale
:
torch
.
Tensor
|
None
,
curr_topk_ids
:
torch
.
Tensor
,
global_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
block_m
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.
Permute the hidden states and scales according to `sorted_token_ids`.
"""
top_k_num
=
curr_topk_ids
.
size
(
1
)
tokens_in_chunk
=
curr_hidden_states
.
size
(
0
)
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
moe_align_block_size
(
curr_topk_ids
,
block_m
,
global_num_experts
,
expert_map
,
pad_sorted_ids
=
True
)
inv_perm
:
torch
.
Tensor
|
None
=
None
num_tokens
=
top_k_num
*
tokens_in_chunk
expert_ids
=
torch
.
repeat_interleave
(
expert_ids
,
block_m
,
dim
=
0
)
inv_perm
=
torch
.
argsort
(
sorted_token_ids
)[:
num_tokens
]
# Permute according to sorted token ids.
sorted_token_ids
=
sorted_token_ids
.
clamp
(
max
=
num_tokens
-
1
)
curr_hidden_states
=
_fp8_perm
(
curr_hidden_states
,
sorted_token_ids
//
top_k_num
)
if
a1q_scale
is
not
None
:
a1q_scale
=
a1q_scale
[
sorted_token_ids
//
top_k_num
]
return
(
curr_hidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
)
def
_moe_unpermute_and_reduce
(
out
:
torch
.
Tensor
,
curr_hidden
:
torch
.
Tensor
,
inv_perm
:
torch
.
Tensor
|
None
,
topk_weight
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
)
->
None
:
"""
Unpermute the final result and apply topk_weights, then perform the final
reduction on the hidden states.
"""
M
,
topk
=
topk_weight
.
size
()
K
=
curr_hidden
.
size
(
-
1
)
if
inv_perm
is
not
None
:
curr_hidden
=
curr_hidden
[
inv_perm
,
...]
curr_hidden
=
curr_hidden
.
view
(
-
1
,
topk
,
K
)
if
not
apply_router_weight_on_input
:
curr_hidden
.
mul_
(
topk_weight
.
view
(
M
,
-
1
,
1
))
ops
.
moe_sum
(
curr_hidden
,
out
)
def
moe_permute
(
hidden_states
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment