Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
edadca10
Unverified
Commit
edadca10
authored
Jan 15, 2026
by
kzwrime
Committed by
GitHub
Jan 15, 2026
Browse files
[Bugfix] Add CpuCommunicator.dispatch and combine to fix DP+MoE inference (#31867)
Signed-off-by:
kunzh
<
zhikun.wu@outlook.com
>
parent
d86fc23b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
1 deletion
+45
-1
vllm/distributed/device_communicators/base_device_communicator.py
...tributed/device_communicators/base_device_communicator.py
+4
-1
vllm/distributed/device_communicators/cpu_communicator.py
vllm/distributed/device_communicators/cpu_communicator.py
+41
-0
No files found.
vllm/distributed/device_communicators/base_device_communicator.py
View file @
edadca10
...
...
@@ -286,7 +286,10 @@ class DeviceCommunicatorBase:
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
...
...
vllm/distributed/device_communicators/cpu_communicator.py
View file @
edadca10
...
...
@@ -8,11 +8,14 @@ import torch
from
torch.distributed
import
ProcessGroup
from
vllm.distributed.utils
import
pickle
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms.interface
import
CpuArchEnum
from
.base_device_communicator
import
DeviceCommunicatorBase
logger
=
init_logger
(
__name__
)
class
CpuCommunicator
(
DeviceCommunicatorBase
):
def
__init__
(
...
...
@@ -32,6 +35,20 @@ class CpuCommunicator(DeviceCommunicatorBase):
):
self
.
dist_module
=
_CPUSHMDistributed
(
self
)
if
self
.
use_all2all
:
if
self
.
all2all_backend
!=
"naive"
:
# type: ignore[has-type]
logger
.
warning
(
"`%s` all2all manager is not supported on CPU. "
"Falling back to `naive` all2all manager for CPU."
,
self
.
all2all_backend
,
# type: ignore[has-type]
)
self
.
all2all_backend
=
"naive"
if
self
.
all2all_backend
==
"naive"
:
from
.all2all
import
NaiveAll2AllManager
self
.
all2all_manager
=
NaiveAll2AllManager
(
self
.
cpu_group
)
logger
.
info
(
"Using naive all2all manager."
)
def
all_reduce
(
self
,
input_
):
self
.
dist_module
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
return
input_
...
...
@@ -110,6 +127,30 @@ class CpuCommunicator(DeviceCommunicatorBase):
)
->
dict
[
str
,
torch
.
Tensor
|
Any
]:
return
self
.
dist_module
.
recv_tensor_dict
(
src
)
def
dispatch
(
# type: ignore[override]
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
hidden_states
,
router_logits
,
is_sequence_parallel
,
extra_tensors
,
# type: ignore[call-arg]
)
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
assert
self
.
all2all_manager
is
not
None
hidden_states
=
self
.
all2all_manager
.
combine
(
hidden_states
,
is_sequence_parallel
)
return
hidden_states
class
_CPUSHMDistributed
:
def
__init__
(
self
,
communicator
:
CpuCommunicator
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment