Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d09b94ca
Unverified
Commit
d09b94ca
authored
Jul 26, 2024
by
Woosuk Kwon
Committed by
GitHub
Jul 27, 2024
Browse files
[TPU] Support collective communications in XLA devices (#6813)
parent
bb549467
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
70 additions
and
2 deletions
+70
-2
vllm/distributed/device_communicators/tpu_communicator.py
vllm/distributed/device_communicators/tpu_communicator.py
+30
-0
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+22
-0
vllm/lora/layers.py
vllm/lora/layers.py
+4
-0
vllm/model_executor/layers/logits_processor.py
vllm/model_executor/layers/logits_processor.py
+14
-2
No files found.
vllm/distributed/device_communicators/tpu_communicator.py
0 → 100644
View file @
d09b94ca
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
from
vllm.platforms
import
current_platform
if
current_platform
.
is_tpu
():
import
torch_xla.core.xla_model
as
xm
from
torch_xla._internal
import
pjrt
class
TpuCommunicator
:
def
__init__
(
self
,
group
:
ProcessGroup
):
if
not
current_platform
.
is_tpu
():
self
.
disabled
=
True
return
self
.
disabled
=
False
local_rank
=
dist
.
get_rank
(
group
)
world_size
=
dist
.
get_world_size
(
group
)
pjrt
.
initialize_multiprocess
(
local_rank
,
world_size
)
xm
.
_init_world_size_ordinal
()
def
all_reduce
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
xm
.
all_reduce
(
xm
.
REDUCE_SUM
,
x
)
def
all_gather
(
self
,
x
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
assert
dim
==
-
1
,
"TPUs only support dim=-1 for all-gather."
return
xm
.
all_gather
(
x
,
dim
=
dim
)
vllm/distributed/parallel_state.py
View file @
d09b94ca
...
@@ -133,6 +133,7 @@ class GroupCoordinator:
...
@@ -133,6 +133,7 @@ class GroupCoordinator:
torch_distributed_backend
:
Union
[
str
,
Backend
],
torch_distributed_backend
:
Union
[
str
,
Backend
],
use_pynccl
:
bool
,
use_pynccl
:
bool
,
use_custom_allreduce
:
bool
,
use_custom_allreduce
:
bool
,
use_tpu_communicator
:
bool
,
use_message_queue_broadcaster
:
bool
=
False
,
use_message_queue_broadcaster
:
bool
=
False
,
):
):
...
@@ -164,6 +165,7 @@ class GroupCoordinator:
...
@@ -164,6 +165,7 @@ class GroupCoordinator:
self
.
use_pynccl
=
use_pynccl
self
.
use_pynccl
=
use_pynccl
self
.
use_custom_allreduce
=
use_custom_allreduce
self
.
use_custom_allreduce
=
use_custom_allreduce
self
.
use_tpu_communicator
=
use_tpu_communicator
# lazy import to avoid documentation build error
# lazy import to avoid documentation build error
from
vllm.distributed.device_communicators.custom_all_reduce
import
(
from
vllm.distributed.device_communicators.custom_all_reduce
import
(
...
@@ -190,6 +192,12 @@ class GroupCoordinator:
...
@@ -190,6 +192,12 @@ class GroupCoordinator:
else
:
else
:
self
.
ca_comm
=
None
self
.
ca_comm
=
None
from
vllm.distributed.device_communicators.tpu_communicator
import
(
TpuCommunicator
)
self
.
tpu_communicator
:
Optional
[
TpuCommunicator
]
if
use_tpu_communicator
and
self
.
world_size
>
1
:
self
.
tpu_communicator
=
TpuCommunicator
(
group
=
self
.
cpu_group
)
from
vllm.distributed.device_communicators.shm_broadcast
import
(
from
vllm.distributed.device_communicators.shm_broadcast
import
(
MessageQueue
)
MessageQueue
)
self
.
mq_broadcaster
:
Optional
[
MessageQueue
]
=
None
self
.
mq_broadcaster
:
Optional
[
MessageQueue
]
=
None
...
@@ -289,6 +297,12 @@ class GroupCoordinator:
...
@@ -289,6 +297,12 @@ class GroupCoordinator:
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
if
self
.
world_size
==
1
:
if
self
.
world_size
==
1
:
return
input_
return
input_
# For TPUs, use TPU communicator.
tpu_comm
=
self
.
tpu_communicator
if
tpu_comm
is
not
None
and
not
tpu_comm
.
disabled
:
return
tpu_comm
.
all_reduce
(
input_
)
if
ca_comm
is
not
None
:
if
ca_comm
is
not
None
:
out
=
ca_comm
.
custom_all_reduce
(
input_
)
out
=
ca_comm
.
custom_all_reduce
(
input_
)
if
out
is
not
None
:
if
out
is
not
None
:
...
@@ -310,6 +324,12 @@ class GroupCoordinator:
...
@@ -310,6 +324,12 @@ class GroupCoordinator:
return
input_
return
input_
assert
-
input_
.
dim
()
<=
dim
<
input_
.
dim
(),
(
assert
-
input_
.
dim
()
<=
dim
<
input_
.
dim
(),
(
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
)
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
)
# For TPUs, use TPU communicator.
tpu_comm
=
self
.
tpu_communicator
if
tpu_comm
is
not
None
and
not
tpu_comm
.
disabled
:
return
tpu_comm
.
all_gather
(
input_
,
dim
)
if
dim
<
0
:
if
dim
<
0
:
# Convert negative dim to positive.
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
dim
+=
input_
.
dim
()
...
@@ -727,6 +747,7 @@ def init_world_group(ranks: List[int], local_rank: int,
...
@@ -727,6 +747,7 @@ def init_world_group(ranks: List[int], local_rank: int,
torch_distributed_backend
=
backend
,
torch_distributed_backend
=
backend
,
use_pynccl
=
False
,
use_pynccl
=
False
,
use_custom_allreduce
=
False
,
use_custom_allreduce
=
False
,
use_tpu_communicator
=
False
,
)
)
...
@@ -745,6 +766,7 @@ def init_model_parallel_group(
...
@@ -745,6 +766,7 @@ def init_model_parallel_group(
torch_distributed_backend
=
backend
,
torch_distributed_backend
=
backend
,
use_pynccl
=
True
,
use_pynccl
=
True
,
use_custom_allreduce
=
use_custom_allreduce
,
use_custom_allreduce
=
use_custom_allreduce
,
use_tpu_communicator
=
True
,
use_message_queue_broadcaster
=
use_message_queue_broadcaster
,
use_message_queue_broadcaster
=
use_message_queue_broadcaster
,
)
)
...
...
vllm/lora/layers.py
View file @
d09b94ca
...
@@ -1067,6 +1067,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
...
@@ -1067,6 +1067,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def
soft_cap
(
self
):
def
soft_cap
(
self
):
return
self
.
base_layer
.
soft_cap
return
self
.
base_layer
.
soft_cap
@
property
def
use_gather
(
self
):
return
self
.
base_layer
.
use_gather
@
property
@
property
def
org_vocab_size
(
self
):
def
org_vocab_size
(
self
):
return
self
.
base_layer
.
org_vocab_size
return
self
.
base_layer
.
org_vocab_size
...
...
vllm/model_executor/layers/logits_processor.py
View file @
d09b94ca
...
@@ -5,10 +5,12 @@ from typing import Optional
...
@@ -5,10 +5,12 @@ from typing import Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.distributed
import
tensor_model_parallel_gather
from
vllm.distributed
import
(
tensor_model_parallel_all_gather
,
tensor_model_parallel_gather
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.platforms
import
current_platform
class
LogitsProcessor
(
nn
.
Module
):
class
LogitsProcessor
(
nn
.
Module
):
...
@@ -39,6 +41,8 @@ class LogitsProcessor(nn.Module):
...
@@ -39,6 +41,8 @@ class LogitsProcessor(nn.Module):
self
.
org_vocab_size
=
org_vocab_size
or
vocab_size
self
.
org_vocab_size
=
org_vocab_size
or
vocab_size
# Soft cap the logits. Used in Gemma 2.
# Soft cap the logits. Used in Gemma 2.
self
.
soft_cap
=
soft_cap
self
.
soft_cap
=
soft_cap
# Whether to use gather or all-gather to gather the logits.
self
.
use_gather
=
not
current_platform
.
is_tpu
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -76,7 +80,15 @@ class LogitsProcessor(nn.Module):
...
@@ -76,7 +80,15 @@ class LogitsProcessor(nn.Module):
logits
=
lm_head
.
linear_method
.
apply
(
lm_head
,
logits
=
lm_head
.
linear_method
.
apply
(
lm_head
,
hidden_states
,
hidden_states
,
bias
=
embedding_bias
)
bias
=
embedding_bias
)
if
self
.
use_gather
:
logits
=
tensor_model_parallel_gather
(
logits
)
logits
=
tensor_model_parallel_gather
(
logits
)
else
:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
# NOTE(woosuk): Here, the outputs of every device should not be None
# because XLA requires strict SPMD among all devices. Every device
# should execute the same operations after gathering the logits.
logits
=
tensor_model_parallel_all_gather
(
logits
)
# Remove paddings in vocab (if any).
# Remove paddings in vocab (if any).
if
logits
is
not
None
:
if
logits
is
not
None
:
logits
=
logits
[:,
:
self
.
org_vocab_size
]
logits
=
logits
[:,
:
self
.
org_vocab_size
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment