terratorch.py 2.33 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any

5
from vllm.config import VllmConfig
6
7
8
9
10
11
12
13
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ConversationMessage,
    parse_chat_messages,
    parse_chat_messages_async,
)
from vllm.logger import init_logger

14
from .base import BaseRenderer
15
16
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
17
from .params import ChatParams
18
19
20
21

logger = init_logger(__name__)


22
class TerratorchRenderer(BaseRenderer):
23
24
25
    @classmethod
    def from_config(
        cls,
26
        config: VllmConfig,  # type: ignore[override]
27
        tokenizer_kwargs: dict[str, Any],
28
29
    ) -> "TerratorchRenderer":
        model_config = config.model_config
30
        if not model_config.skip_tokenizer_init:
31
32
            raise ValueError("Terratorch renderer requires `skip_tokenizer_init=True`")

33
        return cls(config, None)
34
35
36
37

    def render_messages(
        self,
        messages: list[ChatCompletionMessageParam],
38
        params: ChatParams,
39
    ) -> tuple[list[ConversationMessage], DictPrompt]:
40
        model_config = self.model_config
41
42
43
44
45

        conversation, mm_data, mm_uuids = parse_chat_messages(
            messages,
            model_config,
            content_format="string",
46
            media_io_kwargs=params.media_io_kwargs,
47
48
        )

49
        prompt = parse_dec_only_prompt([1])  # Dummy token IDs
50
51
52
53
54
55
56
57
58
59
        if mm_data is not None:
            prompt["multi_modal_data"] = mm_data
        if mm_uuids is not None:
            prompt["multi_modal_uuids"] = mm_uuids

        return conversation, prompt

    async def render_messages_async(
        self,
        messages: list[ChatCompletionMessageParam],
60
        params: ChatParams,
61
    ) -> tuple[list[ConversationMessage], DictPrompt]:
62
        model_config = self.model_config
63
64
65
66
67

        conversation, mm_data, mm_uuids = await parse_chat_messages_async(
            messages,
            model_config,
            content_format="string",
68
            media_io_kwargs=params.media_io_kwargs,
69
70
        )

71
        prompt = parse_dec_only_prompt([1])  # Dummy token IDs
72
73
74
75
76
77
        if mm_data is not None:
            prompt["multi_modal_data"] = mm_data
        if mm_uuids is not None:
            prompt["multi_modal_uuids"] = mm_uuids

        return conversation, prompt