registry.py 3.21 KB
Newer Older
1
2
3
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

from vllm.logger import init_logger
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.utils.import_utils import resolve_obj_by_qualname

10
from .base import BaseRenderer
11
12

if TYPE_CHECKING:
13
    from vllm.config import VllmConfig
14
15
16
17
18
19
20
21

logger = init_logger(__name__)


_VLLM_RENDERERS = {
    "deepseek_v32": ("deepseek_v32", "DeepseekV32Renderer"),
    "hf": ("hf", "HfRenderer"),
    "grok2": ("grok2", "Grok2Renderer"),
22
    "kimi_audio": ("kimi_audio", "KimiAudioRenderer"),
23
    "mistral": ("mistral", "MistralRenderer"),
24
    "qwen_vl": ("qwen_vl", "QwenVLRenderer"),
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    "terratorch": ("terratorch", "TerratorchRenderer"),
}


@dataclass
class RendererRegistry:
    # Renderer mode ->  (renderer module, renderer class)
    renderers: dict[str, tuple[str, str]] = field(default_factory=dict)

    def register(self, renderer_mode: str, module: str, class_name: str) -> None:
        if renderer_mode in self.renderers:
            logger.warning(
                "%s.%s is already registered for renderer_mode=%r. "
                "It is overwritten by the new one.",
                module,
                class_name,
                renderer_mode,
            )

        self.renderers[renderer_mode] = (module, class_name)

        return None

48
    def load_renderer_cls(self, renderer_mode: str) -> type[BaseRenderer]:
49
50
51
52
53
54
55
56
57
58
59
        if renderer_mode not in self.renderers:
            raise ValueError(f"No renderer registered for {renderer_mode=!r}.")

        module, class_name = self.renderers[renderer_mode]
        logger.debug_once(f"Loading {class_name} for {renderer_mode=!r}")

        return resolve_obj_by_qualname(f"{module}.{class_name}")

    def load_renderer(
        self,
        renderer_mode: str,
60
        config: "VllmConfig",
61
        tokenizer_kwargs: dict[str, Any],
62
    ) -> BaseRenderer:
63
64
65
66
67
68
69
70
71
72
73
74
75
        renderer_cls = self.load_renderer_cls(renderer_mode)
        return renderer_cls.from_config(config, tokenizer_kwargs)


RENDERER_REGISTRY = RendererRegistry(
    {
        mode: (f"vllm.renderers.{mod_relname}", cls_name)
        for mode, (mod_relname, cls_name) in _VLLM_RENDERERS.items()
    }
)
"""The global `RendererRegistry` instance."""


76
77
def renderer_from_config(config: "VllmConfig", **kwargs):
    model_config = config.model_config
78

79
    tokenizer_mode, tokenizer_name, args, kwargs = tokenizer_args_from_config(
80
        model_config, **kwargs
81
82
    )

83
84
85
86
87
88
89
    # Override tokenizer_mode for Kimi-Audio models
    if model_config.architecture == "MoonshotKimiaForCausalLM":
        tokenizer_mode = "kimi_audio"
        # Update model_config so other components (e.g., multimodal registry)
        # also use the correct tokenizer mode
        model_config.tokenizer_mode = "kimi_audio"

90
91
92
93
    if (
        model_config.tokenizer_mode == "auto"
        and model_config.model_impl == "terratorch"
    ):
94
95
96
97
98
99
100
101
102
        renderer_mode = "terratorch"
    else:
        renderer_mode = tokenizer_mode

    return RENDERER_REGISTRY.load_renderer(
        renderer_mode,
        config,
        tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
    )