Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
30513d1c
Unverified
Commit
30513d1c
authored
Feb 17, 2025
by
Yan Ma
Committed by
GitHub
Feb 17, 2025
Browse files
[Bugfix] fix xpu communicator (#13368)
Signed-off-by:
yan ma
<
yan.ma@intel.com
>
parent
1f69c4a8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
58 additions
and
0 deletions
+58
-0
vllm/distributed/device_communicators/xpu_communicator.py
vllm/distributed/device_communicators/xpu_communicator.py
+54
-0
vllm/platforms/xpu.py
vllm/platforms/xpu.py
+4
-0
No files found.
vllm/distributed/device_communicators/xpu_communicator.py
0 → 100644
View file @
30513d1c
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
from
.base_device_communicator
import
DeviceCommunicatorBase
class
XpuCommunicator
(
DeviceCommunicatorBase
):
def
__init__
(
self
,
cpu_group
:
ProcessGroup
,
device
:
Optional
[
torch
.
device
]
=
None
,
device_group
:
Optional
[
ProcessGroup
]
=
None
,
unique_name
:
str
=
""
):
super
().
__init__
(
cpu_group
,
device
,
device_group
,
unique_name
)
def
all_reduce
(
self
,
input_
)
->
torch
.
Tensor
:
dist
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
return
input_
def
gather
(
self
,
input_
:
torch
.
Tensor
,
dst
:
int
=
0
,
dim
:
int
=
-
1
)
->
Optional
[
torch
.
Tensor
]:
assert
-
input_
.
dim
()
<=
dim
<
input_
.
dim
(),
(
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
)
if
dim
<
0
:
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
# For xpu path, gather doesn't work properly together with ray
# cluster so we use all_gather instead for now.
input_size
=
input_
.
size
()
# Allocate output tensor.
output_tensor
=
torch
.
empty
((
self
.
world_size
,
)
+
input_size
,
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
# All-gather.
dist
.
all_gather_into_tensor
(
output_tensor
,
input_
,
group
=
self
.
device_group
)
if
self
.
rank_in_group
==
dst
:
# Reshape
output_tensor
=
output_tensor
.
movedim
(
0
,
dim
)
output_tensor
=
output_tensor
.
reshape
(
input_size
[:
dim
]
+
(
self
.
world_size
*
input_size
[
dim
],
)
+
input_size
[
dim
+
1
:])
else
:
output_tensor
=
None
return
output_tensor
vllm/platforms/xpu.py
View file @
30513d1c
...
@@ -135,3 +135,7 @@ class XPUPlatform(Platform):
...
@@ -135,3 +135,7 @@ class XPUPlatform(Platform):
logger
.
warning
(
"Unknown device name %s, always use float16"
,
logger
.
warning
(
"Unknown device name %s, always use float16"
,
device_name
)
device_name
)
return
False
return
False
@
classmethod
def
get_device_communicator_cls
(
cls
)
->
str
:
return
"vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator"
# noqa
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment