Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fe56180c
Unverified
Commit
fe56180c
authored
Jul 24, 2025
by
Woosuk Kwon
Committed by
GitHub
Jul 24, 2025
Browse files
[MoE] More balanced expert sharding (#21497)
Signed-off-by:
Woosuk Kwon
<
woosuk@thinkingmachines.ai
>
parent
07d80d7b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
12 deletions
+10
-12
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+10
-12
No files found.
vllm/model_executor/layers/fused_moe/layer.py
View file @
fe56180c
...
@@ -591,22 +591,20 @@ def determine_expert_map(
...
@@ -591,22 +591,20 @@ def determine_expert_map(
if
ep_size
==
1
:
if
ep_size
==
1
:
return
(
global_num_experts
,
None
)
return
(
global_num_experts
,
None
)
local_num_experts
=
global_num_experts
//
ep_size
# Distribute experts as evenly as possible to each rank.
base_experts
=
global_num_experts
//
ep_size
remainder
=
global_num_experts
%
ep_size
if
ep_rank
<
remainder
:
local_num_experts
=
base_experts
+
1
else
:
local_num_experts
=
base_experts
# Create a tensor of size num_experts filled with -1
# Create a tensor of size num_experts filled with -1
expert_map
=
torch
.
full
((
global_num_experts
,
),
-
1
,
dtype
=
torch
.
int32
)
expert_map
=
torch
.
full
((
global_num_experts
,
),
-
1
,
dtype
=
torch
.
int32
)
# Create a expert map for the local experts
# Create a expert map for the local experts
if
ep_rank
<
(
ep_size
-
1
):
start_idx
=
ep_rank
*
base_experts
+
min
(
ep_rank
,
remainder
)
# Each non-last rank gets local_num_experts experts.
expert_map
[
start_idx
:
start_idx
+
local_num_experts
]
=
torch
.
arange
(
expert_map
[
ep_rank
*
local_num_experts
:
0
,
local_num_experts
,
dtype
=
torch
.
int32
)
(
ep_rank
+
1
)
*
local_num_experts
]
=
\
torch
.
arange
(
0
,
local_num_experts
,
dtype
=
torch
.
int32
)
else
:
# All remaining experts are assigned to the last rank.
local_num_experts
=
(
global_num_experts
-
ep_rank
*
local_num_experts
)
expert_map
[
-
local_num_experts
:]
=
\
torch
.
arange
(
0
,
local_num_experts
,
dtype
=
torch
.
int32
)
return
(
local_num_experts
,
expert_map
)
return
(
local_num_experts
,
expert_map
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment