Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e18227b0
Unverified
Commit
e18227b0
authored
Feb 16, 2025
by
Woosuk Kwon
Committed by
GitHub
Feb 16, 2025
Browse files
[V1][PP] Cache Intermediate Tensors (#13353)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
7b893865
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
8 deletions
+28
-8
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+28
-8
No files found.
vllm/v1/worker/gpu_model_runner.py
View file @
e18227b0
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
import
gc
import
gc
import
time
import
time
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
cast
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
cast
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -149,6 +149,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -149,6 +149,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
positions
=
torch
.
zeros
(
self
.
max_num_tokens
,
self
.
positions
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
device
=
self
.
device
)
# self.intermediate_tensors # Set after load_model
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
...
@@ -869,7 +870,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -869,7 +870,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
,
self
,
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
ModelRunnerOutput
:
)
->
Union
[
ModelRunnerOutput
,
torch
.
Tensor
]
:
batch_changed
=
self
.
_update_states
(
scheduler_output
)
batch_changed
=
self
.
_update_states
(
scheduler_output
)
if
self
.
is_multimodal_model
:
if
self
.
is_multimodal_model
:
...
@@ -919,6 +920,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -919,6 +920,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else
:
else
:
positions
=
self
.
positions
[:
num_input_tokens
]
positions
=
self
.
positions
[:
num_input_tokens
]
if
get_pp_group
().
is_first_rank
:
intermediate_tensors
=
None
else
:
intermediate_tensors
=
IntermediateTensors
({
k
:
v
[:
num_input_tokens
]
for
k
,
v
in
self
.
intermediate_tensors
.
items
()
})
# Run the decoder.
# Run the decoder.
# Use persistent buffers for CUDA graphs.
# Use persistent buffers for CUDA graphs.
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
):
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
):
...
@@ -931,7 +940,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -931,7 +940,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
)
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
# For mid-pipeline stages, return the hidden states.
return
hidden_states
return
hidden_states
hidden_states
=
hidden_states
[:
num_scheduled_tokens
]
hidden_states
=
hidden_states
[:
num_scheduled_tokens
]
sample_hidden_states
=
hidden_states
[
logits_indices
]
sample_hidden_states
=
hidden_states
[
logits_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
...
@@ -1118,12 +1129,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1118,12 +1129,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
positions
=
self
.
mrope_positions
[:,
:
num_tokens
]
positions
=
self
.
mrope_positions
[:,
:
num_tokens
]
else
:
else
:
positions
=
self
.
positions
[:
num_tokens
]
positions
=
self
.
positions
[:
num_tokens
]
if
get_pp_group
().
is_first_rank
:
intermediate_tensors
=
None
intermediate_tensors
=
None
if
not
get_pp_group
().
is_first_rank
:
else
:
intermediate_tensors
=
self
.
model
.
make_empty_intermediate_tensors
(
if
not
hasattr
(
self
,
"intermediate_tensors"
):
batch_size
=
num_tokens
,
self
.
intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
(
batch_size
=
self
.
max_num_tokens
,
dtype
=
self
.
model_config
.
dtype
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
)
device
=
self
.
device
))
intermediate_tensors
=
IntermediateTensors
({
k
:
v
[:
num_tokens
]
for
k
,
v
in
self
.
intermediate_tensors
.
items
()
})
with
set_forward_context
(
None
,
self
.
vllm_config
):
with
set_forward_context
(
None
,
self
.
vllm_config
):
hidden_states
=
model
(
hidden_states
=
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment