Unverified Commit 44554a00 authored by Wang Yijun's avatar Wang Yijun Committed by GitHub
Browse files

Add tokenization_kwargs to encode for embedding model truncation (#21033)

parent 226b452a
...@@ -438,6 +438,7 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -438,6 +438,7 @@ class _AsyncLLMEngine(LLMEngine):
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
data_parallel_rank: Optional[int] = None, data_parallel_rank: Optional[int] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> None: ) -> None:
""" """
Async version of Async version of
...@@ -468,6 +469,7 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -468,6 +469,7 @@ class _AsyncLLMEngine(LLMEngine):
prompt, prompt,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
tokenization_kwargs=tokenization_kwargs,
) )
if isinstance(params, SamplingParams) and \ if isinstance(params, SamplingParams) and \
...@@ -862,6 +864,7 @@ class AsyncLLMEngine(EngineClient): ...@@ -862,6 +864,7 @@ class AsyncLLMEngine(EngineClient):
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
data_parallel_rank: Optional[int] = None, data_parallel_rank: Optional[int] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
if not self.is_running: if not self.is_running:
if self.start_engine_loop: if self.start_engine_loop:
...@@ -889,6 +892,7 @@ class AsyncLLMEngine(EngineClient): ...@@ -889,6 +892,7 @@ class AsyncLLMEngine(EngineClient):
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
priority=priority, priority=priority,
data_parallel_rank=data_parallel_rank, data_parallel_rank=data_parallel_rank,
tokenization_kwargs=tokenization_kwargs,
) )
return stream.generator() return stream.generator()
...@@ -996,6 +1000,7 @@ class AsyncLLMEngine(EngineClient): ...@@ -996,6 +1000,7 @@ class AsyncLLMEngine(EngineClient):
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[PoolingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model. """Generate outputs for a request from a pooling model.
...@@ -1070,6 +1075,7 @@ class AsyncLLMEngine(EngineClient): ...@@ -1070,6 +1075,7 @@ class AsyncLLMEngine(EngineClient):
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=priority, priority=priority,
tokenization_kwargs=tokenization_kwargs,
): ):
yield LLMEngine.validate_output(output, PoolingRequestOutput) yield LLMEngine.validate_output(output, PoolingRequestOutput)
except asyncio.CancelledError: except asyncio.CancelledError:
......
...@@ -965,6 +965,7 @@ class LLM: ...@@ -965,6 +965,7 @@ class LLM:
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
... ...
...@@ -981,6 +982,7 @@ class LLM: ...@@ -981,6 +982,7 @@ class LLM:
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
... ...
...@@ -997,6 +999,7 @@ class LLM: ...@@ -997,6 +999,7 @@ class LLM:
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
... ...
...@@ -1014,6 +1017,7 @@ class LLM: ...@@ -1014,6 +1017,7 @@ class LLM:
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
... ...
...@@ -1031,6 +1035,7 @@ class LLM: ...@@ -1031,6 +1035,7 @@ class LLM:
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
... ...
...@@ -1046,6 +1051,7 @@ class LLM: ...@@ -1046,6 +1051,7 @@ class LLM:
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
... ...
...@@ -1066,6 +1072,7 @@ class LLM: ...@@ -1066,6 +1072,7 @@ class LLM:
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
"""Apply pooling to the hidden states corresponding to the input """Apply pooling to the hidden states corresponding to the input
prompts. prompts.
...@@ -1131,9 +1138,11 @@ class LLM: ...@@ -1131,9 +1138,11 @@ class LLM:
for pooling_param in pooling_params: for pooling_param in pooling_params:
pooling_param.verify(pooling_task, model_config) pooling_param.verify(pooling_task, model_config)
if tokenization_kwargs is None:
tokenization_kwargs = dict[str, Any]() tokenization_kwargs = dict[str, Any]()
_validate_truncation_size(model_config.max_model_len, _validate_truncation_size(model_config.max_model_len,
truncate_prompt_tokens, tokenization_kwargs) truncate_prompt_tokens,
tokenization_kwargs)
self._validate_and_add_requests( self._validate_and_add_requests(
prompts=parsed_prompts, prompts=parsed_prompts,
......
...@@ -437,6 +437,7 @@ class AsyncLLM(EngineClient): ...@@ -437,6 +437,7 @@ class AsyncLLM(EngineClient):
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[PoolingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
""" """
Main function called by the API server to kick off a request Main function called by the API server to kick off a request
...@@ -465,6 +466,7 @@ class AsyncLLM(EngineClient): ...@@ -465,6 +466,7 @@ class AsyncLLM(EngineClient):
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=priority, priority=priority,
tokenization_kwargs=tokenization_kwargs,
) )
# The output_handler task pushes items into the queue. # The output_handler task pushes items into the queue.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment