Unverified Commit e54ee3ea authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Core] Deduplicate generate/encode logic in `AsyncLLM` (#31510)


Signed-off-by: default avatarnjhill <nickhill123@gmail.com>
parent 358bfd31
...@@ -281,6 +281,25 @@ class AsyncLLM(EngineClient): ...@@ -281,6 +281,25 @@ class AsyncLLM(EngineClient):
is_pooling = isinstance(params, PoolingParams) is_pooling = isinstance(params, PoolingParams)
if (
self.vllm_config.cache_config.kv_sharing_fast_prefill
and not is_pooling
and params.prompt_logprobs
):
raise ValueError(
"--kv-sharing-fast-prefill produces incorrect logprobs for "
"prompt tokens, please disable it when the requests need "
"prompt logprobs"
)
if tokenization_kwargs is None:
tokenization_kwargs = {}
_validate_truncation_size(
self.model_config.max_model_len,
params.truncate_prompt_tokens,
tokenization_kwargs,
)
# Convert Input --> Request. # Convert Input --> Request.
if isinstance(prompt, EngineCoreRequest): if isinstance(prompt, EngineCoreRequest):
request = prompt request = prompt
...@@ -291,7 +310,10 @@ class AsyncLLM(EngineClient): ...@@ -291,7 +310,10 @@ class AsyncLLM(EngineClient):
"latter will be used, and the former will be ignored." "latter will be used, and the former will be ignored."
) )
else: else:
assert prompt_text is None if prompt_text is not None:
raise ValueError(
"should only provide prompt_text with EngineCoreRequest"
)
request = self.input_processor.process_inputs( request = self.input_processor.process_inputs(
request_id, request_id,
prompt, prompt,
...@@ -310,6 +332,15 @@ class AsyncLLM(EngineClient): ...@@ -310,6 +332,15 @@ class AsyncLLM(EngineClient):
self.input_processor.assign_request_id(request) self.input_processor.assign_request_id(request)
# We start the output_handler on the first call to add_request() so
# we can call __init__ before the event loop, which enables us
# to handle startup failure gracefully in the OpenAI server.
self._run_output_handler()
# Respect pause state before accepting new requests.
async with self._pause_cond:
await self._pause_cond.wait_for(lambda: not self._paused)
# Create a new output collector for the request. # Create a new output collector for the request.
queue = RequestOutputCollector(params.output_kind, request.request_id) queue = RequestOutputCollector(params.output_kind, request.request_id)
...@@ -385,37 +416,8 @@ class AsyncLLM(EngineClient): ...@@ -385,37 +416,8 @@ class AsyncLLM(EngineClient):
returning the RequestOutput back to the caller. returning the RequestOutput back to the caller.
""" """
if (
self.vllm_config.cache_config.kv_sharing_fast_prefill
and sampling_params.prompt_logprobs
):
raise ValueError(
"--kv-sharing-fast-prefill produces incorrect logprobs for "
"prompt tokens, please disable it when the requests need "
"prompt logprobs"
)
q: RequestOutputCollector | None = None q: RequestOutputCollector | None = None
try: try:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
# to handle startup failure gracefully in the OpenAI server.
self._run_output_handler()
# Wait until generation is resumed if the engine is paused.
async with self._pause_cond:
await self._pause_cond.wait_for(lambda: not self._paused)
if tokenization_kwargs is None:
tokenization_kwargs = {}
truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
_validate_truncation_size(
self.model_config.max_model_len,
truncate_prompt_tokens,
tokenization_kwargs,
)
q = await self.add_request( q = await self.add_request(
request_id, request_id,
prompt, prompt,
...@@ -639,18 +641,6 @@ class AsyncLLM(EngineClient): ...@@ -639,18 +641,6 @@ class AsyncLLM(EngineClient):
q: RequestOutputCollector | None = None q: RequestOutputCollector | None = None
try: try:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
# to handle startup failure gracefully in the OpenAI server.
self._run_output_handler()
# Respect pause state before accepting new requests.
async with self._pause_cond:
await self._pause_cond.wait_for(lambda: not self._paused)
if tokenization_kwargs is None:
tokenization_kwargs = {}
if truncate_prompt_tokens is not None: if truncate_prompt_tokens is not None:
warnings.warn( warnings.warn(
"The `truncate_prompt_tokens` parameter in `AsyncLLM.encode()` " "The `truncate_prompt_tokens` parameter in `AsyncLLM.encode()` "
...@@ -660,12 +650,6 @@ class AsyncLLM(EngineClient): ...@@ -660,12 +650,6 @@ class AsyncLLM(EngineClient):
stacklevel=2, stacklevel=2,
) )
_validate_truncation_size(
self.model_config.max_model_len,
pooling_params.truncate_prompt_tokens,
tokenization_kwargs,
)
q = await self.add_request( q = await self.add_request(
request_id, request_id,
prompt, prompt,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment