Unverified Commit eeec9e33 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Separate pooling APIs in offline inference (#11129)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f93bf2b1
......@@ -617,10 +617,9 @@ class SequenceGroup:
sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request.
lora_request: LoRA request.
embeddings: The embeddings vectors of the prompt of the sequence group
for a pooling model.
pooling_params: The pooling parameters used to generate the pooling
pooling_params: The parameters used to generate the pooler
for a pooling model.
pooled_data: The extracted hidden states from a pooling model.
encoder_seq: Optional, the single encoder sequence. Should be None
unless you are working with an encoder/decoder model.
trace_headers: OpenTelemetry trace headers.
......@@ -635,8 +634,8 @@ class SequenceGroup:
arrival_time: float,
sampling_params: Optional[SamplingParams] = None,
lora_request: Optional[LoRARequest] = None,
embeddings: Optional[List[float]] = None,
pooling_params: Optional[PoolingParams] = None,
pooled_data: Optional[torch.Tensor] = None,
encoder_seq: Optional[Sequence] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
......@@ -658,8 +657,8 @@ class SequenceGroup:
self.lora_request = lora_request
self.prompt_logprobs: Optional[PromptLogprobs] = None
self.state = SequenceGroupState()
self.embeddings = embeddings
self.pooling_params = pooling_params
self.pooled_data = pooled_data
self.prompt_adapter_request = prompt_adapter_request
self.encoder_seq = encoder_seq
self.trace_headers = trace_headers
......@@ -1033,8 +1032,8 @@ class CompletionSequenceGroupOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
__metaclass__ = SequenceGroupOutput
"""The model output associated with a completion sequence group."""
__metaclass__ = SequenceGroupOutput
samples: List[SequenceOutput]
# Prompt logprob for each prompt query token.
prompt_logprobs: Optional[PromptLogprobs]
......@@ -1050,23 +1049,24 @@ class CompletionSequenceGroupOutput(
and self.prompt_logprobs == other.prompt_logprobs)
class EmbeddingSequenceGroupOutput(
class PoolingSequenceGroupOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True, # type: ignore[call-arg]
):
"""The model output associated with an embedding sequence group."""
"""The model output associated with a pooling sequence group."""
__metaclass__ = SequenceGroupOutput
embeddings: List[int]
# Annotated as Any to be compatible with msgspec
# The actual type is in SequenceGroup.pooled_data
data: Any
def __repr__(self) -> str:
return (f"EmbeddingSequenceGroupOutput("
f"embeddings_shape={len(self.embeddings)})")
return f"PoolingSequenceGroupOutput(data={self.data}"
def __eq__(self, other: object) -> bool:
if not isinstance(other, EmbeddingSequenceGroupOutput):
if not isinstance(other, PoolingSequenceGroupOutput):
raise NotImplementedError()
return self.embeddings == other.embeddings
return self.data == other.data
# cannot use msgspec.Struct here because Dynamo does not support it
......@@ -1085,7 +1085,7 @@ class IntermediateTensors:
elif isinstance(key, slice):
return self.__class__({k: v[key] for k, v in self.tensors.items()})
def __setitem__(self, key: str, value):
def __setitem__(self, key: str, value: torch.Tensor):
self.tensors[key] = value
def __len__(self):
......@@ -1103,16 +1103,12 @@ class PoolerOutput(
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""The output from a pooling operation in the pooling model."""
outputs: List[EmbeddingSequenceGroupOutput]
# lazy import to avoid circular import
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
outputs: List[PoolingSequenceGroupOutput]
def __getitem__(self, idx: int) -> EmbeddingSequenceGroupOutput:
def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
return self.outputs[idx]
def __setitem__(self, idx: int, value):
def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput):
self.outputs[idx] = value
def __len__(self):
......@@ -1385,8 +1381,8 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
arrival_time=seq_group.arrival_time,
sampling_params=original_params,
lora_request=seq_group.lora_request,
embeddings=seq_group.embeddings,
pooling_params=seq_group.pooling_params,
pooled_data=seq_group.pooled_data,
encoder_seq=seq_group.encoder_seq,
trace_headers=seq_group.trace_headers,
prompt_adapter_request=seq_group.prompt_adapter_request,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment