Unverified Commit 5a30bd10 authored by Ning Xie's avatar Ning Xie Committed by GitHub
Browse files

[Bugfix] fix IntermediateTensors equal method (#23027)


Signed-off-by: default avatarAndy Xie <andy.xning@gmail.com>
parent 27e8d1ea
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
import torch
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (CompletionSequenceGroupOutput, SequenceData, from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
SequenceOutput) SequenceData, SequenceOutput)
from .core.utils import create_dummy_prompt from .core.utils import create_dummy_prompt
...@@ -98,3 +99,38 @@ def test_sequence_group_stage(): ...@@ -98,3 +99,38 @@ def test_sequence_group_stage():
assert seq_group.is_prefill() is True assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(1) seq_group.update_num_computed_tokens(1)
assert seq_group.is_prefill() is False assert seq_group.is_prefill() is False
def test_sequence_intermediate_tensors_equal():
class AnotherIntermediateTensors(IntermediateTensors):
pass
intermediate_tensors = IntermediateTensors({})
another_intermediate_tensors = AnotherIntermediateTensors({})
assert intermediate_tensors != another_intermediate_tensors
empty_intermediate_tensors_1 = IntermediateTensors({})
empty_intermediate_tensors_2 = IntermediateTensors({})
assert empty_intermediate_tensors_1 == empty_intermediate_tensors_2
different_key_intermediate_tensors_1 = IntermediateTensors(
{"1": torch.zeros([2, 4], dtype=torch.int32)})
difference_key_intermediate_tensors_2 = IntermediateTensors(
{"2": torch.zeros([2, 4], dtype=torch.int32)})
assert (different_key_intermediate_tensors_1
!= difference_key_intermediate_tensors_2)
same_key_different_value_intermediate_tensors_1 = IntermediateTensors(
{"1": torch.zeros([2, 4], dtype=torch.int32)})
same_key_different_value_intermediate_tensors_2 = IntermediateTensors(
{"1": torch.zeros([2, 5], dtype=torch.int32)})
assert (same_key_different_value_intermediate_tensors_1
!= same_key_different_value_intermediate_tensors_2)
same_key_same_value_intermediate_tensors_1 = IntermediateTensors(
{"1": torch.zeros([2, 4], dtype=torch.int32)})
same_key_same_value_intermediate_tensors_2 = IntermediateTensors(
{"1": torch.zeros([2, 4], dtype=torch.int32)})
assert (same_key_same_value_intermediate_tensors_1 ==
same_key_same_value_intermediate_tensors_2)
...@@ -1163,7 +1163,13 @@ class IntermediateTensors: ...@@ -1163,7 +1163,13 @@ class IntermediateTensors:
return len(self.tensors) return len(self.tensors)
def __eq__(self, other: object): def __eq__(self, other: object):
return isinstance(other, self.__class__) and self if not isinstance(other, self.__class__):
return False
if self.tensors.keys() != other.tensors.keys():
return False
return all(
torch.equal(self.tensors[k], other.tensors[k])
for k in self.tensors)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"IntermediateTensors(tensors={self.tensors})" return f"IntermediateTensors(tensors={self.tensors})"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment