Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bb2fc8b5
Unverified
Commit
bb2fc8b5
authored
Feb 10, 2026
by
Ilya Markov
Committed by
GitHub
Feb 10, 2026
Browse files
[BugFix] Fix async EPLB hang with DeepEP LL all2all backend (#32860)
Signed-off-by:
ilmarkov
<
markovilya197@gmail.com
>
parent
67132945
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
0 deletions
+56
-0
vllm/distributed/eplb/eplb_utils.py
vllm/distributed/eplb/eplb_utils.py
+54
-0
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+2
-0
No files found.
vllm/distributed/eplb/eplb_utils.py
0 → 100644
View file @
bb2fc8b5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utility functions for EPLB (Expert Parallel Load Balancing)."""
import
os
from
vllm.config
import
ParallelConfig
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
override_envs_for_eplb
(
parallel_config
:
ParallelConfig
)
->
None
:
"""
Override environment variables for EPLB when specific conditions are met.
Args:
parallel_config: The parallel configuration object.
"""
is_data_parallel
=
parallel_config
.
data_parallel_size
>
1
is_eplb_enabled
=
parallel_config
.
enable_eplb
async_eplb
=
parallel_config
.
eplb_config
.
use_async
is_deepep_ll
=
parallel_config
.
all2all_backend
==
"deepep_low_latency"
# Override NCCL_MAX_CTAS to avoid hangs when using async EPLB with the
# DeepEP low-latency backend.
#
# The hang happens when two ranks interleave kernel launches differently
# between NCCL collectives (used by async EPLB weight exchange) and DeepEP
# low-latency (LL) kernels. DeepEP LL uses a cooperative launch and tries
# to reserve a large fraction of the GPU's SMs; if those SMs are currently
# occupied by NCCL, the DeepEP LL launch blocks until enough SMs are
# freed.
#
# If rank A enters DeepEP LL in main thread while rank B is still executing
# NCCL in async thread, rank A can block waiting for SMs, while rank B can
# block inside NCCL waiting for rank A to participate in the collective.
# This circular wait causes a deadlock.
# Limiting NCCL occupancy via NCCL_MAX_CTAS leaves space for the DeepEP
# cooperative kernel to launch and complete, breaking the deadlock.
# See: https://github.com/deepseek-ai/DeepEP/issues/496
if
is_data_parallel
and
is_eplb_enabled
and
is_deepep_ll
and
async_eplb
:
current_value_str
=
os
.
getenv
(
"NCCL_MAX_CTAS"
)
if
current_value_str
and
current_value_str
.
isdigit
():
return
override_value
=
8
os
.
environ
[
"NCCL_MAX_CTAS"
]
=
str
(
override_value
)
logger
.
info_once
(
f
"EPLB: Setting NCCL_MAX_CTAS=
{
override_value
}
"
"for expert parallel with EPLB and deepep_low_latency backend"
,
scope
=
"global"
,
)
vllm/v1/worker/gpu_worker.py
View file @
bb2fc8b5
...
...
@@ -22,6 +22,7 @@ from vllm.distributed import (
set_custom_all_reduce
,
)
from
vllm.distributed.ec_transfer
import
ensure_ec_transfer_initialized
from
vllm.distributed.eplb.eplb_utils
import
override_envs_for_eplb
from
vllm.distributed.kv_transfer
import
(
ensure_kv_transfer_initialized
,
ensure_kv_transfer_shutdown
,
...
...
@@ -1035,6 +1036,7 @@ def init_worker_distributed_environment(
from
vllm.model_executor.layers.batch_invariant
import
init_batch_invariance
init_batch_invariance
(
attention_config
.
backend
)
override_envs_for_eplb
(
parallel_config
)
set_custom_all_reduce
(
not
parallel_config
.
disable_custom_all_reduce
)
init_method
=
distributed_init_method
or
"env://"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment