Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5a30bd10
Unverified
Commit
5a30bd10
authored
Aug 18, 2025
by
Ning Xie
Committed by
GitHub
Aug 18, 2025
Browse files
[Bugfix] fix IntermediateTensors equal method (#23027)
Signed-off-by:
Andy Xie
<
andy.xning@gmail.com
>
parent
27e8d1ea
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
3 deletions
+45
-3
tests/test_sequence.py
tests/test_sequence.py
+38
-2
vllm/sequence.py
vllm/sequence.py
+7
-1
No files found.
tests/test_sequence.py
View file @
5a30bd10
...
...
@@ -2,10 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
SequenceData
,
SequenceOutput
)
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
IntermediateTensors
,
SequenceData
,
SequenceOutput
)
from
.core.utils
import
create_dummy_prompt
...
...
@@ -98,3 +99,38 @@ def test_sequence_group_stage():
assert
seq_group
.
is_prefill
()
is
True
seq_group
.
update_num_computed_tokens
(
1
)
assert
seq_group
.
is_prefill
()
is
False
def
test_sequence_intermediate_tensors_equal
():
class
AnotherIntermediateTensors
(
IntermediateTensors
):
pass
intermediate_tensors
=
IntermediateTensors
({})
another_intermediate_tensors
=
AnotherIntermediateTensors
({})
assert
intermediate_tensors
!=
another_intermediate_tensors
empty_intermediate_tensors_1
=
IntermediateTensors
({})
empty_intermediate_tensors_2
=
IntermediateTensors
({})
assert
empty_intermediate_tensors_1
==
empty_intermediate_tensors_2
different_key_intermediate_tensors_1
=
IntermediateTensors
(
{
"1"
:
torch
.
zeros
([
2
,
4
],
dtype
=
torch
.
int32
)})
difference_key_intermediate_tensors_2
=
IntermediateTensors
(
{
"2"
:
torch
.
zeros
([
2
,
4
],
dtype
=
torch
.
int32
)})
assert
(
different_key_intermediate_tensors_1
!=
difference_key_intermediate_tensors_2
)
same_key_different_value_intermediate_tensors_1
=
IntermediateTensors
(
{
"1"
:
torch
.
zeros
([
2
,
4
],
dtype
=
torch
.
int32
)})
same_key_different_value_intermediate_tensors_2
=
IntermediateTensors
(
{
"1"
:
torch
.
zeros
([
2
,
5
],
dtype
=
torch
.
int32
)})
assert
(
same_key_different_value_intermediate_tensors_1
!=
same_key_different_value_intermediate_tensors_2
)
same_key_same_value_intermediate_tensors_1
=
IntermediateTensors
(
{
"1"
:
torch
.
zeros
([
2
,
4
],
dtype
=
torch
.
int32
)})
same_key_same_value_intermediate_tensors_2
=
IntermediateTensors
(
{
"1"
:
torch
.
zeros
([
2
,
4
],
dtype
=
torch
.
int32
)})
assert
(
same_key_same_value_intermediate_tensors_1
==
same_key_same_value_intermediate_tensors_2
)
vllm/sequence.py
View file @
5a30bd10
...
...
@@ -1163,7 +1163,13 @@ class IntermediateTensors:
return
len
(
self
.
tensors
)
def
__eq__
(
self
,
other
:
object
):
return
isinstance
(
other
,
self
.
__class__
)
and
self
if
not
isinstance
(
other
,
self
.
__class__
):
return
False
if
self
.
tensors
.
keys
()
!=
other
.
tensors
.
keys
():
return
False
return
all
(
torch
.
equal
(
self
.
tensors
[
k
],
other
.
tensors
[
k
])
for
k
in
self
.
tensors
)
def
__repr__
(
self
)
->
str
:
return
f
"IntermediateTensors(tensors=
{
self
.
tensors
}
)"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment