Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
01b61136
Unverified
Commit
01b61136
authored
Apr 02, 2025
by
Chengji Yao
Committed by
GitHub
Apr 03, 2025
Browse files
[TPU] optimize the all-reduce performance (#15903)
Signed-off-by:
Chengji Yao
<
chengjiyao@google.com
>
parent
1b84eff0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
2 deletions
+16
-2
vllm/distributed/device_communicators/tpu_communicator.py
vllm/distributed/device_communicators/tpu_communicator.py
+6
-1
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+4
-1
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+6
-0
No files found.
vllm/distributed/device_communicators/tpu_communicator.py
View file @
01b61136
...
@@ -22,6 +22,8 @@ if current_platform.is_tpu():
...
@@ -22,6 +22,8 @@ if current_platform.is_tpu():
import
torch_xla.core.xla_model
as
xm
import
torch_xla.core.xla_model
as
xm
import
torch_xla.runtime
as
xr
import
torch_xla.runtime
as
xr
from
torch_xla._internal
import
pjrt
from
torch_xla._internal
import
pjrt
from
torch_xla.distributed.xla_multiprocessing
import
(
create_optimized_replica_groups
)
if
USE_RAY
:
if
USE_RAY
:
from
vllm.executor
import
ray_utils
from
vllm.executor
import
ray_utils
...
@@ -79,9 +81,12 @@ class TpuCommunicator(DeviceCommunicatorBase):
...
@@ -79,9 +81,12 @@ class TpuCommunicator(DeviceCommunicatorBase):
pjrt
.
initialize_multiprocess
(
local_rank
,
local_world_size
)
pjrt
.
initialize_multiprocess
(
local_rank
,
local_world_size
)
xr
.
_init_world_size_ordinal
()
xr
.
_init_world_size_ordinal
()
self
.
groups
=
create_optimized_replica_groups
()
def
all_reduce
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
all_reduce
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
xm
.
all_reduce
(
xm
.
REDUCE_SUM
,
input_
)
# TODO: Remove the groups specification after XLA compiler can support
# auto-reordering the ring order for all-reduce.
return
xm
.
all_reduce
(
xm
.
REDUCE_SUM
,
input_
,
groups
=
self
.
groups
)
def
all_gather
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
def
all_gather
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
assert
dim
==
-
1
,
"TPUs only support dim=-1 for all-gather."
assert
dim
==
-
1
,
"TPUs only support dim=-1 for all-gather."
...
...
vllm/distributed/parallel_state.py
View file @
01b61136
...
@@ -119,11 +119,13 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
...
@@ -119,11 +119,13 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
if
supports_custom_op
():
if
supports_custom_op
():
from
vllm.platforms
import
current_platform
direct_register_custom_op
(
direct_register_custom_op
(
op_name
=
"all_reduce"
,
op_name
=
"all_reduce"
,
op_func
=
all_reduce
,
op_func
=
all_reduce
,
mutates_args
=
[],
mutates_args
=
[],
fake_impl
=
all_reduce_fake
,
fake_impl
=
all_reduce_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
)
...
@@ -219,7 +221,8 @@ class GroupCoordinator:
...
@@ -219,7 +221,8 @@ class GroupCoordinator:
self
.
cpu_group
,
1
<<
22
,
6
)
self
.
cpu_group
,
1
<<
22
,
6
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
self
.
use_custom_op_call
=
current_platform
.
is_cuda_alike
()
self
.
use_custom_op_call
=
(
current_platform
.
is_cuda_alike
()
or
current_platform
.
is_tpu
())
@
property
@
property
def
first_rank
(
self
):
def
first_rank
(
self
):
...
...
vllm/v1/worker/tpu_worker.py
View file @
01b61136
...
@@ -84,6 +84,12 @@ class TPUWorker:
...
@@ -84,6 +84,12 @@ class TPUWorker:
def
init_device
(
self
):
def
init_device
(
self
):
os
.
environ
[
"PJRT_DEVICE"
]
=
"TPU"
os
.
environ
[
"PJRT_DEVICE"
]
=
"TPU"
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
# ring, the xla tpu compiler flag
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
# fix this. It will be removed after the bug in XLA compiler is fixed.
os
.
environ
[
"LIBTPU_INIT_ARGS"
]
=
(
"--xla_tpu_force_1d_allreduce_at_chunk_count=1"
)
torch
.
set_grad_enabled
(
False
)
torch
.
set_grad_enabled
(
False
)
torch
.
set_default_dtype
(
self
.
model_config
.
dtype
)
torch
.
set_default_dtype
(
self
.
model_config
.
dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment