Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cfbca8a2
Unverified
Commit
cfbca8a2
authored
Mar 19, 2025
by
Alexander Matveev
Committed by
GitHub
Mar 20, 2025
Browse files
[V1] TPU - Tensor parallel MP support (#15059)
parent
0fe56098
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
15 deletions
+37
-15
vllm/config.py
vllm/config.py
+1
-1
vllm/distributed/device_communicators/tpu_communicator.py
vllm/distributed/device_communicators/tpu_communicator.py
+36
-14
No files found.
vllm/config.py
View file @
cfbca8a2
...
@@ -1473,7 +1473,7 @@ class ParallelConfig:
...
@@ -1473,7 +1473,7 @@ class ParallelConfig:
os
.
environ
[
"VLLM_ENABLE_V1_MULTIPROCESSING"
]
=
"0"
os
.
environ
[
"VLLM_ENABLE_V1_MULTIPROCESSING"
]
=
"0"
logger
.
info
(
"Disabling V1 multiprocessing for external launcher."
)
logger
.
info
(
"Disabling V1 multiprocessing for external launcher."
)
ray_only_devices
=
[
"tpu"
]
ray_only_devices
:
list
[
str
]
=
[
]
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
if
(
current_platform
.
device_type
in
ray_only_devices
if
(
current_platform
.
device_type
in
ray_only_devices
and
self
.
world_size
>
1
):
and
self
.
world_size
>
1
):
...
...
vllm/distributed/device_communicators/tpu_communicator.py
View file @
cfbca8a2
...
@@ -6,16 +6,25 @@ from typing import Optional
...
@@ -6,16 +6,25 @@ from typing import Optional
import
torch
import
torch
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
vllm.config
import
get_current_vllm_config
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
.base_device_communicator
import
DeviceCommunicatorBase
from
.base_device_communicator
import
DeviceCommunicatorBase
USE_RAY
=
parallel_config
=
get_current_vllm_config
(
).
parallel_config
.
distributed_executor_backend
==
"ray"
logger
=
init_logger
(
__name__
)
if
current_platform
.
is_tpu
():
if
current_platform
.
is_tpu
():
import
torch_xla
import
torch_xla.core.xla_model
as
xm
import
torch_xla.core.xla_model
as
xm
import
torch_xla.runtime
as
xr
import
torch_xla.runtime
as
xr
from
torch_xla._internal
import
pjrt
from
torch_xla._internal
import
pjrt
from
vllm.executor
import
ray_utils
if
USE_RAY
:
from
vllm.executor
import
ray_utils
class
TpuCommunicator
(
DeviceCommunicatorBase
):
class
TpuCommunicator
(
DeviceCommunicatorBase
):
...
@@ -33,19 +42,32 @@ class TpuCommunicator(DeviceCommunicatorBase):
...
@@ -33,19 +42,32 @@ class TpuCommunicator(DeviceCommunicatorBase):
global_rank
=
self
.
global_rank
global_rank
=
self
.
global_rank
global_world_size
=
self
.
global_world_size
global_world_size
=
self
.
global_world_size
# Calculate how many TPU nodes are in the current deployment. This
if
USE_RAY
:
# is the Ray placement group if it is deployed with Ray. Default
logger
.
info
(
"TpuCommunicator initialized with RAY"
)
# to the number of TPU nodes in the Ray cluster. The number of TPU
# Calculate how many TPU nodes are in the current deployment. This
# nodes is computed by the total number of TPUs divided by the
# is the Ray placement group if it is deployed with Ray. Default
# number of TPU accelerators per node, to account for clusters
# to the number of TPU nodes in the Ray cluster. The number of TPU
# with both CPUs and TPUs.
# nodes is computed by the total number of TPUs divided by the
num_nodes
=
ray_utils
.
get_num_tpu_nodes
()
# number of TPU accelerators per node, to account for clusters
num_nodes_in_pg
=
ray_utils
.
get_num_nodes_in_placement_group
()
# with both CPUs and TPUs.
if
num_nodes_in_pg
>
0
:
num_nodes
=
ray_utils
.
get_num_tpu_nodes
()
num_nodes
=
num_nodes_in_pg
num_nodes_in_pg
=
ray_utils
.
get_num_nodes_in_placement_group
()
if
num_nodes_in_pg
>
0
:
local_world_size
=
global_world_size
//
num_nodes
num_nodes
=
num_nodes_in_pg
local_rank
=
global_rank
%
local_world_size
local_world_size
=
global_world_size
//
num_nodes
local_rank
=
global_rank
%
local_world_size
else
:
logger
.
info
(
"TpuCommunicator initialized with MP"
)
# Sanity: Verify we run on a single host
num_hosts
=
torch_xla
.
tpu
.
num_tpu_workers
()
assert
num_hosts
==
1
# Get the current number of TPUs (we have locally)
local_world_size
=
torch_xla
.
tpu
.
num_available_chips
()
# Get current rank
local_rank
=
global_rank
%
local_world_size
# Ensure environment variables are set for multihost deployments.
# Ensure environment variables are set for multihost deployments.
# On GKE, this is needed for libtpu and TPU driver to know which TPU
# On GKE, this is needed for libtpu and TPU driver to know which TPU
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment