Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
80c7b089
"vscode:/vscode.git/clone" did not exist on "586f0eba8225d7c0a358c8a286e4e13b9739dfd7"
Unverified
Commit
80c7b089
authored
Aug 29, 2024
by
Woosuk Kwon
Committed by
GitHub
Aug 29, 2024
Browse files
[TPU] Async output processing for TPU (#8011)
parent
428dd144
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
4 deletions
+10
-4
vllm/config.py
vllm/config.py
+3
-3
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+7
-1
No files found.
vllm/config.py
View file @
80c7b089
...
...
@@ -347,10 +347,10 @@ class ModelConfig:
self
.
use_async_output_proc
=
False
return
if
device_config
.
device_type
!=
"cuda"
:
if
device_config
.
device_type
not
in
(
"cuda"
,
"tpu"
)
:
logger
.
warning
(
"Async output processing is only supported for CUDA
.
"
"
Disabling it for other platforms."
)
"Async output processing is only supported for CUDA
or TPU.
"
"Disabling it for other platforms."
)
self
.
use_async_output_proc
=
False
return
...
...
vllm/worker/tpu_model_runner.py
View file @
80c7b089
import
time
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
)
from
unittest.mock
import
patch
import
numpy
as
np
...
...
@@ -51,6 +52,7 @@ class ModelInputForTPU(ModelRunnerInputBase):
best_of
:
List
[
int
]
seq_groups
:
List
[
List
[
int
]]
virtual_engine
:
int
=
0
async_callback
:
Optional
[
Callable
]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Union
[
int
,
torch
.
Tensor
]]:
...
...
@@ -562,6 +564,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
model_input
.
attn_metadata
,
model_input
.
input_lens
[
i
:
i
+
1
],
model_input
.
t
[
i
:
i
+
1
],
model_input
.
p
[
i
:
i
+
1
],
model_input
.
num_samples
,
kv_caches
)
if
i
==
0
and
model_input
.
async_callback
is
not
None
:
model_input
.
async_callback
()
# Retrieve the outputs to CPU.
next_token_ids
+=
output_token_ids
.
cpu
().
tolist
()
start_idx
=
end_idx
...
...
@@ -572,6 +576,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
model_input
.
attn_metadata
,
model_input
.
input_lens
,
model_input
.
t
,
model_input
.
p
,
model_input
.
num_samples
,
kv_caches
)
if
model_input
.
async_callback
is
not
None
:
model_input
.
async_callback
()
# Retrieve the outputs to CPU.
next_token_ids
=
output_token_ids
.
cpu
().
tolist
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment