grok2.py 2.82 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
from vllm.config import VllmConfig
5
6
7
8
9
10
11
12
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ConversationMessage,
    parse_chat_messages,
    parse_chat_messages_async,
)
from vllm.logger import init_logger
from vllm.tokenizers.grok2 import Grok2Tokenizer
13
from vllm.utils.async_utils import make_async
14

15
from .base import BaseRenderer
16
17
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
18
from .params import ChatParams
19
20
21
22

logger = init_logger(__name__)


23
class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    def __init__(
        self,
        config: VllmConfig,
        tokenizer: Grok2Tokenizer | None,
    ) -> None:
        super().__init__(config, tokenizer)

        self._apply_chat_template_async = make_async(
            self._apply_chat_template, executor=self._executor
        )

    def _apply_chat_template(self, *args, **kwargs):
        return self.get_tokenizer().apply_chat_template(*args, **kwargs)

38
39
40
    def render_messages(
        self,
        messages: list[ChatCompletionMessageParam],
41
        params: ChatParams,
42
    ) -> tuple[list[ConversationMessage], DictPrompt]:
43
44
        conversation, mm_data, mm_uuids = parse_chat_messages(
            messages,
45
            self.model_config,
46
            content_format="string",
47
            media_io_kwargs=params.media_io_kwargs,
48
            mm_processor_kwargs=params.mm_processor_kwargs,
49
50
        )

51
        prompt_raw = self._apply_chat_template(
52
53
            conversation=conversation,
            messages=messages,
54
            **params.get_apply_chat_template_kwargs(),
55
56
        )

57
        prompt = parse_dec_only_prompt(prompt_raw)
58
59
60
61
62
        if mm_data is not None:
            prompt["multi_modal_data"] = mm_data
        if mm_uuids is not None:
            prompt["multi_modal_uuids"] = mm_uuids

63
        return conversation, prompt
64
65
66
67

    async def render_messages_async(
        self,
        messages: list[ChatCompletionMessageParam],
68
        params: ChatParams,
69
    ) -> tuple[list[ConversationMessage], DictPrompt]:
70
71
        conversation, mm_data, mm_uuids = await parse_chat_messages_async(
            messages,
72
            self.model_config,
73
            content_format="string",
74
            media_io_kwargs=params.media_io_kwargs,
75
            mm_processor_kwargs=params.mm_processor_kwargs,
76
77
        )

78
        prompt_raw = await self._apply_chat_template_async(
79
80
            conversation=conversation,
            messages=messages,
81
            **params.get_apply_chat_template_kwargs(),
82
83
        )

84
        prompt = parse_dec_only_prompt(prompt_raw)
85
86
87
88
89
        if mm_data is not None:
            prompt["multi_modal_data"] = mm_data
        if mm_uuids is not None:
            prompt["multi_modal_uuids"] = mm_uuids

90
        return conversation, prompt