grok2.py 2.41 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ConversationMessage,
    parse_chat_messages,
    parse_chat_messages_async,
)
from vllm.logger import init_logger
from vllm.tokenizers.grok2 import Grok2Tokenizer

13
from .base import BaseRenderer
14
15
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
16
from .params import ChatParams
17
18
19
20

logger = init_logger(__name__)


21
class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
22
23
24
    def render_messages(
        self,
        messages: list[ChatCompletionMessageParam],
25
        params: ChatParams,
26
    ) -> tuple[list[ConversationMessage], DictPrompt]:
27
28
29
        tokenizer = self.get_tokenizer()
        conversation, mm_data, mm_uuids = parse_chat_messages(
            messages,
30
            self.model_config,
31
            content_format="string",
32
            media_io_kwargs=params.media_io_kwargs,
33
            mm_processor_kwargs=params.mm_processor_kwargs,
34
35
36
37
38
        )

        prompt_raw = tokenizer.apply_chat_template(
            conversation=conversation,
            messages=messages,
39
            **params.get_apply_chat_template_kwargs(),
40
41
        )

42
        prompt = parse_dec_only_prompt(prompt_raw)
43
44
45
46
47
        if mm_data is not None:
            prompt["multi_modal_data"] = mm_data
        if mm_uuids is not None:
            prompt["multi_modal_uuids"] = mm_uuids

48
        return conversation, prompt
49
50
51
52

    async def render_messages_async(
        self,
        messages: list[ChatCompletionMessageParam],
53
        params: ChatParams,
54
    ) -> tuple[list[ConversationMessage], DictPrompt]:
55
56
57
        tokenizer = self.get_tokenizer()
        conversation, mm_data, mm_uuids = await parse_chat_messages_async(
            messages,
58
            self.model_config,
59
            content_format="string",
60
            media_io_kwargs=params.media_io_kwargs,
61
            mm_processor_kwargs=params.mm_processor_kwargs,
62
63
64
65
66
        )

        prompt_raw = tokenizer.apply_chat_template(
            conversation=conversation,
            messages=messages,
67
            **params.get_apply_chat_template_kwargs(),
68
69
        )

70
        prompt = parse_dec_only_prompt(prompt_raw)
71
72
73
74
75
        if mm_data is not None:
            prompt["multi_modal_data"] = mm_data
        if mm_uuids is not None:
            prompt["multi_modal_uuids"] = mm_uuids

76
        return conversation, prompt