Unverified Commit eeec9e33 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Separate pooling APIs in offline inference (#11129)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f93bf2b1
...@@ -617,10 +617,9 @@ class SequenceGroup: ...@@ -617,10 +617,9 @@ class SequenceGroup:
sampling_params: The sampling parameters used to generate the outputs. sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request. arrival_time: The arrival time of the request.
lora_request: LoRA request. lora_request: LoRA request.
embeddings: The embeddings vectors of the prompt of the sequence group pooling_params: The parameters used to generate the pooler
for a pooling model.
pooling_params: The pooling parameters used to generate the pooling
for a pooling model. for a pooling model.
pooled_data: The extracted hidden states from a pooling model.
encoder_seq: Optional, the single encoder sequence. Should be None encoder_seq: Optional, the single encoder sequence. Should be None
unless you are working with an encoder/decoder model. unless you are working with an encoder/decoder model.
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
...@@ -635,8 +634,8 @@ class SequenceGroup: ...@@ -635,8 +634,8 @@ class SequenceGroup:
arrival_time: float, arrival_time: float,
sampling_params: Optional[SamplingParams] = None, sampling_params: Optional[SamplingParams] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
embeddings: Optional[List[float]] = None,
pooling_params: Optional[PoolingParams] = None, pooling_params: Optional[PoolingParams] = None,
pooled_data: Optional[torch.Tensor] = None,
encoder_seq: Optional[Sequence] = None, encoder_seq: Optional[Sequence] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
...@@ -658,8 +657,8 @@ class SequenceGroup: ...@@ -658,8 +657,8 @@ class SequenceGroup:
self.lora_request = lora_request self.lora_request = lora_request
self.prompt_logprobs: Optional[PromptLogprobs] = None self.prompt_logprobs: Optional[PromptLogprobs] = None
self.state = SequenceGroupState() self.state = SequenceGroupState()
self.embeddings = embeddings
self.pooling_params = pooling_params self.pooling_params = pooling_params
self.pooled_data = pooled_data
self.prompt_adapter_request = prompt_adapter_request self.prompt_adapter_request = prompt_adapter_request
self.encoder_seq = encoder_seq self.encoder_seq = encoder_seq
self.trace_headers = trace_headers self.trace_headers = trace_headers
...@@ -1033,8 +1032,8 @@ class CompletionSequenceGroupOutput( ...@@ -1033,8 +1032,8 @@ class CompletionSequenceGroupOutput(
msgspec.Struct, msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg] omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg] array_like=True): # type: ignore[call-arg]
__metaclass__ = SequenceGroupOutput
"""The model output associated with a completion sequence group.""" """The model output associated with a completion sequence group."""
__metaclass__ = SequenceGroupOutput
samples: List[SequenceOutput] samples: List[SequenceOutput]
# Prompt logprob for each prompt query token. # Prompt logprob for each prompt query token.
prompt_logprobs: Optional[PromptLogprobs] prompt_logprobs: Optional[PromptLogprobs]
...@@ -1050,23 +1049,24 @@ class CompletionSequenceGroupOutput( ...@@ -1050,23 +1049,24 @@ class CompletionSequenceGroupOutput(
and self.prompt_logprobs == other.prompt_logprobs) and self.prompt_logprobs == other.prompt_logprobs)
class EmbeddingSequenceGroupOutput( class PoolingSequenceGroupOutput(
msgspec.Struct, msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg] omit_defaults=True, # type: ignore[call-arg]
array_like=True, # type: ignore[call-arg] array_like=True, # type: ignore[call-arg]
): ):
"""The model output associated with an embedding sequence group.""" """The model output associated with a pooling sequence group."""
__metaclass__ = SequenceGroupOutput __metaclass__ = SequenceGroupOutput
embeddings: List[int] # Annotated as Any to be compatible with msgspec
# The actual type is in SequenceGroup.pooled_data
data: Any
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"EmbeddingSequenceGroupOutput(" return f"PoolingSequenceGroupOutput(data={self.data}"
f"embeddings_shape={len(self.embeddings)})")
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, EmbeddingSequenceGroupOutput): if not isinstance(other, PoolingSequenceGroupOutput):
raise NotImplementedError() raise NotImplementedError()
return self.embeddings == other.embeddings return self.data == other.data
# cannot use msgspec.Struct here because Dynamo does not support it # cannot use msgspec.Struct here because Dynamo does not support it
...@@ -1085,7 +1085,7 @@ class IntermediateTensors: ...@@ -1085,7 +1085,7 @@ class IntermediateTensors:
elif isinstance(key, slice): elif isinstance(key, slice):
return self.__class__({k: v[key] for k, v in self.tensors.items()}) return self.__class__({k: v[key] for k, v in self.tensors.items()})
def __setitem__(self, key: str, value): def __setitem__(self, key: str, value: torch.Tensor):
self.tensors[key] = value self.tensors[key] = value
def __len__(self): def __len__(self):
...@@ -1103,16 +1103,12 @@ class PoolerOutput( ...@@ -1103,16 +1103,12 @@ class PoolerOutput(
omit_defaults=True, # type: ignore[call-arg] omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg] array_like=True): # type: ignore[call-arg]
"""The output from a pooling operation in the pooling model.""" """The output from a pooling operation in the pooling model."""
outputs: List[EmbeddingSequenceGroupOutput] outputs: List[PoolingSequenceGroupOutput]
# lazy import to avoid circular import
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
def __getitem__(self, idx: int) -> EmbeddingSequenceGroupOutput: def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
return self.outputs[idx] return self.outputs[idx]
def __setitem__(self, idx: int, value): def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput):
self.outputs[idx] = value self.outputs[idx] = value
def __len__(self): def __len__(self):
...@@ -1385,8 +1381,8 @@ class ParallelSampleSequenceGroup(SequenceGroupBase): ...@@ -1385,8 +1381,8 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
arrival_time=seq_group.arrival_time, arrival_time=seq_group.arrival_time,
sampling_params=original_params, sampling_params=original_params,
lora_request=seq_group.lora_request, lora_request=seq_group.lora_request,
embeddings=seq_group.embeddings,
pooling_params=seq_group.pooling_params, pooling_params=seq_group.pooling_params,
pooled_data=seq_group.pooled_data,
encoder_seq=seq_group.encoder_seq, encoder_seq=seq_group.encoder_seq,
trace_headers=seq_group.trace_headers, trace_headers=seq_group.trace_headers,
prompt_adapter_request=seq_group.prompt_adapter_request, prompt_adapter_request=seq_group.prompt_adapter_request,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment