Unverified Commit 2e225f7b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Renderer] Consolidate factory methods (#38218)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 757eafcf
......@@ -16,7 +16,7 @@ from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.tokenizers.registry import cached_tokenizer_from_config
from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2"
......@@ -72,11 +72,9 @@ class MockVllmConfig:
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer.from_config(
return HfRenderer(
MockVllmConfig(model_config, parallel_config=MockParallelConfig()),
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
cached_tokenizer_from_config(model_config),
)
......
......@@ -41,7 +41,7 @@ from vllm.renderers.hf import HfRenderer
from vllm.renderers.mistral import MistralRenderer
from vllm.tokenizers import get_tokenizer
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.tokenizers.registry import cached_tokenizer_from_config
from vllm.tool_parsers import ToolParserManager
from vllm.v1.engine.async_llm import AsyncLLM
......@@ -553,11 +553,9 @@ class MockVllmConfig:
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer.from_config(
return HfRenderer(
MockVllmConfig(model_config, parallel_config=MockParallelConfig()),
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
cached_tokenizer_from_config(model_config),
)
......
......@@ -16,7 +16,7 @@ from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.tokenizers.registry import cached_tokenizer_from_config
from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2"
......@@ -93,11 +93,9 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer.from_config(
return HfRenderer(
MockVllmConfig(model_config, parallel_config=MockParallelConfig()),
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
cached_tokenizer_from_config(model_config),
)
......
......@@ -18,7 +18,7 @@ from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.tokenizers.registry import cached_tokenizer_from_config
from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2"
......@@ -101,11 +101,9 @@ def register_mock_resolver():
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer.from_config(
return HfRenderer(
MockVllmConfig(model_config, parallel_config=MockParallelConfig()),
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
cached_tokenizer_from_config(model_config),
)
......
......@@ -15,7 +15,6 @@ from vllm.inputs import SingletonPrompt
from vllm.renderers import TokenizeParams
from vllm.renderers.hf import HfRenderer
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
from vllm.tokenizers.registry import tokenizer_args_from_config
MODEL_NAME = "openai-community/gpt2"
......@@ -81,8 +80,6 @@ def _build_renderer(
truncation_side: str = "left",
max_chars_per_token: int = 1,
):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
renderer = HfRenderer(
MockVllmConfig(model_config, parallel_config=MockParallelConfig()),
tokenizer=(
......
......@@ -8,7 +8,7 @@ from vllm.assets.video import VideoAsset
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.multimodal.parse import parse_mm_uuids
from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.tokenizers.registry import cached_tokenizer_from_config
cherry_pil_image = ImageAsset("cherry_blossom").pil_image
stop_pil_image = ImageAsset("stop_sign").pil_image
......@@ -29,11 +29,9 @@ def _build_renderer(
cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching),
)
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer.from_config(
return HfRenderer(
vllm_config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
cached_tokenizer_from_config(model_config),
)
......
......@@ -542,7 +542,9 @@ class ModelConfig:
# Set default tokenizer modes based on model architecture
if self.tokenizer_mode == "auto":
if arch == "Grok1ForCausalLM":
if self.model_impl == "terratorch":
self.tokenizer_mode = "terratorch"
elif arch == "Grok1ForCausalLM":
self.tokenizer_mode = "grok2"
elif arch == "MoonshotKimiaForCausalLM":
self.tokenizer_mode = "kimi_audio"
......
......@@ -69,15 +69,6 @@ _T = TypeVar("_T", bound=TokenizerLike, default=TokenizerLike)
class BaseRenderer(ABC, Generic[_T]):
@classmethod
@abstractmethod
def from_config(
cls,
config: "VllmConfig",
tokenizer_kwargs: dict[str, Any],
) -> "BaseRenderer":
raise NotImplementedError
def __init__(self, config: "VllmConfig", tokenizer: _T | None) -> None:
super().__init__()
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from vllm.config import VllmConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ConversationMessage,
......@@ -10,7 +8,6 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages_async,
)
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
from .base import BaseRenderer
......@@ -22,23 +19,6 @@ logger = init_logger(__name__)
class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]):
@classmethod
def from_config( # type: ignore[override]
cls,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> "DeepseekV32Renderer":
model_config = config.model_config
if model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = cached_get_tokenizer(
tokenizer_cls=DeepseekV32Tokenizer,
**tokenizer_kwargs,
)
return cls(config, tokenizer)
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from vllm.config import VllmConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ConversationMessage,
......@@ -10,7 +8,6 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages_async,
)
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.grok2 import Grok2Tokenizer
from .base import BaseRenderer
......@@ -22,23 +19,6 @@ logger = init_logger(__name__)
class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
@classmethod
def from_config( # type: ignore[override]
cls,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> "Grok2Renderer":
model_config = config.model_config
if model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = cached_get_tokenizer(
tokenizer_cls=Grok2Tokenizer,
**tokenizer_kwargs,
)
return cls(config, tokenizer)
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
......
......@@ -27,8 +27,7 @@ from vllm.entrypoints.chat_utils import (
)
from vllm.inputs import MultiModalDataDict, MultiModalUUIDDict
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.hf import CachedHfTokenizer, HfTokenizer
from vllm.tokenizers.hf import HfTokenizer
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils.func_utils import supports_kw
......@@ -604,26 +603,6 @@ def replace_vision_chunk_video_placeholder(
class HfRenderer(BaseRenderer[HfTokenizer]):
@classmethod
def from_config( # type: ignore[override]
cls,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> "HfRenderer":
model_config = config.model_config
if model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = cast(
HfTokenizer,
cached_get_tokenizer(
tokenizer_cls=CachedHfTokenizer, # type: ignore[type-abstract]
**tokenizer_kwargs,
),
)
return cls(config, tokenizer)
def __init__(
self,
config: VllmConfig,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, cast
from vllm.config import VllmConfig
from vllm.tokenizers.kimi_audio import KimiAudioTokenizer
from vllm.tokenizers.registry import get_tokenizer
from .hf import HfRenderer, HfTokenizer
class KimiAudioRenderer(HfRenderer):
"""Renderer for Kimi-Audio models.
This renderer uses HfRenderer internally with a custom TikToken tokenizer.
"""
@classmethod
def from_config( # type: ignore[override]
cls,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> "HfRenderer":
"""Create an HfRenderer instance for Kimi-Audio models."""
model_config = config.model_config
if model_config.skip_tokenizer_init:
tokenizer = None
else:
# Extract tokenizer_name from kwargs (already processed by
# tokenizer_args_from_config for ModelScope/GGUF/etc)
tokenizer_name = tokenizer_kwargs.pop(
"tokenizer_name", model_config.tokenizer
)
# Remove tokenizer_cls from kwargs to avoid duplicate argument
tokenizer_kwargs = {
k: v for k, v in tokenizer_kwargs.items() if k != "tokenizer_cls"
}
# Use get_tokenizer directly instead of cached_get_tokenizer
# (KimiAudioTokenizer doesn't work with get_cached_tokenizer)
tokenizer = cast(
HfTokenizer,
get_tokenizer(
tokenizer_name,
tokenizer_cls=KimiAudioTokenizer, # type: ignore[arg-type]
**tokenizer_kwargs,
),
)
return HfRenderer(config, tokenizer)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from vllm.config import VllmConfig
from vllm.entrypoints.chat_utils import (
......@@ -11,7 +10,6 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages_async,
)
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.async_utils import make_async
......@@ -51,23 +49,6 @@ def safe_apply_chat_template(
class MistralRenderer(BaseRenderer[MistralTokenizer]):
@classmethod
def from_config( # type: ignore[override]
cls,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> "MistralRenderer":
model_config = config.model_config
if model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = cached_get_tokenizer(
tokenizer_cls=MistralTokenizer,
**tokenizer_kwargs,
)
return cls(config, tokenizer)
def __init__(
self,
config: VllmConfig,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from vllm.config import VllmConfig
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.qwen_vl import QwenVLTokenizer
from .hf import HfRenderer
class QwenVLRenderer(HfRenderer):
@classmethod
def from_config( # type: ignore[override]
cls,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> "HfRenderer":
model_config = config.model_config
if model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = cached_get_tokenizer(
tokenizer_cls=QwenVLTokenizer,
**tokenizer_kwargs,
)
return HfRenderer(config, tokenizer)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING
from vllm.logger import init_logger
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.registry import (
cached_tokenizer_from_config,
tokenizer_args_from_config,
)
from vllm.utils.import_utils import resolve_obj_by_qualname
from .base import BaseRenderer
......@@ -19,9 +23,9 @@ _VLLM_RENDERERS = {
"deepseek_v32": ("deepseek_v32", "DeepseekV32Renderer"),
"hf": ("hf", "HfRenderer"),
"grok2": ("grok2", "Grok2Renderer"),
"kimi_audio": ("kimi_audio", "KimiAudioRenderer"),
"kimi_audio": ("hf", "HfRenderer"),
"mistral": ("mistral", "MistralRenderer"),
"qwen_vl": ("qwen_vl", "QwenVLRenderer"),
"qwen_vl": ("hf", "HfRenderer"),
"terratorch": ("terratorch", "TerratorchRenderer"),
}
......@@ -58,10 +62,10 @@ class RendererRegistry:
self,
renderer_mode: str,
config: "VllmConfig",
tokenizer_kwargs: dict[str, Any],
tokenizer: TokenizerLike | None,
) -> BaseRenderer:
renderer_cls = self.load_renderer_cls(renderer_mode)
return renderer_cls.from_config(config, tokenizer_kwargs)
return renderer_cls(config, tokenizer)
RENDERER_REGISTRY = RendererRegistry(
......@@ -76,20 +80,7 @@ RENDERER_REGISTRY = RendererRegistry(
def renderer_from_config(config: "VllmConfig", **kwargs):
model_config = config.model_config
tokenizer_mode, tokenizer_name, args, kwargs = tokenizer_args_from_config(
model_config, **kwargs
)
if (
model_config.tokenizer_mode == "auto"
and model_config.model_impl == "terratorch"
):
renderer_mode = "terratorch"
else:
renderer_mode = tokenizer_mode
return RENDERER_REGISTRY.load_renderer(
renderer_mode,
config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
tokenizer = cached_tokenizer_from_config(model_config, **kwargs)
renderer_mode, *_ = tokenizer_args_from_config(model_config, **kwargs)
return RENDERER_REGISTRY.load_renderer(renderer_mode, config, tokenizer)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from vllm.config import VllmConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ConversationMessage,
......@@ -20,18 +18,6 @@ logger = init_logger(__name__)
class TerratorchRenderer(BaseRenderer):
@classmethod
def from_config(
cls,
config: VllmConfig, # type: ignore[override]
tokenizer_kwargs: dict[str, Any],
) -> "TerratorchRenderer":
model_config = config.model_config
if not model_config.skip_tokenizer_init:
raise ValueError("Terratorch renderer requires `skip_tokenizer_init=True`")
return cls(config, None)
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment