Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9d880f59
Unverified
Commit
9d880f59
authored
Jun 12, 2025
by
Varun Sundar Rabindranath
Committed by
GitHub
Jun 12, 2025
Browse files
[Misc] Turn MOE_DP_CHUNK_SIZE into an env var (#19506)
parent
017ef648
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
8 deletions
+18
-8
vllm/envs.py
vllm/envs.py
+9
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+9
-8
No files found.
vllm/envs.py
View file @
9d880f59
...
...
@@ -112,6 +112,7 @@ if TYPE_CHECKING:
VLLM_DP_SIZE
:
int
=
1
VLLM_DP_MASTER_IP
:
str
=
""
VLLM_DP_MASTER_PORT
:
int
=
0
VLLM_MOE_DP_CHUNK_SIZE
:
int
=
256
VLLM_RANDOMIZE_DP_DUMMY_INPUTS
:
bool
=
False
VLLM_MARLIN_USE_ATOMIC_ADD
:
bool
=
False
VLLM_V0_USE_OUTLINES_CACHE
:
bool
=
False
...
...
@@ -773,6 +774,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DP_MASTER_PORT"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_DP_MASTER_PORT"
,
"0"
)),
# In the context of executing MoE models with Data-Parallel, Expert-Parallel
# and Batched All-to-All dispatch/combine kernels, VLLM_MOE_DP_CHUNK_SIZE
# dictates the quantum of tokens that can be dispatched from a DP
# rank. All DP ranks process the activations in VLLM_MOE_DP_CHUNK_SIZE
# units.
"VLLM_MOE_DP_CHUNK_SIZE"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_MOE_DP_CHUNK_SIZE"
,
"256"
)),
# Randomize inputs during dummy runs when using Data Parallel
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS"
:
lambda
:
os
.
environ
.
get
(
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS"
,
"0"
)
==
"1"
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
9d880f59
...
...
@@ -61,10 +61,6 @@ else:
fused_moe_pallas
=
None
# type: ignore
logger
=
init_logger
(
__name__
)
# Note: this limit is somewhat arbitrary and might be changed later.
# The size of the activations will be E x MOE_DP_CHUNK_SIZE x hidden_dim.
MOE_DP_CHUNK_SIZE
=
256
@
dataclass
class
FusedMoEParallelConfig
:
...
...
@@ -218,7 +214,12 @@ class MoEConfig:
# TODO: add more quantization params, blocked, per-token, etc.
block_size
:
int
=
128
max_num_tokens
:
int
=
MOE_DP_CHUNK_SIZE
max_num_tokens
:
int
=
envs
.
VLLM_MOE_DP_CHUNK_SIZE
def
__post_init__
(
self
):
if
self
.
dp_size
>
1
:
logger
.
debug
(
"Using MOEConfig::max_num_tokens=%d"
,
self
.
max_num_tokens
)
@
property
def
tp_size
(
self
):
...
...
@@ -913,7 +914,7 @@ class FusedMoE(torch.nn.Module):
moe_parallel_config
=
self
.
moe_parallel_config
,
in_dtype
=
params_dtype
,
quant_dtype
=
quant_dtype
,
max_num_tokens
=
MOE_DP_CHUNK_SIZE
,
max_num_tokens
=
envs
.
VLLM_
MOE_DP_CHUNK_SIZE
,
)
self
.
moe_config
=
moe
self
.
quant_config
=
quant_config
...
...
@@ -952,12 +953,12 @@ class FusedMoE(torch.nn.Module):
or
self
.
moe_parallel_config
.
use_deepep_ll_kernels
):
act_dtype
=
vllm_config
.
model_config
.
dtype
self
.
batched_hidden_states
=
torch
.
zeros
(
(
MOE_DP_CHUNK_SIZE
,
self
.
hidden_size
),
(
envs
.
VLLM_
MOE_DP_CHUNK_SIZE
,
self
.
hidden_size
),
dtype
=
act_dtype
,
device
=
torch
.
cuda
.
current_device
())
self
.
batched_router_logits
=
torch
.
zeros
(
(
MOE_DP_CHUNK_SIZE
,
self
.
global_num_experts
),
(
envs
.
VLLM_
MOE_DP_CHUNK_SIZE
,
self
.
global_num_experts
),
dtype
=
act_dtype
,
device
=
torch
.
cuda
.
current_device
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment