Unverified Commit 59c62332 authored by Li's avatar Li Committed by GitHub
Browse files

Support prompt_embeds for pooling requests in output processor (#34904)


Signed-off-by: default avatarLi Zhang <lzhanga@amazon.com>
Co-authored-by: default avatarLi Zhang <lzhanga@amazon.com>
parent d38cd3dd
...@@ -337,16 +337,20 @@ class RequestState: ...@@ -337,16 +337,20 @@ class RequestState:
finished: bool, finished: bool,
kv_transfer_params: dict[str, Any] | None = None, kv_transfer_params: dict[str, Any] | None = None,
) -> RequestOutput | PoolingRequestOutput: ) -> RequestOutput | PoolingRequestOutput:
# If prompt embeds were used, put placeholder prompt token ids
prompt_token_ids = self.prompt_token_ids
if prompt_token_ids is None and self.prompt_embeds is not None:
prompt_token_ids = [0] * len(self.prompt_embeds)
assert prompt_token_ids is not None
first_output = outputs[0] first_output = outputs[0]
if isinstance(first_output, PoolingOutput): if isinstance(first_output, PoolingOutput):
assert len(outputs) == 1 assert len(outputs) == 1
# Prompt embeddings are currently not supported by pooling requests.
assert self.prompt_token_ids is not None
return PoolingRequestOutput( return PoolingRequestOutput(
request_id=external_req_id, request_id=external_req_id,
outputs=first_output, outputs=first_output,
num_cached_tokens=self.num_cached_tokens, num_cached_tokens=self.num_cached_tokens,
prompt_token_ids=self.prompt_token_ids, prompt_token_ids=prompt_token_ids,
finished=finished, finished=finished,
) )
assert self.logprobs_processor is not None assert self.logprobs_processor is not None
...@@ -356,11 +360,6 @@ class RequestState: ...@@ -356,11 +360,6 @@ class RequestState:
else: else:
prompt_logprobs = self.logprobs_processor.prompt_logprobs prompt_logprobs = self.logprobs_processor.prompt_logprobs
# If prompt embeds were used, put placeholder prompt token ids
prompt_token_ids = self.prompt_token_ids
if prompt_token_ids is None and self.prompt_embeds is not None:
prompt_token_ids = [0] * len(self.prompt_embeds)
return RequestOutput( return RequestOutput(
request_id=external_req_id, # request_id is what was provided externally request_id=external_req_id, # request_id is what was provided externally
lora_request=self.lora_request, lora_request=self.lora_request,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment