Unverified Commit d117a4d1 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Introduce Renderer for processing chat messages (using `ModelConfig`) (#30200)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 421012b6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ConversationMessage,
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.async_utils import make_async
from .protocol import RendererLike
logger = init_logger(__name__)
def safe_apply_chat_template(
tokenizer: MistralTokenizer,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> str | list[int]:
from mistral_common.exceptions import MistralCommonException
try:
return tokenizer.apply_chat_template(messages, **kwargs)
# mistral-common uses assert statements to stop processing of input
# if input does not comply with the expected format.
# We convert those assertion errors to ValueErrors so they can be
# properly caught in the preprocessing_input step
except (AssertionError, MistralCommonException) as e:
raise ValueError(str(e)) from e
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
except Exception as e:
# Log and report any library-related exceptions for further
# investigation.
logger.exception(
"An error occurred in `mistral_common` while applying chat template"
)
raise ValueError(str(e)) from e
class MistralRenderer(RendererLike):
@classmethod
def from_config(
cls,
config: ModelConfig,
tokenizer_kwargs: dict[str, Any],
) -> "RendererLike":
return cls(config, tokenizer_kwargs)
def __init__(
self,
config: ModelConfig,
tokenizer_kwargs: dict[str, Any],
) -> None:
super().__init__()
self.config = config
if config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = cached_get_tokenizer(
tokenizer_cls=MistralTokenizer,
**tokenizer_kwargs,
)
self._tokenizer = tokenizer
self._apply_chat_template_executor = ThreadPoolExecutor(max_workers=1)
self._apply_chat_template_async = make_async(
safe_apply_chat_template, executor=self._apply_chat_template_executor
)
@property
def tokenizer(self) -> MistralTokenizer | None:
return self._tokenizer
def get_tokenizer(self) -> MistralTokenizer:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
return tokenizer
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
self.config,
content_format="string",
)
prompt_raw = safe_apply_chat_template(tokenizer, messages, **kwargs)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]
async def render_messages_async(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
self.config,
content_format="string",
)
prompt_raw = await self._apply_chat_template_async(
tokenizer, messages, **kwargs
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, Protocol
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.tokenizers import TokenizerLike
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ConversationMessage,
)
class RendererLike(Protocol):
@classmethod
def from_config(
cls,
config: "ModelConfig",
tokenizer_kwargs: dict[str, Any],
) -> "RendererLike":
raise NotImplementedError
@property
def tokenizer(self) -> TokenizerLike | None:
raise NotImplementedError
def get_tokenizer(self) -> TokenizerLike:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
return tokenizer
def render_messages(
self,
messages: list["ChatCompletionMessageParam"],
**kwargs,
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt]:
raise NotImplementedError
async def render_messages_async(
self,
messages: list["ChatCompletionMessageParam"],
**kwargs,
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt]:
return self.render_messages(messages, **kwargs)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from vllm.logger import init_logger
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.utils.import_utils import resolve_obj_by_qualname
from .protocol import RendererLike
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__)
_VLLM_RENDERERS = {
"deepseek_v32": ("deepseek_v32", "DeepseekV32Renderer"),
"hf": ("hf", "HfRenderer"),
"grok2": ("grok2", "Grok2Renderer"),
"mistral": ("mistral", "MistralRenderer"),
"terratorch": ("terratorch", "TerratorchRenderer"),
}
@dataclass
class RendererRegistry:
# Renderer mode -> (renderer module, renderer class)
renderers: dict[str, tuple[str, str]] = field(default_factory=dict)
def register(self, renderer_mode: str, module: str, class_name: str) -> None:
if renderer_mode in self.renderers:
logger.warning(
"%s.%s is already registered for renderer_mode=%r. "
"It is overwritten by the new one.",
module,
class_name,
renderer_mode,
)
self.renderers[renderer_mode] = (module, class_name)
return None
def load_renderer_cls(self, renderer_mode: str) -> type[RendererLike]:
if renderer_mode not in self.renderers:
raise ValueError(f"No renderer registered for {renderer_mode=!r}.")
module, class_name = self.renderers[renderer_mode]
logger.debug_once(f"Loading {class_name} for {renderer_mode=!r}")
return resolve_obj_by_qualname(f"{module}.{class_name}")
def load_renderer(
self,
renderer_mode: str,
config: "ModelConfig",
tokenizer_kwargs: dict[str, Any],
) -> RendererLike:
renderer_cls = self.load_renderer_cls(renderer_mode)
return renderer_cls.from_config(config, tokenizer_kwargs)
RENDERER_REGISTRY = RendererRegistry(
{
mode: (f"vllm.renderers.{mod_relname}", cls_name)
for mode, (mod_relname, cls_name) in _VLLM_RENDERERS.items()
}
)
"""The global `RendererRegistry` instance."""
def renderer_from_config(config: "ModelConfig", **kwargs):
tokenizer_mode, tokenizer_name, args, kwargs = tokenizer_args_from_config(
config, **kwargs
)
if config.tokenizer_mode == "auto" and config.model_impl == "terratorch":
renderer_mode = "terratorch"
else:
renderer_mode = tokenizer_mode
return RENDERER_REGISTRY.load_renderer(
renderer_mode,
config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ConversationMessage,
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from .protocol import RendererLike
logger = init_logger(__name__)
class TerratorchRenderer(RendererLike):
@classmethod
def from_config(
cls,
config: "ModelConfig",
tokenizer_kwargs: dict[str, Any],
) -> "RendererLike":
return cls(config)
def __init__(self, config: ModelConfig) -> None:
super().__init__()
self.config = config
if not config.skip_tokenizer_init:
raise ValueError("Terratorch renderer requires `skip_tokenizer_init=True`")
@property
def tokenizer(self) -> TokenizerLike | None:
return None
def get_tokenizer(self) -> TokenizerLike:
raise ValueError("Tokenizer not available for Terratorch renderer")
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
model_config = self.config
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
model_config,
content_format="string",
)
prompt = TokensPrompt(prompt_token_ids=[1])
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt
async def render_messages_async(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
model_config = self.config
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
model_config,
content_format="string",
)
prompt = TokensPrompt(prompt_token_ids=[1]) # Dummy token IDs
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt
......@@ -23,9 +23,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.renderers import RendererLike
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from vllm.tokenizers import TokenizerLike
from vllm.tracing import init_tracer
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.usage.usage_lib import UsageContext
......@@ -106,9 +107,7 @@ class AsyncLLM(EngineClient):
"enabling logging without default stat loggers."
)
tokenizer = cached_tokenizer_from_config(self.model_config)
self.input_processor = InputProcessor(self.vllm_config, tokenizer)
self.input_processor = InputProcessor(self.vllm_config)
self.io_processor = get_io_processor(
self.vllm_config,
self.model_config.io_processor_plugin,
......@@ -709,13 +708,12 @@ class AsyncLLM(EngineClient):
def tokenizer(self) -> TokenizerLike | None:
return self.input_processor.tokenizer
async def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
def get_tokenizer(self) -> TokenizerLike:
return self.input_processor.get_tokenizer()
return self.tokenizer
@property
def renderer(self) -> RendererLike:
return self.input_processor.renderer
async def is_tracing_enabled(self) -> bool:
return self.observability_config.otlp_traces_endpoint is not None # type: ignore
......
......@@ -19,6 +19,7 @@ from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing.context import set_request_id
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.renderers import RendererLike
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
......@@ -45,7 +46,6 @@ class InputProcessor:
def __init__(
self,
vllm_config: VllmConfig,
tokenizer: TokenizerLike | None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
) -> None:
self.vllm_config = vllm_config
......@@ -61,8 +61,7 @@ class InputProcessor:
self.input_preprocessor = InputPreprocessor(
self.model_config,
tokenizer,
self.vllm_config.observability_config,
vllm_config.observability_config,
mm_registry,
mm_processor_cache=self.mm_processor_cache,
)
......@@ -71,6 +70,13 @@ class InputProcessor:
def tokenizer(self) -> TokenizerLike | None:
return self.input_preprocessor.tokenizer
def get_tokenizer(self) -> TokenizerLike:
return self.input_preprocessor.get_tokenizer()
@property
def renderer(self) -> RendererLike:
return self.input_preprocessor.renderer
def _validate_logprobs(
self,
params: SamplingParams,
......
......@@ -21,9 +21,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.renderers import RendererLike
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from vllm.tokenizers import TokenizerLike
from vllm.tracing import init_tracer
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest
......@@ -84,9 +85,7 @@ class LLMEngine:
self.dp_group = None
self.should_execute_dummy_batch = False
tokenizer = cached_tokenizer_from_config(self.model_config)
self.input_processor = InputProcessor(self.vllm_config, tokenizer)
self.input_processor = InputProcessor(self.vllm_config)
self.io_processor = get_io_processor(
self.vllm_config,
self.model_config.io_processor_plugin,
......@@ -357,12 +356,11 @@ class LLMEngine:
return self.input_processor.tokenizer
def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
return self.input_processor.get_tokenizer()
return self.tokenizer
@property
def renderer(self) -> RendererLike:
return self.input_processor.renderer
def do_log_stats(self) -> None:
"""Log stats if logging is enabled."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment