Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
43735bf5
Unverified
Commit
43735bf5
authored
Aug 19, 2024
by
Woosuk Kwon
Committed by
GitHub
Aug 19, 2024
Browse files
[TPU] Remove redundant input tensor cloning (#7660)
parent
da115230
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
20 deletions
+8
-20
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+8
-20
No files found.
vllm/worker/tpu_model_runner.py
View file @
43735bf5
...
...
@@ -516,27 +516,19 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
raise
ValueError
(
"TPUModelRunner does not support multi-step execution."
)
def
_execute_model
(
*
args
,
clone
:
bool
=
False
)
->
torch
.
Tensor
:
def
_execute_model
(
*
args
)
:
"""Move input args from CPU to device and execute the model."""
def
_copy_to_device
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
clone
:
# When x is a slice of a CPU tensor, XLA may copy the whole
# original tensor to TPU instead of only copying x.
# To avoid this, we copy x after cloning.
x
=
x
.
clone
()
return
x
.
to
(
self
.
device
)
new_args
=
[]
for
arg
in
args
:
if
isinstance
(
arg
,
torch
.
Tensor
):
arg
=
_copy_to_
device
(
arg
)
arg
=
arg
.
to
(
self
.
device
)
elif
isinstance
(
arg
,
AttentionMetadata
):
arg
.
slot_mapping
=
_copy_to_device
(
arg
.
slot_mapping
)
arg
.
slot_mapping
=
arg
.
slot_mapping
.
to
(
self
.
device
)
if
getattr
(
arg
,
"block_tables"
,
None
)
is
not
None
:
arg
.
block_tables
=
_copy_to_device
(
arg
.
block_tables
)
arg
.
block_tables
=
arg
.
block_tables
.
to
(
self
.
device
)
if
getattr
(
arg
,
"context_lens"
,
None
)
is
not
None
:
arg
.
context_lens
=
_copy_to_device
(
arg
.
context_lens
)
arg
.
context_lens
=
arg
.
context_lens
.
to
(
self
.
device
)
new_args
.
append
(
arg
)
return
self
.
model
(
*
new_args
)
...
...
@@ -563,13 +555,9 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
output_token_ids
=
_execute_model
(
model_input
.
token_ids
[
None
,
start_idx
:
end_idx
],
model_input
.
position_ids
[
None
,
start_idx
:
end_idx
],
model_input
.
attn_metadata
,
model_input
.
input_lens
[
i
:
i
+
1
],
model_input
.
t
[
i
:
i
+
1
],
model_input
.
p
[
i
:
i
+
1
],
model_input
.
num_samples
,
kv_caches
,
clone
=
True
)
model_input
.
attn_metadata
,
model_input
.
input_lens
[
i
:
i
+
1
],
model_input
.
t
[
i
:
i
+
1
],
model_input
.
p
[
i
:
i
+
1
],
model_input
.
num_samples
,
kv_caches
)
# Retrieve the outputs to CPU.
next_token_ids
+=
output_token_ids
.
cpu
().
tolist
()
start_idx
=
end_idx
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment