terratorch.py 2.53 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any

from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ConversationMessage,
    parse_chat_messages,
    parse_chat_messages_async,
)
12
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
13
14
15
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike

16
from .params import ChatParams
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from .protocol import RendererLike

logger = init_logger(__name__)


class TerratorchRenderer(RendererLike):
    @classmethod
    def from_config(
        cls,
        config: "ModelConfig",
        tokenizer_kwargs: dict[str, Any],
    ) -> "RendererLike":
        return cls(config)

    def __init__(self, config: ModelConfig) -> None:
        super().__init__()

        self.config = config

        if not config.skip_tokenizer_init:
            raise ValueError("Terratorch renderer requires `skip_tokenizer_init=True`")

    @property
    def tokenizer(self) -> TokenizerLike | None:
        return None

    def get_tokenizer(self) -> TokenizerLike:
        raise ValueError("Tokenizer not available for Terratorch renderer")

    def render_messages(
        self,
        messages: list[ChatCompletionMessageParam],
49
50
        params: ChatParams,
    ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
51
52
53
54
55
56
57
58
        model_config = self.config

        conversation, mm_data, mm_uuids = parse_chat_messages(
            messages,
            model_config,
            content_format="string",
        )

59
        prompt = self.render_completion([1])  # Dummy token IDs
60
61
62
63
64
65
66
67
68
69
        if mm_data is not None:
            prompt["multi_modal_data"] = mm_data
        if mm_uuids is not None:
            prompt["multi_modal_uuids"] = mm_uuids

        return conversation, prompt

    async def render_messages_async(
        self,
        messages: list[ChatCompletionMessageParam],
70
71
        params: ChatParams,
    ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
72
73
74
75
76
77
78
79
        model_config = self.config

        conversation, mm_data, mm_uuids = await parse_chat_messages_async(
            messages,
            model_config,
            content_format="string",
        )

80
        prompt = self.render_completion([1])  # Dummy token IDs
81
82
83
84
85
86
        if mm_data is not None:
            prompt["multi_modal_data"] = mm_data
        if mm_uuids is not None:
            prompt["multi_modal_uuids"] = mm_uuids

        return conversation, prompt