grok2.py 3.22 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any

from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ConversationMessage,
    parse_chat_messages,
    parse_chat_messages_async,
)
12
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
13
14
15
16
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.grok2 import Grok2Tokenizer

17
from .params import ChatParams
18
from .protocol import BaseRenderer
19
20
21
22

logger = init_logger(__name__)


23
class Grok2Renderer(BaseRenderer):
24
25
26
27
28
    @classmethod
    def from_config(
        cls,
        config: ModelConfig,
        tokenizer_kwargs: dict[str, Any],
29
    ) -> "BaseRenderer":
30
31
32
33
34
35
36
        return cls(config, tokenizer_kwargs)

    def __init__(
        self,
        config: ModelConfig,
        tokenizer_kwargs: dict[str, Any],
    ) -> None:
37
        super().__init__(config)
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

        if config.skip_tokenizer_init:
            tokenizer = None
        else:
            tokenizer = cached_get_tokenizer(
                tokenizer_cls=Grok2Tokenizer,
                **tokenizer_kwargs,
            )

        self._tokenizer = tokenizer

    @property
    def tokenizer(self) -> Grok2Tokenizer | None:
        return self._tokenizer

    def get_tokenizer(self) -> Grok2Tokenizer:
        tokenizer = self.tokenizer
        if tokenizer is None:
            raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")

        return tokenizer

    def render_messages(
        self,
        messages: list[ChatCompletionMessageParam],
63
64
        params: ChatParams,
    ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
65
66
67
68
69
70
71
72
73
74
        tokenizer = self.get_tokenizer()
        conversation, mm_data, mm_uuids = parse_chat_messages(
            messages,
            self.config,
            content_format="string",
        )

        prompt_raw = tokenizer.apply_chat_template(
            conversation=conversation,
            messages=messages,
75
            **params.get_apply_chat_template_kwargs(),
76
77
        )

78
        prompt = self.render_completion(prompt_raw)
79
80
81
82
83
        if mm_data is not None:
            prompt["multi_modal_data"] = mm_data
        if mm_uuids is not None:
            prompt["multi_modal_uuids"] = mm_uuids

84
        return conversation, prompt
85
86
87
88

    async def render_messages_async(
        self,
        messages: list[ChatCompletionMessageParam],
89
90
        params: ChatParams,
    ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
91
92
93
94
95
96
97
98
99
100
        tokenizer = self.get_tokenizer()
        conversation, mm_data, mm_uuids = await parse_chat_messages_async(
            messages,
            self.config,
            content_format="string",
        )

        prompt_raw = tokenizer.apply_chat_template(
            conversation=conversation,
            messages=messages,
101
            **params.get_apply_chat_template_kwargs(),
102
103
        )

104
        prompt = self.render_completion(prompt_raw)
105
106
107
108
109
        if mm_data is not None:
            prompt["multi_modal_data"] = mm_data
        if mm_uuids is not None:
            prompt["multi_modal_uuids"] = mm_uuids

110
        return conversation, prompt