Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6ac485a9
Unverified
Commit
6ac485a9
authored
Feb 17, 2025
by
Cody Yu
Committed by
GitHub
Feb 17, 2025
Browse files
[V1][PP] Fix intermediate tensor values (#13417)
Signed-off-by:
Cody Yu
<
hao.yu.cody@gmail.com
>
parent
4c21ce9e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
2 deletions
+11
-2
vllm/sequence.py
vllm/sequence.py
+3
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+8
-2
No files found.
vllm/sequence.py
View file @
6ac485a9
...
@@ -1137,6 +1137,9 @@ class IntermediateTensors:
...
@@ -1137,6 +1137,9 @@ class IntermediateTensors:
def
__setitem__
(
self
,
key
:
str
,
value
:
torch
.
Tensor
):
def
__setitem__
(
self
,
key
:
str
,
value
:
torch
.
Tensor
):
self
.
tensors
[
key
]
=
value
self
.
tensors
[
key
]
=
value
def
items
(
self
):
return
self
.
tensors
.
items
()
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
tensors
)
return
len
(
self
.
tensors
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
6ac485a9
...
@@ -151,7 +151,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -151,7 +151,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
positions
=
torch
.
zeros
(
self
.
max_num_tokens
,
self
.
positions
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
device
=
self
.
device
)
# self.intermediate_tensors # Set after load_model
# None in the first PP rank. The rest are set after load_model.
self
.
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
...
@@ -922,6 +923,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -922,6 +923,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
intermediate_tensors
=
None
intermediate_tensors
=
None
else
:
else
:
assert
intermediate_tensors
is
not
None
assert
self
.
intermediate_tensors
is
not
None
for
k
,
v
in
intermediate_tensors
.
items
():
self
.
intermediate_tensors
[
k
][:
num_input_tokens
].
copy_
(
v
[:
num_input_tokens
],
non_blocking
=
True
)
intermediate_tensors
=
IntermediateTensors
({
intermediate_tensors
=
IntermediateTensors
({
k
:
v
[:
num_input_tokens
]
k
:
v
[:
num_input_tokens
]
for
k
,
v
in
self
.
intermediate_tensors
.
items
()
for
k
,
v
in
self
.
intermediate_tensors
.
items
()
...
@@ -1120,7 +1126,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1120,7 +1126,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
intermediate_tensors
=
None
intermediate_tensors
=
None
else
:
else
:
if
not
hasattr
(
self
,
"
intermediate_tensors
"
)
:
if
self
.
intermediate_tensors
is
None
:
self
.
intermediate_tensors
=
(
self
.
intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
(
self
.
model
.
make_empty_intermediate_tensors
(
batch_size
=
self
.
max_num_tokens
,
batch_size
=
self
.
max_num_tokens
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment