terratorch.py 2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ConversationMessage,
    parse_chat_messages,
    parse_chat_messages_async,
)
from vllm.logger import init_logger

12
from .base import BaseRenderer
13
14
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
15
from .params import ChatParams
16
17
18
19

logger = init_logger(__name__)


20
class TerratorchRenderer(BaseRenderer):
21
22
23
    def render_messages(
        self,
        messages: list[ChatCompletionMessageParam],
24
        params: ChatParams,
25
    ) -> tuple[list[ConversationMessage], DictPrompt]:
26
        model_config = self.model_config
27
28
29
30
31

        conversation, mm_data, mm_uuids = parse_chat_messages(
            messages,
            model_config,
            content_format="string",
32
            media_io_kwargs=params.media_io_kwargs,
33
            mm_processor_kwargs=params.mm_processor_kwargs,
34
35
        )

36
        prompt = parse_dec_only_prompt([1])  # Dummy token IDs
37
38
39
40
41
42
43
44
45
46
        if mm_data is not None:
            prompt["multi_modal_data"] = mm_data
        if mm_uuids is not None:
            prompt["multi_modal_uuids"] = mm_uuids

        return conversation, prompt

    async def render_messages_async(
        self,
        messages: list[ChatCompletionMessageParam],
47
        params: ChatParams,
48
    ) -> tuple[list[ConversationMessage], DictPrompt]:
49
        model_config = self.model_config
50
51
52
53
54

        conversation, mm_data, mm_uuids = await parse_chat_messages_async(
            messages,
            model_config,
            content_format="string",
55
            media_io_kwargs=params.media_io_kwargs,
56
            mm_processor_kwargs=params.mm_processor_kwargs,
57
58
        )

59
        prompt = parse_dec_only_prompt([1])  # Dummy token IDs
60
61
62
63
64
65
        if mm_data is not None:
            prompt["multi_modal_data"] = mm_data
        if mm_uuids is not None:
            prompt["multi_modal_uuids"] = mm_uuids

        return conversation, prompt