Unverified Commit 18bdcf41 authored by Mac Misiura's avatar Mac Misiura Committed by GitHub
Browse files

feat - add a new endpoint `get_tokenizer_info` to provide...


feat - add a new endpoint `get_tokenizer_info` to provide tokenizer/chat-template information (#20575)
Signed-off-by: default avatarm-misiura <mmisiura@redhat.com>
parent 1c3198b6
...@@ -32,6 +32,7 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811 ...@@ -32,6 +32,7 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811
f"zephyr-lora2={zephyr_lora_added_tokens_files}", f"zephyr-lora2={zephyr_lora_added_tokens_files}",
"--max-lora-rank", "--max-lora-rank",
"64", "64",
"--enable-tokenizer-info-endpoint",
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
...@@ -283,3 +284,106 @@ async def test_detokenize( ...@@ -283,3 +284,106 @@ async def test_detokenize(
response.raise_for_status() response.raise_for_status()
assert response.json() == {"prompt": prompt} assert response.json() == {"prompt": prompt}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name,tokenizer_name",
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"],
)
async def test_tokenizer_info_basic(
server: RemoteOpenAIServer,
model_name: str,
tokenizer_name: str,
):
"""Test basic tokenizer info endpoint functionality."""
response = requests.get(server.url_for("tokenizer_info"))
response.raise_for_status()
result = response.json()
assert "tokenizer_class" in result
assert isinstance(result["tokenizer_class"], str)
assert result["tokenizer_class"]
@pytest.mark.asyncio
async def test_tokenizer_info_schema(server: RemoteOpenAIServer):
"""Test that the response matches expected schema types."""
response = requests.get(server.url_for("tokenizer_info"))
response.raise_for_status()
result = response.json()
field_types = {
"add_bos_token": bool,
"add_prefix_space": bool,
"clean_up_tokenization_spaces": bool,
"split_special_tokens": bool,
"bos_token": str,
"eos_token": str,
"pad_token": str,
"unk_token": str,
"chat_template": str,
"errors": str,
"model_max_length": int,
"additional_special_tokens": list,
"added_tokens_decoder": dict,
}
for field, expected_type in field_types.items():
if field in result and result[field] is not None:
assert isinstance(
result[field],
expected_type), (f"{field} should be {expected_type.__name__}")
@pytest.mark.asyncio
async def test_tokenizer_info_added_tokens_structure(
server: RemoteOpenAIServer, ):
"""Test added_tokens_decoder structure if present."""
response = requests.get(server.url_for("tokenizer_info"))
response.raise_for_status()
result = response.json()
added_tokens = result.get("added_tokens_decoder")
if added_tokens:
for token_id, token_info in added_tokens.items():
assert isinstance(token_id, str), "Token IDs should be strings"
assert isinstance(token_info, dict), "Token info should be a dict"
assert "content" in token_info, "Token info should have content"
assert "special" in token_info, (
"Token info should have special flag")
assert isinstance(token_info["special"],
bool), ("Special flag should be boolean")
@pytest.mark.asyncio
async def test_tokenizer_info_consistency_with_tokenize(
server: RemoteOpenAIServer, ):
"""Test that tokenizer info is consistent with tokenization endpoint."""
info_response = requests.get(server.url_for("tokenizer_info"))
info_response.raise_for_status()
info = info_response.json()
tokenize_response = requests.post(
server.url_for("tokenize"),
json={
"model": MODEL_NAME,
"prompt": "Hello world!"
},
)
tokenize_response.raise_for_status()
tokenize_result = tokenize_response.json()
info_max_len = info.get("model_max_length")
tokenize_max_len = tokenize_result.get("max_model_len")
if info_max_len and tokenize_max_len:
assert info_max_len >= tokenize_max_len, (
"Info max length should be >= tokenize max length")
@pytest.mark.asyncio
async def test_tokenizer_info_chat_template(server: RemoteOpenAIServer):
"""Test chat template is properly included."""
response = requests.get(server.url_for("tokenizer_info"))
response.raise_for_status()
result = response.json()
chat_template = result.get("chat_template")
if chat_template:
assert isinstance(chat_template,
str), ("Chat template should be a string")
assert chat_template.strip(), "Chat template should not be empty"
\ No newline at end of file
...@@ -522,6 +522,19 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request): ...@@ -522,6 +522,19 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request):
assert_never(generator) assert_never(generator)
def maybe_register_tokenizer_info_endpoint(args):
"""Conditionally register the tokenizer info endpoint if enabled."""
if getattr(args, 'enable_tokenizer_info_endpoint', False):
@router.get("/tokenizer_info")
async def get_tokenizer_info(raw_request: Request):
"""Get comprehensive tokenizer information."""
result = await tokenization(raw_request).get_tokenizer_info()
return JSONResponse(content=result.model_dump(),
status_code=result.code if isinstance(
result, ErrorResponse) else 200)
@router.get("/v1/models") @router.get("/v1/models")
async def show_available_models(raw_request: Request): async def show_available_models(raw_request: Request):
handler = models(raw_request) handler = models(raw_request)
...@@ -1692,6 +1705,7 @@ async def run_server_worker(listen_address, ...@@ -1692,6 +1705,7 @@ async def run_server_worker(listen_address,
uvicorn_kwargs['log_config'] = log_config uvicorn_kwargs['log_config'] = log_config
async with build_async_engine_client(args, client_config) as engine_client: async with build_async_engine_client(args, client_config) as engine_client:
maybe_register_tokenizer_info_endpoint(args)
app = build_app(args) app = build_app(args)
vllm_config = await engine_client.get_vllm_config() vllm_config = await engine_client.get_vllm_config()
......
...@@ -182,6 +182,9 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" ...@@ -182,6 +182,9 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
"""If set to True, enable tracking server_load_metrics in the app state.""" """If set to True, enable tracking server_load_metrics in the app state."""
enable_force_include_usage: bool = False enable_force_include_usage: bool = False
"""If set to True, including usage on every request.""" """If set to True, including usage on every request."""
enable_tokenizer_info_endpoint: bool = False
"""Enable the /get_tokenizer_info endpoint. May expose chat
templates and other tokenizer configuration."""
@staticmethod @staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
......
...@@ -1953,6 +1953,16 @@ class DetokenizeResponse(OpenAIBaseModel): ...@@ -1953,6 +1953,16 @@ class DetokenizeResponse(OpenAIBaseModel):
prompt: str prompt: str
class TokenizerInfoResponse(OpenAIBaseModel):
"""
Response containing tokenizer configuration
equivalent to tokenizer_config.json
"""
model_config = ConfigDict(extra="allow")
tokenizer_class: str
class LoadLoRAAdapterRequest(BaseModel): class LoadLoRAAdapterRequest(BaseModel):
lora_name: str lora_name: str
lora_path: str lora_path: str
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Final, Optional, Union from typing import Any, Final, Optional, Union
import jinja2 import jinja2
from fastapi import Request from fastapi import Request
...@@ -17,11 +17,13 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest, ...@@ -17,11 +17,13 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
ErrorResponse, ErrorResponse,
TokenizeChatRequest, TokenizeChatRequest,
TokenizeRequest, TokenizeRequest,
TokenizeResponse) TokenizeResponse,
TokenizerInfoResponse)
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -155,3 +157,49 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -155,3 +157,49 @@ class OpenAIServingTokenization(OpenAIServing):
input_text = prompt_input["prompt"] input_text = prompt_input["prompt"]
return DetokenizeResponse(prompt=input_text) return DetokenizeResponse(prompt=input_text)
async def get_tokenizer_info(
self, ) -> Union[TokenizerInfoResponse, ErrorResponse]:
"""Get comprehensive tokenizer information."""
try:
tokenizer = await self.engine_client.get_tokenizer()
info = TokenizerInfo(tokenizer, self.chat_template).to_dict()
return TokenizerInfoResponse(**info)
except Exception as e:
return self.create_error_response(
f"Failed to get tokenizer info: {str(e)}")
@dataclass
class TokenizerInfo:
tokenizer: AnyTokenizer
chat_template: Optional[str]
def to_dict(self) -> dict[str, Any]:
"""Return the tokenizer configuration."""
return self._get_tokenizer_config()
def _get_tokenizer_config(self) -> dict[str, Any]:
"""Get tokenizer configuration directly from the tokenizer object."""
config = dict(getattr(self.tokenizer, "init_kwargs", None) or {})
# Remove file path fields
config.pop("vocab_file", None)
config.pop("merges_file", None)
config = self._make_json_serializable(config)
config["tokenizer_class"] = type(self.tokenizer).__name__
if self.chat_template:
config["chat_template"] = self.chat_template
return config
def _make_json_serializable(self, obj):
"""Convert any non-JSON-serializable objects to serializable format."""
if hasattr(obj, "content"):
return obj.content
elif isinstance(obj, dict):
return {k: self._make_json_serializable(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self._make_json_serializable(item) for item in obj]
else:
return obj
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment