Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4c966e44
Unverified
Commit
4c966e44
authored
Sep 23, 2025
by
Fanli Lin
Committed by
GitHub
Sep 23, 2025
Browse files
[XPU] Fix MOE DP accuracy issue on XPU (#25465)
parent
da5e7e43
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
1 deletion
+29
-1
examples/offline_inference/data_parallel.py
examples/offline_inference/data_parallel.py
+10
-1
vllm/distributed/device_communicators/xpu_communicator.py
vllm/distributed/device_communicators/xpu_communicator.py
+19
-0
No files found.
examples/offline_inference/data_parallel.py
View file @
4c966e44
...
@@ -101,6 +101,13 @@ def parse_args():
...
@@ -101,6 +101,13 @@ def parse_args():
"--quantization"
,
"--quantization"
,
type
=
str
,
type
=
str
,
)
)
parser
.
add_argument
(
"--disable-expert-parallel"
,
dest
=
"enable_expert_parallel"
,
action
=
"store_false"
,
help
=
"Disable expert parallel (default: enabled)."
,
)
parser
.
set_defaults
(
enable_expert_parallel
=
True
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -113,6 +120,7 @@ def main(
...
@@ -113,6 +120,7 @@ def main(
dp_master_port
,
dp_master_port
,
GPUs_per_dp_rank
,
GPUs_per_dp_rank
,
enforce_eager
,
enforce_eager
,
enable_expert_parallel
,
trust_remote_code
,
trust_remote_code
,
max_num_seqs
,
max_num_seqs
,
max_model_len
,
max_model_len
,
...
@@ -168,7 +176,7 @@ def main(
...
@@ -168,7 +176,7 @@ def main(
model
=
model
,
model
=
model
,
tensor_parallel_size
=
GPUs_per_dp_rank
,
tensor_parallel_size
=
GPUs_per_dp_rank
,
enforce_eager
=
enforce_eager
,
enforce_eager
=
enforce_eager
,
enable_expert_parallel
=
True
,
enable_expert_parallel
=
enable_expert_parallel
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
max_num_seqs
=
max_num_seqs
,
max_num_seqs
=
max_num_seqs
,
max_model_len
=
max_model_len
,
max_model_len
=
max_model_len
,
...
@@ -229,6 +237,7 @@ if __name__ == "__main__":
...
@@ -229,6 +237,7 @@ if __name__ == "__main__":
dp_master_port
,
dp_master_port
,
tp_size
,
tp_size
,
args
.
enforce_eager
,
args
.
enforce_eager
,
args
.
enable_expert_parallel
,
args
.
trust_remote_code
,
args
.
trust_remote_code
,
args
.
max_num_seqs
,
args
.
max_num_seqs
,
args
.
max_model_len
,
args
.
max_model_len
,
...
...
vllm/distributed/device_communicators/xpu_communicator.py
View file @
4c966e44
...
@@ -25,6 +25,12 @@ class XpuCommunicator(DeviceCommunicatorBase):
...
@@ -25,6 +25,12 @@ class XpuCommunicator(DeviceCommunicatorBase):
super
().
__init__
(
cpu_group
,
device
,
device_group
,
unique_name
)
super
().
__init__
(
cpu_group
,
device
,
device_group
,
unique_name
)
if
self
.
use_all2all
:
if
self
.
use_all2all
:
all2all_backend
=
envs
.
VLLM_ALL2ALL_BACKEND
all2all_backend
=
envs
.
VLLM_ALL2ALL_BACKEND
if
all2all_backend
!=
"naive"
:
logger
.
warning
(
"`%s` all2all manager is not supported on XPU."
"Falling back to `naive` all2all manager for XPU."
,
all2all_backend
)
all2all_backend
=
"naive"
if
all2all_backend
==
"naive"
:
if
all2all_backend
==
"naive"
:
from
.all2all
import
NaiveAll2AllManager
from
.all2all
import
NaiveAll2AllManager
self
.
all2all_manager
=
NaiveAll2AllManager
(
self
.
cpu_group
)
self
.
all2all_manager
=
NaiveAll2AllManager
(
self
.
cpu_group
)
...
@@ -67,3 +73,16 @@ class XpuCommunicator(DeviceCommunicatorBase):
...
@@ -67,3 +73,16 @@ class XpuCommunicator(DeviceCommunicatorBase):
def
broadcast
(
self
,
input_
:
torch
.
Tensor
,
src
:
int
=
0
)
->
None
:
def
broadcast
(
self
,
input_
:
torch
.
Tensor
,
src
:
int
=
0
)
->
None
:
dist
.
broadcast
(
input_
,
src
=
src
,
group
=
self
.
device_group
)
dist
.
broadcast
(
input_
,
src
=
src
,
group
=
self
.
device_group
)
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
all2all_manager
is
not
None
hidden_states
,
router_logits
=
self
.
all2all_manager
.
dispatch
(
hidden_states
,
router_logits
)
return
hidden_states
,
router_logits
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
self
.
all2all_manager
is
not
None
hidden_states
=
self
.
all2all_manager
.
combine
(
hidden_states
)
return
hidden_states
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment