"vscode:/vscode.git/clone" did not exist on "ab9f2cfd1942f7ddfee658ce86ea96b4789862af"
Unverified Commit f0a1c845 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Use new Renderer for Completions and Tokenize API (#32863)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 8980001c
......@@ -834,7 +834,7 @@ def parse_pooling_type(pooling_name: str):
@cache
def get_sentence_transformer_tokenizer_config(
model: str | Path, revision: str | None = "main"
):
) -> dict[str, Any] | None:
"""
Returns the tokenization configuration dictionary for a
given Sentence Transformer BERT model.
......
......@@ -50,14 +50,17 @@ class AsyncMicrobatchTokenizer:
self._executor = ThreadPoolExecutor(max_workers=1)
# === Public async API ===
async def __call__(self, prompt, **kwargs):
async def __call__(self, prompt, **kwargs) -> BatchEncoding:
result_future: Future = self._loop.create_future()
key = self._queue_key("encode", kwargs)
queue = self._get_queue(self._loop, key)
await queue.put((prompt, kwargs, result_future))
return await result_future
async def decode(self, token_ids, **kwargs):
async def encode(self, prompt, **kwargs) -> list[int]:
return (await self(prompt, **kwargs)).input_ids
async def decode(self, token_ids, **kwargs) -> str:
result_future: Future = self._loop.create_future()
key = self._queue_key("decode", kwargs)
queue = self._get_queue(self._loop, key)
......
......@@ -16,7 +16,6 @@ from vllm import TokensPrompt
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs import PromptType
from vllm.inputs.data import StreamingInput
from vllm.logger import init_logger
......@@ -25,7 +24,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.renderers import RendererLike
from vllm.renderers import RendererLike, merge_kwargs
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike
......@@ -304,13 +303,20 @@ class AsyncLLM(EngineClient):
"prompt logprobs"
)
if tokenization_kwargs is None:
tokenization_kwargs = {}
_validate_truncation_size(
self.model_config.max_model_len,
params.truncate_prompt_tokens,
tokenization_kwargs,
)
if params.truncate_prompt_tokens is not None:
params_type = type(params).__name__
warnings.warn(
f"The `truncate_prompt_tokens` parameter in `{params_type}` "
"is deprecated and will be removed in v0.16. "
"Please pass it via `tokenization_kwargs` instead.",
DeprecationWarning,
stacklevel=2,
)
tokenization_kwargs = merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
)
if isinstance(prompt, AsyncGenerator):
# Streaming input case.
......@@ -344,12 +350,12 @@ class AsyncLLM(EngineClient):
request_id,
prompt,
params,
arrival_time,
lora_request,
tokenization_kwargs,
trace_headers,
priority,
data_parallel_rank,
arrival_time=arrival_time,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
data_parallel_rank=data_parallel_rank,
)
prompt_text = get_prompt_text(prompt)
......@@ -757,7 +763,6 @@ class AsyncLLM(EngineClient):
lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
truncate_prompt_tokens: int | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""
......@@ -772,22 +777,10 @@ class AsyncLLM(EngineClient):
The caller of generate() iterates the returned AsyncGenerator,
returning the RequestOutput back to the caller.
NOTE: truncate_prompt_tokens is deprecated in v0.14.
TODO: Remove truncate_prompt_tokens in v0.15.
"""
q: RequestOutputCollector | None = None
try:
if truncate_prompt_tokens is not None:
warnings.warn(
"The `truncate_prompt_tokens` parameter in `AsyncLLM.encode()` "
"is deprecated and will be removed in v0.15. "
"Please use `pooling_params.truncate_prompt_tokens` instead.",
DeprecationWarning,
stacklevel=2,
)
q = await self.add_request(
request_id,
prompt,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment