terratorch.py 2.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any

from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ConversationMessage,
    parse_chat_messages,
    parse_chat_messages_async,
)
12
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
13
14
15
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike

16
from .params import ChatParams
17
from .protocol import BaseRenderer
18
19
20
21

logger = init_logger(__name__)


22
class TerratorchRenderer(BaseRenderer):
23
24
25
26
27
    @classmethod
    def from_config(
        cls,
        config: "ModelConfig",
        tokenizer_kwargs: dict[str, Any],
28
    ) -> "BaseRenderer":
29
30
31
        return cls(config)

    def __init__(self, config: ModelConfig) -> None:
32
        super().__init__(config)
33
34
35
36
37
38
39
40
41
42
43
44
45
46

        if not config.skip_tokenizer_init:
            raise ValueError("Terratorch renderer requires `skip_tokenizer_init=True`")

    @property
    def tokenizer(self) -> TokenizerLike | None:
        return None

    def get_tokenizer(self) -> TokenizerLike:
        raise ValueError("Tokenizer not available for Terratorch renderer")

    def render_messages(
        self,
        messages: list[ChatCompletionMessageParam],
47
48
        params: ChatParams,
    ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
49
50
51
52
53
54
55
56
        model_config = self.config

        conversation, mm_data, mm_uuids = parse_chat_messages(
            messages,
            model_config,
            content_format="string",
        )

57
        prompt = self.render_completion([1])  # Dummy token IDs
58
59
60
61
62
63
64
65
66
67
        if mm_data is not None:
            prompt["multi_modal_data"] = mm_data
        if mm_uuids is not None:
            prompt["multi_modal_uuids"] = mm_uuids

        return conversation, prompt

    async def render_messages_async(
        self,
        messages: list[ChatCompletionMessageParam],
68
69
        params: ChatParams,
    ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
70
71
72
73
74
75
76
77
        model_config = self.config

        conversation, mm_data, mm_uuids = await parse_chat_messages_async(
            messages,
            model_config,
            content_format="string",
        )

78
        prompt = self.render_completion([1])  # Dummy token IDs
79
80
81
82
83
84
        if mm_data is not None:
            prompt["multi_modal_data"] = mm_data
        if mm_uuids is not None:
            prompt["multi_modal_uuids"] = mm_uuids

        return conversation, prompt