Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
13f6630a
Unverified
Commit
13f6630a
authored
Jan 20, 2026
by
YiSheng5
Committed by
GitHub
Jan 20, 2026
Browse files
[XPU]Support AgRsAll2AllManager on XPU device (#32654)
Signed-off-by:
yisheng
<
yi.sheng@intel.com
>
parent
fda3f03e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
130 additions
and
7 deletions
+130
-7
vllm/distributed/device_communicators/xpu_communicator.py
vllm/distributed/device_communicators/xpu_communicator.py
+130
-7
No files found.
vllm/distributed/device_communicators/xpu_communicator.py
View file @
13f6630a
...
@@ -23,23 +23,146 @@ class XpuCommunicator(DeviceCommunicatorBase):
...
@@ -23,23 +23,146 @@ class XpuCommunicator(DeviceCommunicatorBase):
):
):
super
().
__init__
(
cpu_group
,
device
,
device_group
,
unique_name
)
super
().
__init__
(
cpu_group
,
device
,
device_group
,
unique_name
)
if
self
.
use_all2all
:
if
self
.
use_all2all
:
if
self
.
all2all_backend
!=
"naive"
:
# type: ignore[has-type]
logger
.
warning
(
"`%s` all2all manager is not supported on XPU. "
"Falling back to `naive` all2all manager for XPU."
,
self
.
all2all_backend
,
# type: ignore[has-type]
)
self
.
all2all_backend
=
"naive"
if
self
.
all2all_backend
==
"naive"
:
if
self
.
all2all_backend
==
"naive"
:
from
.all2all
import
NaiveAll2AllManager
from
.all2all
import
NaiveAll2AllManager
self
.
all2all_manager
=
NaiveAll2AllManager
(
self
.
cpu_group
)
self
.
all2all_manager
=
NaiveAll2AllManager
(
self
.
cpu_group
)
logger
.
info
(
"Using naive all2all manager."
)
logger
.
info
(
"Using naive all2all manager."
)
elif
self
.
all2all_backend
==
"allgather_reducescatter"
:
from
.all2all
import
AgRsAll2AllManager
self
.
all2all_manager
=
AgRsAll2AllManager
(
self
.
cpu_group
)
logger
.
info
(
"Using AgRs manager on XPU device."
)
else
:
# type: ignore[has-type]
logger
.
warning
(
"`%s` all2all manager is not supported on XPU. "
"Falling back to AgRs manager for XPU, "
"which is the Default backend"
,
self
.
all2all_backend
,
# type: ignore[has-type]
)
from
.all2all
import
AgRsAll2AllManager
self
.
all2all_manager
=
AgRsAll2AllManager
(
self
.
cpu_group
)
logger
.
info
(
"Using AgRs manager on XPU device."
)
def
all_reduce
(
self
,
input_
)
->
torch
.
Tensor
:
def
all_reduce
(
self
,
input_
)
->
torch
.
Tensor
:
dist
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
dist
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
return
input_
return
input_
def
reduce_scatter
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
):
world_size
=
self
.
world_size
if
dim
<
0
:
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
# Note: This will produce an incorrect answer if we don't make
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
input_tensor
=
input_
.
movedim
(
0
,
dim
).
contiguous
()
assert
input_tensor
.
shape
[
0
]
%
world_size
==
0
chunk_size
=
input_tensor
.
shape
[
0
]
//
world_size
output_shape
=
(
chunk_size
,)
+
input_tensor
.
shape
[
1
:]
output
=
torch
.
empty
(
output_shape
,
dtype
=
input_tensor
.
dtype
,
device
=
input_tensor
.
device
)
dist
.
reduce_scatter_tensor
(
output
,
input_tensor
)
# Reshape before returning
return
output
.
movedim
(
0
,
dim
).
contiguous
()
def
reduce_scatterv
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
,
sizes
:
list
[
int
]
|
None
=
None
):
world_size
=
self
.
world_size
if
dim
<
0
:
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
# Note: This will produce an incorrect answer if we don't make
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
input_tensor
=
input_
.
movedim
(
0
,
dim
).
contiguous
()
if
sizes
is
not
None
:
assert
len
(
sizes
)
==
world_size
assert
input_tensor
.
shape
[
0
]
==
sum
(
sizes
)
chunk_size
=
sizes
[
self
.
rank_in_group
]
else
:
assert
input_tensor
.
shape
[
0
]
%
world_size
==
0
chunk_size
=
input_tensor
.
shape
[
0
]
//
world_size
output_shape
=
(
chunk_size
,)
+
input_tensor
.
shape
[
1
:]
output
=
torch
.
empty
(
output_shape
,
dtype
=
input_tensor
.
dtype
,
device
=
input_tensor
.
device
)
if
sizes
is
not
None
and
sizes
.
count
(
sizes
[
0
])
!=
len
(
sizes
):
# if inputs shape in different ranks is not the same using reduce_scatter
input_splits
=
list
(
input_tensor
.
split
(
sizes
,
dim
=
0
))
dist
.
reduce_scatter
(
output
,
input_splits
)
else
:
dist
.
reduce_scatter_tensor
(
output
,
input_tensor
)
# Reshape before returning
return
output
.
movedim
(
0
,
dim
).
contiguous
()
def
all_gatherv
(
self
,
input_
:
torch
.
Tensor
|
list
[
torch
.
Tensor
],
dim
:
int
=
0
,
sizes
:
list
[
int
]
|
None
=
None
,
):
if
dim
!=
0
:
raise
NotImplementedError
(
"only dim 0 all-gatherv is supported"
)
world_size
=
self
.
world_size
# 'sizes' is not needed if all inputs in the same group have the same
# shape
if
sizes
is
not
None
and
all
(
s
==
sizes
[
0
]
for
s
in
sizes
):
sizes
=
None
def
_all_gather_single
(
input_
:
torch
.
Tensor
,
sizes
:
list
[
int
]
|
None
=
None
):
input_size
=
input_
.
size
()
if
sizes
is
not
None
:
assert
len
(
sizes
)
==
world_size
assert
input_
.
shape
[
dim
]
==
sizes
[
self
.
rank_in_group
],
(
f
"
{
input_
.
shape
[
dim
]
}
!=
{
sizes
[
self
.
rank_in_group
]
}
"
)
output_size
=
(
sum
(
sizes
),)
+
input_size
[
1
:]
else
:
output_size
=
(
input_size
[
0
]
*
world_size
,)
+
input_size
[
1
:]
# Allocate output tensor.
output_tensor
=
torch
.
empty
(
output_size
,
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
if
sizes
is
not
None
:
all_gather_list
=
[]
for
size
in
sizes
:
all_gather_list
.
append
(
torch
.
empty
(
(
size
,)
+
input_
.
shape
[
1
:],
dtype
=
input_
.
dtype
,
device
=
input_
.
device
,
)
)
dist
.
all_gather
(
all_gather_list
,
input_
)
output_tensor
=
torch
.
cat
(
all_gather_list
,
dim
=
0
)
else
:
dist
.
all_gather
([
output_tensor
],
input_
)
return
output_tensor
if
isinstance
(
input_
,
torch
.
Tensor
):
return
_all_gather_single
(
input_
,
sizes
)
output_list
=
[]
for
inp
in
input_
:
output_list
.
append
(
_all_gather_single
(
inp
,
sizes
=
sizes
))
return
output_list
def
gather
(
def
gather
(
self
,
input_
:
torch
.
Tensor
,
dst
:
int
=
0
,
dim
:
int
=
-
1
self
,
input_
:
torch
.
Tensor
,
dst
:
int
=
0
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
|
None
:
)
->
torch
.
Tensor
|
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment