Unverified Commit f0a1c845 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Use new Renderer for Completions and Tokenize API (#32863)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 8980001c
...@@ -834,7 +834,7 @@ def parse_pooling_type(pooling_name: str): ...@@ -834,7 +834,7 @@ def parse_pooling_type(pooling_name: str):
@cache @cache
def get_sentence_transformer_tokenizer_config( def get_sentence_transformer_tokenizer_config(
model: str | Path, revision: str | None = "main" model: str | Path, revision: str | None = "main"
): ) -> dict[str, Any] | None:
""" """
Returns the tokenization configuration dictionary for a Returns the tokenization configuration dictionary for a
given Sentence Transformer BERT model. given Sentence Transformer BERT model.
......
...@@ -50,14 +50,17 @@ class AsyncMicrobatchTokenizer: ...@@ -50,14 +50,17 @@ class AsyncMicrobatchTokenizer:
self._executor = ThreadPoolExecutor(max_workers=1) self._executor = ThreadPoolExecutor(max_workers=1)
# === Public async API === # === Public async API ===
async def __call__(self, prompt, **kwargs): async def __call__(self, prompt, **kwargs) -> BatchEncoding:
result_future: Future = self._loop.create_future() result_future: Future = self._loop.create_future()
key = self._queue_key("encode", kwargs) key = self._queue_key("encode", kwargs)
queue = self._get_queue(self._loop, key) queue = self._get_queue(self._loop, key)
await queue.put((prompt, kwargs, result_future)) await queue.put((prompt, kwargs, result_future))
return await result_future return await result_future
async def decode(self, token_ids, **kwargs): async def encode(self, prompt, **kwargs) -> list[int]:
return (await self(prompt, **kwargs)).input_ids
async def decode(self, token_ids, **kwargs) -> str:
result_future: Future = self._loop.create_future() result_future: Future = self._loop.create_future()
key = self._queue_key("decode", kwargs) key = self._queue_key("decode", kwargs)
queue = self._get_queue(self._loop, key) queue = self._get_queue(self._loop, key)
......
...@@ -16,7 +16,6 @@ from vllm import TokensPrompt ...@@ -16,7 +16,6 @@ from vllm import TokensPrompt
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.inputs.data import StreamingInput from vllm.inputs.data import StreamingInput
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -25,7 +24,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry ...@@ -25,7 +24,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import RendererLike from vllm.renderers import RendererLike, merge_kwargs
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
...@@ -304,13 +303,20 @@ class AsyncLLM(EngineClient): ...@@ -304,13 +303,20 @@ class AsyncLLM(EngineClient):
"prompt logprobs" "prompt logprobs"
) )
if tokenization_kwargs is None: if params.truncate_prompt_tokens is not None:
tokenization_kwargs = {} params_type = type(params).__name__
_validate_truncation_size( warnings.warn(
self.model_config.max_model_len, f"The `truncate_prompt_tokens` parameter in `{params_type}` "
params.truncate_prompt_tokens, "is deprecated and will be removed in v0.16. "
tokenization_kwargs, "Please pass it via `tokenization_kwargs` instead.",
) DeprecationWarning,
stacklevel=2,
)
tokenization_kwargs = merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
)
if isinstance(prompt, AsyncGenerator): if isinstance(prompt, AsyncGenerator):
# Streaming input case. # Streaming input case.
...@@ -344,12 +350,12 @@ class AsyncLLM(EngineClient): ...@@ -344,12 +350,12 @@ class AsyncLLM(EngineClient):
request_id, request_id,
prompt, prompt,
params, params,
arrival_time, arrival_time=arrival_time,
lora_request, lora_request=lora_request,
tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
trace_headers, trace_headers=trace_headers,
priority, priority=priority,
data_parallel_rank, data_parallel_rank=data_parallel_rank,
) )
prompt_text = get_prompt_text(prompt) prompt_text = get_prompt_text(prompt)
...@@ -757,7 +763,6 @@ class AsyncLLM(EngineClient): ...@@ -757,7 +763,6 @@ class AsyncLLM(EngineClient):
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None, trace_headers: Mapping[str, str] | None = None,
priority: int = 0, priority: int = 0,
truncate_prompt_tokens: int | None = None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
) -> AsyncGenerator[PoolingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
""" """
...@@ -772,22 +777,10 @@ class AsyncLLM(EngineClient): ...@@ -772,22 +777,10 @@ class AsyncLLM(EngineClient):
The caller of generate() iterates the returned AsyncGenerator, The caller of generate() iterates the returned AsyncGenerator,
returning the RequestOutput back to the caller. returning the RequestOutput back to the caller.
NOTE: truncate_prompt_tokens is deprecated in v0.14.
TODO: Remove truncate_prompt_tokens in v0.15.
""" """
q: RequestOutputCollector | None = None q: RequestOutputCollector | None = None
try: try:
if truncate_prompt_tokens is not None:
warnings.warn(
"The `truncate_prompt_tokens` parameter in `AsyncLLM.encode()` "
"is deprecated and will be removed in v0.15. "
"Please use `pooling_params.truncate_prompt_tokens` instead.",
DeprecationWarning,
stacklevel=2,
)
q = await self.add_request( q = await self.add_request(
request_id, request_id,
prompt, prompt,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment