Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a0231b7c
Unverified
Commit
a0231b7c
authored
Feb 16, 2025
by
youkaichao
Committed by
GitHub
Feb 16, 2025
Browse files
[platform] add base class for communicators (#13208)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
124776eb
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
364 additions
and
282 deletions
+364
-282
vllm/distributed/device_communicators/base_device_communicator.py
...tributed/device_communicators/base_device_communicator.py
+117
-0
vllm/distributed/device_communicators/cpu_communicator.py
vllm/distributed/device_communicators/cpu_communicator.py
+33
-0
vllm/distributed/device_communicators/cuda_communicator.py
vllm/distributed/device_communicators/cuda_communicator.py
+106
-0
vllm/distributed/device_communicators/hpu_communicator.py
vllm/distributed/device_communicators/hpu_communicator.py
+14
-19
vllm/distributed/device_communicators/tpu_communicator.py
vllm/distributed/device_communicators/tpu_communicator.py
+16
-13
vllm/distributed/device_communicators/xpu_communicator.py
vllm/distributed/device_communicators/xpu_communicator.py
+0
-49
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+48
-201
vllm/platforms/cpu.py
vllm/platforms/cpu.py
+7
-0
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+4
-0
vllm/platforms/hpu.py
vllm/platforms/hpu.py
+4
-0
vllm/platforms/interface.py
vllm/platforms/interface.py
+7
-0
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+4
-0
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+4
-0
No files found.
vllm/distributed/device_communicators/base_device_communicator.py
0 → 100644
View file @
a0231b7c
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
class
DeviceCommunicatorBase
:
"""
Base class for device-specific communicator.
It can use the `cpu_group` to initialize the communicator.
If the device has PyTorch integration (PyTorch can recognize its
communication backend), the `device_group` will also be given.
"""
def
__init__
(
self
,
cpu_group
:
ProcessGroup
,
device
:
Optional
[
torch
.
device
]
=
None
,
device_group
:
Optional
[
ProcessGroup
]
=
None
,
unique_name
:
str
=
""
):
self
.
device
=
device
or
torch
.
device
(
"cpu"
)
self
.
cpu_group
=
cpu_group
self
.
device_group
=
device_group
self
.
unique_name
=
unique_name
self
.
rank
=
dist
.
get_rank
(
cpu_group
)
self
.
world_size
=
dist
.
get_world_size
(
cpu_group
)
self
.
ranks
=
dist
.
get_process_group_ranks
(
cpu_group
)
self
.
global_rank
=
dist
.
get_rank
()
self
.
global_world_size
=
dist
.
get_world_size
()
self
.
rank_in_group
=
dist
.
get_group_rank
(
self
.
cpu_group
,
self
.
global_rank
)
def
all_reduce
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
dist
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
return
input_
def
all_gather
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
if
dim
<
0
:
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
input_size
=
input_
.
size
()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size
=
(
input_size
[
0
]
*
self
.
world_size
,
)
+
input_size
[
1
:]
# Allocate output tensor.
output_tensor
=
torch
.
empty
(
output_size
,
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
# All-gather.
dist
.
all_gather_into_tensor
(
output_tensor
,
input_
,
group
=
self
.
device_group
)
# Reshape
output_tensor
=
output_tensor
.
reshape
((
self
.
world_size
,
)
+
input_size
)
output_tensor
=
output_tensor
.
movedim
(
0
,
dim
)
output_tensor
=
output_tensor
.
reshape
(
input_size
[:
dim
]
+
(
self
.
world_size
*
input_size
[
dim
],
)
+
input_size
[
dim
+
1
:])
return
output_tensor
def
gather
(
self
,
input_
:
torch
.
Tensor
,
dst
:
int
=
0
,
dim
:
int
=
-
1
)
->
Optional
[
torch
.
Tensor
]:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: `dst` is the local rank of the destination rank.
"""
world_size
=
self
.
world_size
assert
-
input_
.
dim
()
<=
dim
<
input_
.
dim
(),
(
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
)
if
dim
<
0
:
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
# Allocate output tensor.
if
self
.
rank_in_group
==
dst
:
gather_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
else
:
gather_list
=
None
# Gather.
torch
.
distributed
.
gather
(
input_
,
gather_list
,
dst
=
self
.
ranks
[
dst
],
group
=
self
.
device_group
)
if
self
.
rank_in_group
==
dst
:
output_tensor
=
torch
.
cat
(
gather_list
,
dim
=
dim
)
else
:
output_tensor
=
None
return
output_tensor
def
send
(
self
,
tensor
:
torch
.
Tensor
,
dst
:
Optional
[
int
]
=
None
)
->
None
:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if
dst
is
None
:
dst
=
(
self
.
rank_in_group
+
1
)
%
self
.
world_size
torch
.
distributed
.
send
(
tensor
,
self
.
ranks
[
dst
],
self
.
device_group
)
def
recv
(
self
,
size
:
torch
.
Size
,
dtype
:
torch
.
dtype
,
src
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if
src
is
None
:
src
=
(
self
.
rank_in_group
-
1
)
%
self
.
world_size
tensor
=
torch
.
empty
(
size
,
dtype
=
dtype
,
device
=
self
.
device
)
torch
.
distributed
.
recv
(
tensor
,
self
.
ranks
[
src
],
self
.
device_group
)
return
tensor
def
destroy
(
self
):
pass
vllm/distributed/device_communicators/cpu_communicator.py
0 → 100644
View file @
a0231b7c
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
import
torch
from
torch.distributed
import
ProcessGroup
from
.base_device_communicator
import
DeviceCommunicatorBase
class
CpuCommunicator
(
DeviceCommunicatorBase
):
def
__init__
(
self
,
cpu_group
:
ProcessGroup
,
device
:
Optional
[
torch
.
device
]
=
None
,
device_group
:
Optional
[
ProcessGroup
]
=
None
,
unique_name
:
str
=
""
):
super
().
__init__
(
cpu_group
,
device
,
device_group
,
unique_name
)
self
.
ipex_available
=
False
self
.
dist_module
=
torch
.
distributed
try
:
import
intel_extension_for_pytorch
as
ipex
self
.
ipex_available
=
True
self
.
dist_module
=
ipex
.
distributed
except
ImportError
:
"""
Intel IPEX not found. Falling back to PyTorch native
all_reduce for CPU (e.g. MacOS)
"""
pass
def
all_reduce
(
self
,
input_
):
return
self
.
dist_module
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
vllm/distributed/device_communicators/cuda_communicator.py
0 → 100644
View file @
a0231b7c
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
import
torch
from
torch.distributed
import
ProcessGroup
from
.base_device_communicator
import
DeviceCommunicatorBase
class
CudaCommunicator
(
DeviceCommunicatorBase
):
def
__init__
(
self
,
cpu_group
:
ProcessGroup
,
device
:
Optional
[
torch
.
device
]
=
None
,
device_group
:
Optional
[
ProcessGroup
]
=
None
,
unique_name
:
str
=
""
):
super
().
__init__
(
cpu_group
,
device
,
device_group
,
unique_name
)
if
"pp"
in
unique_name
:
# pipeline parallel does not need custom allreduce
use_custom_allreduce
=
False
else
:
from
vllm.distributed.parallel_state
import
(
_ENABLE_CUSTOM_ALL_REDUCE
)
use_custom_allreduce
=
_ENABLE_CUSTOM_ALL_REDUCE
use_pynccl
=
True
self
.
use_pynccl
=
use_pynccl
self
.
use_custom_allreduce
=
use_custom_allreduce
# lazy import to avoid documentation build error
from
vllm.distributed.device_communicators.custom_all_reduce
import
(
CustomAllreduce
)
from
vllm.distributed.device_communicators.pynccl
import
(
PyNcclCommunicator
)
self
.
pynccl_comm
:
Optional
[
PyNcclCommunicator
]
=
None
if
use_pynccl
and
self
.
world_size
>
1
:
self
.
pynccl_comm
=
PyNcclCommunicator
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
)
self
.
ca_comm
:
Optional
[
CustomAllreduce
]
=
None
if
use_custom_allreduce
and
self
.
world_size
>
1
:
# Initialize a custom fast all-reduce implementation.
self
.
ca_comm
=
CustomAllreduce
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
)
def
all_reduce
(
self
,
input_
):
# always try custom allreduce first,
# and then pynccl.
ca_comm
=
self
.
ca_comm
if
ca_comm
is
not
None
and
not
ca_comm
.
disabled
and
\
ca_comm
.
should_custom_ar
(
input_
):
out
=
ca_comm
.
custom_all_reduce
(
input_
)
assert
out
is
not
None
return
out
pynccl_comm
=
self
.
pynccl_comm
assert
pynccl_comm
is
not
None
out
=
pynccl_comm
.
all_reduce
(
input_
)
if
out
is
None
:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
# when we run the model, allreduce only happens for the TP
# group, where we always have either custom allreduce or pynccl.
out
=
input_
.
clone
()
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
device_group
)
return
out
def
send
(
self
,
tensor
:
torch
.
Tensor
,
dst
:
Optional
[
int
]
=
None
)
->
None
:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if
dst
is
None
:
dst
=
(
self
.
rank_in_group
+
1
)
%
self
.
world_size
pynccl_comm
=
self
.
pynccl_comm
if
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
:
pynccl_comm
.
send
(
tensor
,
dst
)
else
:
torch
.
distributed
.
send
(
tensor
,
self
.
ranks
[
dst
],
self
.
device_group
)
def
recv
(
self
,
size
:
torch
.
Size
,
dtype
:
torch
.
dtype
,
src
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if
src
is
None
:
src
=
(
self
.
rank_in_group
-
1
)
%
self
.
world_size
tensor
=
torch
.
empty
(
size
,
dtype
=
dtype
,
device
=
self
.
device
)
pynccl_comm
=
self
.
pynccl_comm
if
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
:
pynccl_comm
.
recv
(
tensor
,
src
)
else
:
torch
.
distributed
.
recv
(
tensor
,
self
.
ranks
[
src
],
self
.
device_group
)
return
tensor
def
destroy
(
self
):
if
self
.
pynccl_comm
is
not
None
:
self
.
pynccl_comm
=
None
if
self
.
ca_comm
is
not
None
:
self
.
ca_comm
=
None
vllm/distributed/device_communicators/hpu_communicator.py
View file @
a0231b7c
...
@@ -2,45 +2,40 @@
...
@@ -2,45 +2,40 @@
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
.base_device_communicator
import
DeviceCommunicatorBase
if
current_platform
.
is_hpu
():
if
current_platform
.
is_hpu
():
import
habana_frameworks.torch
as
htorch
# noqa: F401
import
habana_frameworks.torch
as
htorch
# noqa: F401
class
HpuCommunicator
:
class
HpuCommunicator
(
DeviceCommunicatorBase
):
def
__init__
(
self
,
group
:
ProcessGroup
):
if
not
current_platform
.
is_hpu
():
self
.
disabled
=
True
return
self
.
disabled
=
False
self
.
group
=
group
self
.
world_size
=
dist
.
get_world_size
(
self
.
group
)
def
all_reduce
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
all_reduce
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
# (which is required for tensor parallel HPUGraph inference)
# (which is required for tensor parallel HPUGraph inference)
htorch
.
core
.
mark_step
()
htorch
.
core
.
mark_step
()
dist
.
all_reduce
(
x
,
group
=
self
.
group
)
dist
.
all_reduce
(
input_
,
group
=
self
.
device_
group
)
return
x
return
input_
def
all_gather
(
self
,
x
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
def
all_gather
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
world_size
=
self
.
world_size
world_size
=
self
.
world_size
if
dim
<
0
:
if
dim
<
0
:
# Convert negative dim to positive.
# Convert negative dim to positive.
dim
+=
x
.
dim
()
dim
+=
input_
.
dim
()
input_size
=
x
.
size
()
input_size
=
input_
.
size
()
# Allocate output tensor.
# Allocate output tensor.
output_tensor
=
torch
.
empty
((
world_size
,
)
+
input_size
,
output_tensor
=
torch
.
empty
((
world_size
,
)
+
input_size
,
dtype
=
x
.
dtype
,
dtype
=
input_
.
dtype
,
device
=
x
.
device
)
device
=
input_
.
device
)
# All-gather.
# All-gather.
htorch
.
core
.
mark_step
()
htorch
.
core
.
mark_step
()
dist
.
all_gather_into_tensor
(
output_tensor
,
x
,
group
=
self
.
group
)
dist
.
all_gather_into_tensor
(
output_tensor
,
input_
,
group
=
self
.
device_group
)
# Reshape
# Reshape
output_tensor
=
output_tensor
.
movedim
(
0
,
dim
)
output_tensor
=
output_tensor
.
movedim
(
0
,
dim
)
output_tensor
=
output_tensor
.
reshape
(
input_size
[:
dim
]
+
output_tensor
=
output_tensor
.
reshape
(
input_size
[:
dim
]
+
...
...
vllm/distributed/device_communicators/tpu_communicator.py
View file @
a0231b7c
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
os
import
os
from
typing
import
Optional
import
torch
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
.base_device_communicator
import
DeviceCommunicatorBase
if
current_platform
.
is_tpu
():
if
current_platform
.
is_tpu
():
import
torch_xla.core.xla_model
as
xm
import
torch_xla.core.xla_model
as
xm
import
torch_xla.runtime
as
xr
import
torch_xla.runtime
as
xr
...
@@ -16,19 +18,20 @@ if current_platform.is_tpu():
...
@@ -16,19 +18,20 @@ if current_platform.is_tpu():
from
vllm.executor
import
ray_utils
from
vllm.executor
import
ray_utils
class
TpuCommunicator
:
class
TpuCommunicator
(
DeviceCommunicatorBase
)
:
def
__init__
(
self
,
group
:
ProcessGroup
):
def
__init__
(
self
,
if
not
current_platform
.
is_tpu
():
cpu_group
:
ProcessGroup
,
self
.
disabled
=
True
device
:
Optional
[
torch
.
device
]
=
None
,
return
device_group
:
Optional
[
ProcessGroup
]
=
None
,
self
.
disabled
=
False
unique_name
:
str
=
""
):
super
().
__init__
(
cpu_group
,
device
,
device_group
,
unique_name
)
# NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
# NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
# must be used together. Therefore, the local rank and world size can
# must be used together. Therefore, the local rank and world size can
# be simply calculated as follows.
# be simply calculated as follows.
global_rank
=
dist
.
get_rank
(
group
)
global_rank
=
self
.
global_rank
global_world_size
=
dist
.
get
_world_size
(
group
)
global_world_size
=
self
.
global
_world_size
# Calculate how many TPU nodes are in the current deployment. This
# Calculate how many TPU nodes are in the current deployment. This
# is the Ray placement group if it is deployed with Ray. Default
# is the Ray placement group if it is deployed with Ray. Default
...
@@ -55,9 +58,9 @@ class TpuCommunicator:
...
@@ -55,9 +58,9 @@ class TpuCommunicator:
pjrt
.
initialize_multiprocess
(
local_rank
,
local_world_size
)
pjrt
.
initialize_multiprocess
(
local_rank
,
local_world_size
)
xr
.
_init_world_size_ordinal
()
xr
.
_init_world_size_ordinal
()
def
all_reduce
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
all_reduce
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
xm
.
all_reduce
(
xm
.
REDUCE_SUM
,
x
)
return
xm
.
all_reduce
(
xm
.
REDUCE_SUM
,
input_
)
def
all_gather
(
self
,
x
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
def
all_gather
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
assert
dim
==
-
1
,
"TPUs only support dim=-1 for all-gather."
assert
dim
==
-
1
,
"TPUs only support dim=-1 for all-gather."
return
xm
.
all_gather
(
x
,
dim
=
dim
)
return
xm
.
all_gather
(
input_
,
dim
=
dim
)
vllm/distributed/device_communicators/xpu_communicator.py
deleted
100644 → 0
View file @
124776eb
# SPDX-License-Identifier: Apache-2.0
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
from
vllm.platforms
import
current_platform
class
XpuCommunicator
:
def
__init__
(
self
,
group
:
ProcessGroup
):
if
not
current_platform
.
is_xpu
():
self
.
disabled
=
True
return
self
.
disabled
=
False
self
.
group
=
group
self
.
world_size
=
dist
.
get_world_size
(
self
.
group
)
def
all_reduce
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
dist
.
all_reduce
(
x
,
group
=
self
.
group
)
return
x
def
gather
(
self
,
input_
:
torch
.
Tensor
,
rank_in_group
:
int
,
dst
:
int
=
0
,
dim
:
int
=
-
1
):
# For xpu path, gather doesn't work properly together with ray
# cluster so we use all_gather instead for now.
input_size
=
input_
.
size
()
# Allocate output tensor.
output_tensor
=
torch
.
empty
((
self
.
world_size
,
)
+
input_size
,
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
# All-gather.
torch
.
distributed
.
all_gather_into_tensor
(
output_tensor
,
input_
,
group
=
self
.
group
)
if
rank_in_group
==
dst
:
# Reshape
output_tensor
=
output_tensor
.
movedim
(
0
,
dim
)
output_tensor
=
output_tensor
.
reshape
(
input_size
[:
dim
]
+
(
self
.
world_size
*
input_size
[
dim
],
)
+
input_size
[
dim
+
1
:])
else
:
output_tensor
=
None
return
output_tensor
vllm/distributed/parallel_state.py
View file @
a0231b7c
...
@@ -39,9 +39,12 @@ from torch.distributed import Backend, ProcessGroup
...
@@ -39,9 +39,12 @@ from torch.distributed import Backend, ProcessGroup
import
vllm.distributed.kv_transfer.kv_transfer_agent
as
kv_transfer
import
vllm.distributed.kv_transfer.kv_transfer_agent
as
kv_transfer
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.distributed.device_communicators.base_device_communicator
import
(
DeviceCommunicatorBase
)
from
vllm.distributed.utils
import
StatelessProcessGroup
from
vllm.distributed.utils
import
StatelessProcessGroup
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
direct_register_custom_op
,
supports_custom_op
from
vllm.utils
import
(
direct_register_custom_op
,
resolve_obj_by_qualname
,
supports_custom_op
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
...
@@ -130,9 +133,8 @@ class GroupCoordinator:
...
@@ -130,9 +133,8 @@ class GroupCoordinator:
PyTorch ProcessGroup is bound to one specific communication backend,
PyTorch ProcessGroup is bound to one specific communication backend,
e.g. NCCL, Gloo, MPI, etc.
e.g. NCCL, Gloo, MPI, etc.
GroupCoordinator takes charge of all the communication operations among
GroupCoordinator takes charge of all the communication operations among
the processes in the group. It can route the communication to
the processes in the group. It manages both CPU and device
a specific implementation (e.g. switch allreduce implementation
communication.
based on the tensor size and cuda graph mode).
"""
"""
# available attributes:
# available attributes:
...
@@ -150,11 +152,8 @@ class GroupCoordinator:
...
@@ -150,11 +152,8 @@ class GroupCoordinator:
rank_in_group
:
int
# rank inside the group
rank_in_group
:
int
# rank inside the group
cpu_group
:
ProcessGroup
# group for CPU communication
cpu_group
:
ProcessGroup
# group for CPU communication
device_group
:
ProcessGroup
# group for device communication
device_group
:
ProcessGroup
# group for device communication
use_pynccl
:
bool
# a hint of whether to use PyNccl
use_device_communicator
:
bool
# whether to use device communicator
use_custom_allreduce
:
bool
# a hint of whether to use CustomAllreduce
device_communicator
:
DeviceCommunicatorBase
# device communicator
# communicators are only created for world size > 1
pynccl_comm
:
Optional
[
Any
]
# PyNccl communicator
ca_comm
:
Optional
[
Any
]
# Custom allreduce communicator
mq_broadcaster
:
Optional
[
Any
]
# shared memory broadcaster
mq_broadcaster
:
Optional
[
Any
]
# shared memory broadcaster
def
__init__
(
def
__init__
(
...
@@ -162,11 +161,7 @@ class GroupCoordinator:
...
@@ -162,11 +161,7 @@ class GroupCoordinator:
group_ranks
:
List
[
List
[
int
]],
group_ranks
:
List
[
List
[
int
]],
local_rank
:
int
,
local_rank
:
int
,
torch_distributed_backend
:
Union
[
str
,
Backend
],
torch_distributed_backend
:
Union
[
str
,
Backend
],
use_pynccl
:
bool
,
use_device_communicator
:
bool
,
use_custom_allreduce
:
bool
,
use_tpu_communicator
:
bool
,
use_hpu_communicator
:
bool
,
use_xpu_communicator
:
bool
,
use_message_queue_broadcaster
:
bool
=
False
,
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
Optional
[
str
]
=
None
,
group_name
:
Optional
[
str
]
=
None
,
):
):
...
@@ -196,56 +191,26 @@ class GroupCoordinator:
...
@@ -196,56 +191,26 @@ class GroupCoordinator:
assert
self
.
device_group
is
not
None
assert
self
.
device_group
is
not
None
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
# TODO: fix it for other platforms
if
current_platform
.
is_cuda_alike
():
if
current_platform
.
is_cuda_alike
():
self
.
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
self
.
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
else
:
else
:
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
use_pynccl
=
use_pynccl
self
.
use_device_communicator
=
use_device_communicator
self
.
use_custom_allreduce
=
use_custom_allreduce
self
.
use_tpu_communicator
=
use_tpu_communicator
self
.
use_hpu_communicator
=
use_hpu_communicator
self
.
use_xpu_communicator
=
use_xpu_communicator
# lazy import to avoid documentation build error
from
vllm.distributed.device_communicators.custom_all_reduce
import
(
CustomAllreduce
)
from
vllm.distributed.device_communicators.pynccl
import
(
PyNcclCommunicator
)
self
.
pynccl_comm
:
Optional
[
PyNcclCommunicator
]
=
None
if
use_pynccl
and
self
.
world_size
>
1
:
self
.
pynccl_comm
=
PyNcclCommunicator
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
)
self
.
ca_comm
:
Optional
[
CustomAllreduce
]
=
None
self
.
device_communicator
:
DeviceCommunicatorBase
=
None
# type: ignore
if
use_custom_allreduce
and
self
.
world_size
>
1
:
if
use_device_communicator
and
self
.
world_size
>
1
:
# Initialize a custom fast all-reduce implementation.
device_comm_cls
=
resolve_obj_by_qualname
(
self
.
ca_comm
=
CustomAllreduce
(
current_platform
.
get_device_communicator_cls
())
group
=
self
.
cpu_group
,
self
.
device_communicator
=
device_comm_cls
(
cpu_group
=
self
.
cpu_group
,
device
=
self
.
device
,
device
=
self
.
device
,
device_group
=
self
.
device_group
,
unique_name
=
self
.
unique_name
,
)
)
from
vllm.distributed.device_communicators.tpu_communicator
import
(
TpuCommunicator
)
self
.
tpu_communicator
:
Optional
[
TpuCommunicator
]
=
None
if
use_tpu_communicator
and
self
.
world_size
>
1
:
self
.
tpu_communicator
=
TpuCommunicator
(
group
=
self
.
cpu_group
)
from
vllm.distributed.device_communicators.hpu_communicator
import
(
HpuCommunicator
)
self
.
hpu_communicator
:
Optional
[
HpuCommunicator
]
if
use_hpu_communicator
and
self
.
world_size
>
1
:
self
.
hpu_communicator
=
HpuCommunicator
(
group
=
self
.
device_group
)
from
vllm.distributed.device_communicators.xpu_communicator
import
(
XpuCommunicator
)
self
.
xpu_communicator
:
Optional
[
XpuCommunicator
]
if
use_xpu_communicator
and
self
.
world_size
>
1
:
self
.
xpu_communicator
=
XpuCommunicator
(
group
=
self
.
device_group
)
from
vllm.distributed.device_communicators.shm_broadcast
import
(
from
vllm.distributed.device_communicators.shm_broadcast
import
(
MessageQueue
)
MessageQueue
)
self
.
mq_broadcaster
:
Optional
[
MessageQueue
]
=
None
self
.
mq_broadcaster
:
Optional
[
MessageQueue
]
=
None
...
@@ -253,6 +218,9 @@ class GroupCoordinator:
...
@@ -253,6 +218,9 @@ class GroupCoordinator:
self
.
mq_broadcaster
=
MessageQueue
.
create_from_process_group
(
self
.
mq_broadcaster
=
MessageQueue
.
create_from_process_group
(
self
.
cpu_group
,
1
<<
22
,
6
)
self
.
cpu_group
,
1
<<
22
,
6
)
from
vllm.platforms
import
current_platform
self
.
use_custom_op_call
=
current_platform
.
is_cuda_alike
()
@
property
@
property
def
first_rank
(
self
):
def
first_rank
(
self
):
"""Return the global rank of the first process in the group"""
"""Return the global rank of the first process in the group"""
...
@@ -296,9 +264,16 @@ class GroupCoordinator:
...
@@ -296,9 +264,16 @@ class GroupCoordinator:
else
:
else
:
stream
=
graph_capture_context
.
stream
stream
=
graph_capture_context
.
stream
ca_comm
=
self
.
ca_comm
# only cuda uses this function,
maybe_ca_context
=
nullcontext
(
# so we don't abstract it into the base class
)
if
ca_comm
is
None
else
ca_comm
.
capture
()
maybe_ca_context
=
nullcontext
()
from
vllm.distributed.device_communicators.cuda_communicator
import
(
CudaCommunicator
)
if
self
.
device_communicator
is
not
None
:
assert
isinstance
(
self
.
device_communicator
,
CudaCommunicator
)
ca_comm
=
self
.
device_communicator
.
ca_comm
if
ca_comm
is
not
None
:
maybe_ca_context
=
ca_comm
.
capture
()
# type: ignore
# ensure all initialization operations complete before attempting to
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
# capture the graph on another stream
...
@@ -328,54 +303,14 @@ class GroupCoordinator:
...
@@ -328,54 +303,14 @@ class GroupCoordinator:
if
self
.
world_size
==
1
:
if
self
.
world_size
==
1
:
return
input_
return
input_
if
input_
.
is_cpu
:
if
self
.
use_custom_op_call
:
try
:
return
torch
.
ops
.
vllm
.
all_reduce
(
input_
,
import
intel_extension_for_pytorch
as
ipex
group_name
=
self
.
unique_name
)
ipex
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
else
:
return
input_
return
self
.
_all_reduce_out_place
(
input_
)
except
ImportError
:
"""
Intel IPEX not found. Falling back to PyTorch native
all_reduce for CPU
"""
torch
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
return
input_
if
self
.
tpu_communicator
is
not
None
and
\
not
self
.
tpu_communicator
.
disabled
:
# TPU handles Dynamo with its own logic.
return
self
.
tpu_communicator
.
all_reduce
(
input_
)
if
self
.
hpu_communicator
is
not
None
and
\
not
self
.
hpu_communicator
.
disabled
:
return
self
.
hpu_communicator
.
all_reduce
(
input_
)
if
self
.
xpu_communicator
is
not
None
and
\
not
self
.
xpu_communicator
.
disabled
:
return
self
.
xpu_communicator
.
all_reduce
(
input_
)
return
torch
.
ops
.
vllm
.
all_reduce
(
input_
,
group_name
=
self
.
unique_name
)
def
_all_reduce_out_place
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_all_reduce_out_place
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# always try custom allreduce first,
return
self
.
device_communicator
.
all_reduce
(
input_
)
# and then pynccl.
ca_comm
=
self
.
ca_comm
if
ca_comm
is
not
None
and
not
ca_comm
.
disabled
and
\
ca_comm
.
should_custom_ar
(
input_
):
out
=
ca_comm
.
custom_all_reduce
(
input_
)
assert
out
is
not
None
return
out
pynccl_comm
=
self
.
pynccl_comm
assert
pynccl_comm
is
not
None
out
=
pynccl_comm
.
all_reduce
(
input_
)
if
out
is
None
:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
# when we run the model, allreduce only happens for the TP
# group, where we always have either custom allreduce or pynccl.
out
=
input_
.
clone
()
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
device_group
)
return
out
def
all_gather
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
def
all_gather
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
world_size
=
self
.
world_size
world_size
=
self
.
world_size
...
@@ -385,40 +320,7 @@ class GroupCoordinator:
...
@@ -385,40 +320,7 @@ class GroupCoordinator:
assert
-
input_
.
dim
()
<=
dim
<
input_
.
dim
(),
(
assert
-
input_
.
dim
()
<=
dim
<
input_
.
dim
(),
(
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
)
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
)
# For TPUs, use TPU communicator.
return
self
.
device_communicator
.
all_gather
(
input_
,
dim
)
tpu_comm
=
self
.
tpu_communicator
if
tpu_comm
is
not
None
and
not
tpu_comm
.
disabled
:
return
tpu_comm
.
all_gather
(
input_
,
dim
)
# For HPUs, use HPU communicator.
hpu_comm
=
self
.
hpu_communicator
if
hpu_comm
is
not
None
and
not
hpu_comm
.
disabled
:
return
hpu_comm
.
all_gather
(
input_
,
dim
)
if
dim
<
0
:
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
input_size
=
input_
.
size
()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size
=
(
input_size
[
0
]
*
world_size
,
)
+
input_size
[
1
:]
# Allocate output tensor.
output_tensor
=
torch
.
empty
(
output_size
,
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
# All-gather.
torch
.
distributed
.
all_gather_into_tensor
(
output_tensor
,
input_
,
group
=
self
.
device_group
)
# Reshape
output_tensor
=
output_tensor
.
reshape
((
world_size
,
)
+
input_size
)
output_tensor
=
output_tensor
.
movedim
(
0
,
dim
)
output_tensor
=
output_tensor
.
reshape
(
input_size
[:
dim
]
+
(
world_size
*
input_size
[
dim
],
)
+
input_size
[
dim
+
1
:])
return
output_tensor
def
gather
(
self
,
def
gather
(
self
,
input_
:
torch
.
Tensor
,
input_
:
torch
.
Tensor
,
...
@@ -433,30 +335,7 @@ class GroupCoordinator:
...
@@ -433,30 +335,7 @@ class GroupCoordinator:
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
if
world_size
==
1
:
return
input_
return
input_
assert
-
input_
.
dim
()
<=
dim
<
input_
.
dim
(),
(
return
self
.
device_communicator
.
gather
(
input_
,
dst
,
dim
)
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
)
if
dim
<
0
:
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
if
self
.
xpu_communicator
is
not
None
and
\
not
self
.
xpu_communicator
.
disabled
:
return
self
.
xpu_communicator
.
gather
(
input_
,
self
.
rank_in_group
,
dst
,
dim
)
# Allocate output tensor.
if
self
.
rank_in_group
==
dst
:
gather_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
else
:
gather_list
=
None
# Gather.
torch
.
distributed
.
gather
(
input_
,
gather_list
,
dst
=
self
.
ranks
[
dst
],
group
=
self
.
device_group
)
if
self
.
rank_in_group
==
dst
:
output_tensor
=
torch
.
cat
(
gather_list
,
dim
=
dim
)
else
:
output_tensor
=
None
return
output_tensor
def
broadcast
(
self
,
input_
:
torch
.
Tensor
,
src
:
int
=
0
):
def
broadcast
(
self
,
input_
:
torch
.
Tensor
,
src
:
int
=
0
):
"""Broadcast the input tensor.
"""Broadcast the input tensor.
...
@@ -798,14 +677,7 @@ class GroupCoordinator:
...
@@ -798,14 +677,7 @@ class GroupCoordinator:
def
send
(
self
,
tensor
:
torch
.
Tensor
,
dst
:
Optional
[
int
]
=
None
)
->
None
:
def
send
(
self
,
tensor
:
torch
.
Tensor
,
dst
:
Optional
[
int
]
=
None
)
->
None
:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
"""NOTE: `dst` is the local rank of the destination rank."""
if
dst
is
None
:
self
.
device_communicator
.
send
(
tensor
,
dst
)
dst
=
(
self
.
rank_in_group
+
1
)
%
self
.
world_size
pynccl_comm
=
self
.
pynccl_comm
if
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
:
pynccl_comm
.
send
(
tensor
,
dst
)
else
:
torch
.
distributed
.
send
(
tensor
,
self
.
ranks
[
dst
],
self
.
device_group
)
def
recv
(
self
,
def
recv
(
self
,
size
:
torch
.
Size
,
size
:
torch
.
Size
,
...
@@ -813,16 +685,7 @@ class GroupCoordinator:
...
@@ -813,16 +685,7 @@ class GroupCoordinator:
src
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
src
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""Receives a tensor from the source rank."""
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if
src
is
None
:
return
self
.
device_communicator
.
recv
(
size
,
dtype
,
src
)
src
=
(
self
.
rank_in_group
-
1
)
%
self
.
world_size
tensor
=
torch
.
empty
(
size
,
dtype
=
dtype
,
device
=
self
.
device
)
pynccl_comm
=
self
.
pynccl_comm
if
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
:
pynccl_comm
.
recv
(
tensor
,
src
)
else
:
torch
.
distributed
.
recv
(
tensor
,
self
.
ranks
[
src
],
self
.
device_group
)
return
tensor
def
destroy
(
self
):
def
destroy
(
self
):
if
self
.
device_group
is
not
None
:
if
self
.
device_group
is
not
None
:
...
@@ -831,10 +694,8 @@ class GroupCoordinator:
...
@@ -831,10 +694,8 @@ class GroupCoordinator:
if
self
.
cpu_group
is
not
None
:
if
self
.
cpu_group
is
not
None
:
torch
.
distributed
.
destroy_process_group
(
self
.
cpu_group
)
torch
.
distributed
.
destroy_process_group
(
self
.
cpu_group
)
self
.
cpu_group
=
None
self
.
cpu_group
=
None
if
self
.
pynccl_comm
is
not
None
:
if
self
.
device_communicator
is
not
None
:
self
.
pynccl_comm
=
None
self
.
device_communicator
.
destroy
()
if
self
.
ca_comm
is
not
None
:
self
.
ca_comm
=
None
if
self
.
mq_broadcaster
is
not
None
:
if
self
.
mq_broadcaster
is
not
None
:
self
.
mq_broadcaster
=
None
self
.
mq_broadcaster
=
None
...
@@ -853,11 +714,7 @@ def init_world_group(ranks: List[int], local_rank: int,
...
@@ -853,11 +714,7 @@ def init_world_group(ranks: List[int], local_rank: int,
group_ranks
=
[
ranks
],
group_ranks
=
[
ranks
],
local_rank
=
local_rank
,
local_rank
=
local_rank
,
torch_distributed_backend
=
backend
,
torch_distributed_backend
=
backend
,
use_pynccl
=
False
,
use_device_communicator
=
False
,
use_custom_allreduce
=
False
,
use_tpu_communicator
=
False
,
use_hpu_communicator
=
False
,
use_xpu_communicator
=
False
,
group_name
=
"world"
,
group_name
=
"world"
,
)
)
...
@@ -866,23 +723,15 @@ def init_model_parallel_group(
...
@@ -866,23 +723,15 @@ def init_model_parallel_group(
group_ranks
:
List
[
List
[
int
]],
group_ranks
:
List
[
List
[
int
]],
local_rank
:
int
,
local_rank
:
int
,
backend
:
str
,
backend
:
str
,
use_custom_allreduce
:
Optional
[
bool
]
=
None
,
use_message_queue_broadcaster
:
bool
=
False
,
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
Optional
[
str
]
=
None
,
group_name
:
Optional
[
str
]
=
None
,
)
->
GroupCoordinator
:
)
->
GroupCoordinator
:
if
use_custom_allreduce
is
None
:
use_custom_allreduce
=
_ENABLE_CUSTOM_ALL_REDUCE
from
vllm.platforms
import
current_platform
return
GroupCoordinator
(
return
GroupCoordinator
(
group_ranks
=
group_ranks
,
group_ranks
=
group_ranks
,
local_rank
=
local_rank
,
local_rank
=
local_rank
,
torch_distributed_backend
=
backend
,
torch_distributed_backend
=
backend
,
use_pynccl
=
current_platform
.
is_cuda_alike
(),
use_device_communicator
=
True
,
use_custom_allreduce
=
current_platform
.
is_cuda_alike
()
and
use_custom_allreduce
,
use_tpu_communicator
=
True
,
use_hpu_communicator
=
True
,
use_xpu_communicator
=
True
,
use_message_queue_broadcaster
=
use_message_queue_broadcaster
,
use_message_queue_broadcaster
=
use_message_queue_broadcaster
,
group_name
=
group_name
,
group_name
=
group_name
,
)
)
...
@@ -1053,11 +902,9 @@ def initialize_model_parallel(
...
@@ -1053,11 +902,9 @@ def initialize_model_parallel(
for
i
in
range
(
num_pipeline_model_parallel_groups
):
for
i
in
range
(
num_pipeline_model_parallel_groups
):
ranks
=
list
(
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
))
ranks
=
list
(
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
))
group_ranks
.
append
(
ranks
)
group_ranks
.
append
(
ranks
)
# pipeline parallel does not need custom allreduce
_PP
=
init_model_parallel_group
(
group_ranks
,
_PP
=
init_model_parallel_group
(
group_ranks
,
get_world_group
().
local_rank
,
get_world_group
().
local_rank
,
backend
,
backend
,
use_custom_allreduce
=
False
,
group_name
=
"pp"
)
group_name
=
"pp"
)
...
...
vllm/platforms/cpu.py
View file @
a0231b7c
...
@@ -146,3 +146,10 @@ class CpuPlatform(Platform):
...
@@ -146,3 +146,10 @@ class CpuPlatform(Platform):
@
classmethod
@
classmethod
def
get_punica_wrapper
(
cls
)
->
str
:
def
get_punica_wrapper
(
cls
)
->
str
:
return
"vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
return
"vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
@
classmethod
def
get_device_communicator_cls
(
cls
)
->
str
:
"""
Get device specific communicator class for distributed communication.
"""
return
"vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator"
# noqa
vllm/platforms/cuda.py
View file @
a0231b7c
...
@@ -233,6 +233,10 @@ class CudaPlatformBase(Platform):
...
@@ -233,6 +233,10 @@ class CudaPlatformBase(Platform):
def
get_punica_wrapper
(
cls
)
->
str
:
def
get_punica_wrapper
(
cls
)
->
str
:
return
"vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
return
"vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
@
classmethod
def
get_device_communicator_cls
(
cls
)
->
str
:
return
"vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"
# noqa
# NVML utils
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
...
...
vllm/platforms/hpu.py
View file @
a0231b7c
...
@@ -88,3 +88,7 @@ class HpuPlatform(Platform):
...
@@ -88,3 +88,7 @@ class HpuPlatform(Platform):
@
classmethod
@
classmethod
def
get_punica_wrapper
(
cls
)
->
str
:
def
get_punica_wrapper
(
cls
)
->
str
:
return
"vllm.lora.punica_wrapper.punica_hpu.PunicaWrapperHPU"
return
"vllm.lora.punica_wrapper.punica_hpu.PunicaWrapperHPU"
@
classmethod
def
get_device_communicator_cls
(
cls
)
->
str
:
return
"vllm.distributed.device_communicators.hpu_communicator.HpuCommunicator"
# noqa
vllm/platforms/interface.py
View file @
a0231b7c
...
@@ -322,6 +322,13 @@ class Platform:
...
@@ -322,6 +322,13 @@ class Platform:
"""
"""
raise
NotImplementedError
raise
NotImplementedError
@
classmethod
def
get_device_communicator_cls
(
cls
)
->
str
:
"""
Get device specific communicator class for distributed communication.
"""
return
"vllm.distributed.device_communicator.base_device_communicator.DeviceCommunicatorBase"
# noqa
class
UnspecifiedPlatform
(
Platform
):
class
UnspecifiedPlatform
(
Platform
):
_enum
=
PlatformEnum
.
UNSPECIFIED
_enum
=
PlatformEnum
.
UNSPECIFIED
...
...
vllm/platforms/rocm.py
View file @
a0231b7c
...
@@ -186,3 +186,7 @@ class RocmPlatform(Platform):
...
@@ -186,3 +186,7 @@ class RocmPlatform(Platform):
torch
.
cuda
.
reset_peak_memory_stats
(
device
)
torch
.
cuda
.
reset_peak_memory_stats
(
device
)
return
torch
.
cuda
.
mem_get_info
(
device
)[
1
]
-
torch
.
cuda
.
mem_get_info
(
return
torch
.
cuda
.
mem_get_info
(
device
)[
1
]
-
torch
.
cuda
.
mem_get_info
(
device
)[
0
]
device
)[
0
]
@
classmethod
def
get_device_communicator_cls
(
cls
)
->
str
:
return
"vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"
# noqa
vllm/platforms/tpu.py
View file @
a0231b7c
...
@@ -115,3 +115,7 @@ class TpuPlatform(Platform):
...
@@ -115,3 +115,7 @@ class TpuPlatform(Platform):
def
is_pin_memory_available
(
cls
):
def
is_pin_memory_available
(
cls
):
logger
.
warning
(
"Pin memory is not supported on TPU."
)
logger
.
warning
(
"Pin memory is not supported on TPU."
)
return
False
return
False
@
classmethod
def
get_device_communicator_cls
(
cls
)
->
str
:
return
"vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator"
# noqa
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment