registry.py 2.67 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
4
from typing import TYPE_CHECKING
5
6

from vllm.logger import init_logger
7
8
9
10
11
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.registry import (
    cached_tokenizer_from_config,
    tokenizer_args_from_config,
)
12
13
from vllm.utils.import_utils import resolve_obj_by_qualname

14
from .base import BaseRenderer
15
16

if TYPE_CHECKING:
17
    from vllm.config import VllmConfig
18
19
20
21
22
23
24
25

logger = init_logger(__name__)


_VLLM_RENDERERS = {
    "deepseek_v32": ("deepseek_v32", "DeepseekV32Renderer"),
    "hf": ("hf", "HfRenderer"),
    "grok2": ("grok2", "Grok2Renderer"),
26
    "kimi_audio": ("hf", "HfRenderer"),
27
    "mistral": ("mistral", "MistralRenderer"),
28
    "qwen_vl": ("hf", "HfRenderer"),
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    "terratorch": ("terratorch", "TerratorchRenderer"),
}


@dataclass
class RendererRegistry:
    # Renderer mode ->  (renderer module, renderer class)
    renderers: dict[str, tuple[str, str]] = field(default_factory=dict)

    def register(self, renderer_mode: str, module: str, class_name: str) -> None:
        if renderer_mode in self.renderers:
            logger.warning(
                "%s.%s is already registered for renderer_mode=%r. "
                "It is overwritten by the new one.",
                module,
                class_name,
                renderer_mode,
            )

        self.renderers[renderer_mode] = (module, class_name)

        return None

52
    def load_renderer_cls(self, renderer_mode: str) -> type[BaseRenderer]:
53
54
55
56
57
58
59
60
61
62
63
        if renderer_mode not in self.renderers:
            raise ValueError(f"No renderer registered for {renderer_mode=!r}.")

        module, class_name = self.renderers[renderer_mode]
        logger.debug_once(f"Loading {class_name} for {renderer_mode=!r}")

        return resolve_obj_by_qualname(f"{module}.{class_name}")

    def load_renderer(
        self,
        renderer_mode: str,
64
        config: "VllmConfig",
65
        tokenizer: TokenizerLike | None,
66
    ) -> BaseRenderer:
67
        renderer_cls = self.load_renderer_cls(renderer_mode)
68
        return renderer_cls(config, tokenizer)
69
70
71
72
73
74
75
76
77
78
79


RENDERER_REGISTRY = RendererRegistry(
    {
        mode: (f"vllm.renderers.{mod_relname}", cls_name)
        for mode, (mod_relname, cls_name) in _VLLM_RENDERERS.items()
    }
)
"""The global `RendererRegistry` instance."""


80
81
def renderer_from_config(config: "VllmConfig", **kwargs):
    model_config = config.model_config
82

83
84
85
86
    tokenizer = cached_tokenizer_from_config(model_config, **kwargs)
    renderer_mode, *_ = tokenizer_args_from_config(model_config, **kwargs)

    return RENDERER_REGISTRY.load_renderer(renderer_mode, config, tokenizer)