registry.py 2.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

from vllm.logger import init_logger
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.utils.import_utils import resolve_obj_by_qualname

10
from .base import BaseRenderer
11
12

if TYPE_CHECKING:
13
    from vllm.config import VllmConfig
14
15
16
17
18
19
20
21
22

logger = init_logger(__name__)


_VLLM_RENDERERS = {
    "deepseek_v32": ("deepseek_v32", "DeepseekV32Renderer"),
    "hf": ("hf", "HfRenderer"),
    "grok2": ("grok2", "Grok2Renderer"),
    "mistral": ("mistral", "MistralRenderer"),
23
    "qwen_vl": ("qwen_vl", "QwenVLRenderer"),
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    "terratorch": ("terratorch", "TerratorchRenderer"),
}


@dataclass
class RendererRegistry:
    # Renderer mode ->  (renderer module, renderer class)
    renderers: dict[str, tuple[str, str]] = field(default_factory=dict)

    def register(self, renderer_mode: str, module: str, class_name: str) -> None:
        if renderer_mode in self.renderers:
            logger.warning(
                "%s.%s is already registered for renderer_mode=%r. "
                "It is overwritten by the new one.",
                module,
                class_name,
                renderer_mode,
            )

        self.renderers[renderer_mode] = (module, class_name)

        return None

47
    def load_renderer_cls(self, renderer_mode: str) -> type[BaseRenderer]:
48
49
50
51
52
53
54
55
56
57
58
        if renderer_mode not in self.renderers:
            raise ValueError(f"No renderer registered for {renderer_mode=!r}.")

        module, class_name = self.renderers[renderer_mode]
        logger.debug_once(f"Loading {class_name} for {renderer_mode=!r}")

        return resolve_obj_by_qualname(f"{module}.{class_name}")

    def load_renderer(
        self,
        renderer_mode: str,
59
        config: "VllmConfig",
60
        tokenizer_kwargs: dict[str, Any],
61
    ) -> BaseRenderer:
62
63
64
65
66
67
68
69
70
71
72
73
74
        renderer_cls = self.load_renderer_cls(renderer_mode)
        return renderer_cls.from_config(config, tokenizer_kwargs)


RENDERER_REGISTRY = RendererRegistry(
    {
        mode: (f"vllm.renderers.{mod_relname}", cls_name)
        for mode, (mod_relname, cls_name) in _VLLM_RENDERERS.items()
    }
)
"""The global `RendererRegistry` instance."""


75
76
def renderer_from_config(config: "VllmConfig", **kwargs):
    model_config = config.model_config
77
    tokenizer_mode, tokenizer_name, args, kwargs = tokenizer_args_from_config(
78
        model_config, **kwargs
79
80
    )

81
82
83
84
    if (
        model_config.tokenizer_mode == "auto"
        and model_config.model_impl == "terratorch"
    ):
85
86
87
88
89
90
91
92
93
        renderer_mode = "terratorch"
    else:
        renderer_mode = tokenizer_mode

    return RENDERER_REGISTRY.load_renderer(
        renderer_mode,
        config,
        tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
    )