Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f10797c0
Unverified
Commit
f10797c0
authored
Nov 08, 2024
by
Yan Ma
Committed by
GitHub
Nov 08, 2024
Browse files
[Bugfix][XPU] Fix xpu tp by introducing XpuCommunicator (#10144)
Signed-off-by:
yan ma
<
yan.ma@intel.com
>
parent
f4c2187e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
65 additions
and
22 deletions
+65
-22
vllm/distributed/device_communicators/xpu_communicator.py
vllm/distributed/device_communicators/xpu_communicator.py
+47
-0
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+18
-22
No files found.
vllm/distributed/device_communicators/xpu_communicator.py
0 → 100644
View file @
f10797c0
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
from
vllm.platforms
import
current_platform
class
XpuCommunicator
:
def
__init__
(
self
,
group
:
ProcessGroup
):
if
not
current_platform
.
is_xpu
():
self
.
disabled
=
True
return
self
.
disabled
=
False
self
.
group
=
group
self
.
world_size
=
dist
.
get_world_size
(
self
.
group
)
def
all_reduce
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
dist
.
all_reduce
(
x
,
group
=
self
.
group
)
return
x
def
gather
(
self
,
input_
:
torch
.
Tensor
,
rank_in_group
:
int
,
dst
:
int
=
0
,
dim
:
int
=
-
1
):
# For xpu path, gather doesn't work properly together with ray
# cluster so we use all_gather instead for now.
input_size
=
input_
.
size
()
# Allocate output tensor.
output_tensor
=
torch
.
empty
((
self
.
world_size
,
)
+
input_size
,
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
# All-gather.
torch
.
distributed
.
all_gather_into_tensor
(
output_tensor
,
input_
,
group
=
self
.
group
)
if
rank_in_group
==
dst
:
# Reshape
output_tensor
=
output_tensor
.
movedim
(
0
,
dim
)
output_tensor
=
output_tensor
.
reshape
(
input_size
[:
dim
]
+
(
self
.
world_size
*
input_size
[
dim
],
)
+
input_size
[
dim
+
1
:])
else
:
output_tensor
=
None
return
output_tensor
vllm/distributed/parallel_state.py
View file @
f10797c0
...
...
@@ -177,6 +177,7 @@ class GroupCoordinator:
use_custom_allreduce
:
bool
,
use_tpu_communicator
:
bool
,
use_hpu_communicator
:
bool
,
use_xpu_communicator
:
bool
,
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
Optional
[
str
]
=
None
,
):
...
...
@@ -214,6 +215,7 @@ class GroupCoordinator:
self
.
use_custom_allreduce
=
use_custom_allreduce
self
.
use_tpu_communicator
=
use_tpu_communicator
self
.
use_hpu_communicator
=
use_hpu_communicator
self
.
use_xpu_communicator
=
use_xpu_communicator
# lazy import to avoid documentation build error
from
vllm.distributed.device_communicators.custom_all_reduce
import
(
...
...
@@ -248,6 +250,12 @@ class GroupCoordinator:
if
use_hpu_communicator
and
self
.
world_size
>
1
:
self
.
hpu_communicator
=
HpuCommunicator
(
group
=
self
.
device_group
)
from
vllm.distributed.device_communicators.xpu_communicator
import
(
XpuCommunicator
)
self
.
xpu_communicator
:
Optional
[
XpuCommunicator
]
if
use_xpu_communicator
and
self
.
world_size
>
1
:
self
.
xpu_communicator
=
XpuCommunicator
(
group
=
self
.
device_group
)
from
vllm.distributed.device_communicators.shm_broadcast
import
(
MessageQueue
)
self
.
mq_broadcaster
:
Optional
[
MessageQueue
]
=
None
...
...
@@ -373,6 +381,10 @@ class GroupCoordinator:
not
self
.
hpu_communicator
.
disabled
:
return
self
.
hpu_communicator
.
all_reduce
(
input_
)
if
self
.
xpu_communicator
is
not
None
and
\
not
self
.
xpu_communicator
.
disabled
:
return
self
.
xpu_communicator
.
all_reduce
(
input_
)
if
self
.
ca_comm
is
not
None
and
\
not
self
.
ca_comm
.
disabled
and
\
self
.
ca_comm
.
should_custom_ar
(
input_
):
...
...
@@ -459,28 +471,10 @@ class GroupCoordinator:
if
dim
<
0
:
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
# For xpu path, gather doesn't work properly together with ray
# cluster so we use all_gather instead for now.
if
current_platform
.
is_xpu
():
input_size
=
input_
.
size
()
# Allocate output tensor.
output_tensor
=
torch
.
empty
((
world_size
,
)
+
input_size
,
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
# All-gather.
torch
.
distributed
.
all_gather_into_tensor
(
output_tensor
,
input_
,
group
=
self
.
device_group
)
if
self
.
rank_in_group
==
dst
:
# Reshape
output_tensor
=
output_tensor
.
movedim
(
0
,
dim
)
output_tensor
=
output_tensor
.
reshape
(
input_size
[:
dim
]
+
(
world_size
*
input_size
[
dim
],
)
+
input_size
[
dim
+
1
:])
else
:
output_tensor
=
None
return
output_tensor
if
self
.
xpu_communicator
is
not
None
and
\
not
self
.
xpu_communicator
.
disabled
:
return
self
.
xpu_communicator
.
gather
(
input_
,
self
.
rank_in_group
,
dst
,
dim
)
# Allocate output tensor.
if
self
.
rank_in_group
==
dst
:
gather_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
...
...
@@ -896,6 +890,7 @@ def init_world_group(ranks: List[int], local_rank: int,
use_custom_allreduce
=
False
,
use_tpu_communicator
=
False
,
use_hpu_communicator
=
False
,
use_xpu_communicator
=
False
,
group_name
=
"world"
,
)
...
...
@@ -918,6 +913,7 @@ def init_model_parallel_group(
use_custom_allreduce
=
use_custom_allreduce
,
use_tpu_communicator
=
True
,
use_hpu_communicator
=
True
,
use_xpu_communicator
=
True
,
use_message_queue_broadcaster
=
use_message_queue_broadcaster
,
group_name
=
group_name
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment