grok2.py 3.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any

from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ConversationMessage,
    parse_chat_messages,
    parse_chat_messages_async,
)
12
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
13
14
15
16
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.grok2 import Grok2Tokenizer

17
from .params import ChatParams
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from .protocol import RendererLike

logger = init_logger(__name__)


class Grok2Renderer(RendererLike):
    @classmethod
    def from_config(
        cls,
        config: ModelConfig,
        tokenizer_kwargs: dict[str, Any],
    ) -> "RendererLike":
        return cls(config, tokenizer_kwargs)

    def __init__(
        self,
        config: ModelConfig,
        tokenizer_kwargs: dict[str, Any],
    ) -> None:
        super().__init__()

        self.config = config

        if config.skip_tokenizer_init:
            tokenizer = None
        else:
            tokenizer = cached_get_tokenizer(
                tokenizer_cls=Grok2Tokenizer,
                **tokenizer_kwargs,
            )

        self._tokenizer = tokenizer

    @property
    def tokenizer(self) -> Grok2Tokenizer | None:
        return self._tokenizer

    def get_tokenizer(self) -> Grok2Tokenizer:
        tokenizer = self.tokenizer
        if tokenizer is None:
            raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")

        return tokenizer

    def render_messages(
        self,
        messages: list[ChatCompletionMessageParam],
65
66
        params: ChatParams,
    ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
67
68
69
70
71
72
73
        tokenizer = self.get_tokenizer()
        conversation, mm_data, mm_uuids = parse_chat_messages(
            messages,
            self.config,
            content_format="string",
        )

74
        kwargs["return_dict"] = False
75
76
77
        prompt_raw = tokenizer.apply_chat_template(
            conversation=conversation,
            messages=messages,
78
            **params.get_apply_chat_template_kwargs(),
79
80
        )

81
        prompt = self.render_completion(prompt_raw)
82
83
84
85
86
        if mm_data is not None:
            prompt["multi_modal_data"] = mm_data
        if mm_uuids is not None:
            prompt["multi_modal_uuids"] = mm_uuids

87
        return conversation, prompt
88
89
90
91

    async def render_messages_async(
        self,
        messages: list[ChatCompletionMessageParam],
92
93
        params: ChatParams,
    ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
94
95
96
97
98
99
100
        tokenizer = self.get_tokenizer()
        conversation, mm_data, mm_uuids = await parse_chat_messages_async(
            messages,
            self.config,
            content_format="string",
        )

101
        kwargs["return_dict"] = False
102
103
104
        prompt_raw = tokenizer.apply_chat_template(
            conversation=conversation,
            messages=messages,
105
            **params.get_apply_chat_template_kwargs(),
106
107
        )

108
        prompt = self.render_completion(prompt_raw)
109
110
111
112
113
        if mm_data is not None:
            prompt["multi_modal_data"] = mm_data
        if mm_uuids is not None:
            prompt["multi_modal_uuids"] = mm_uuids

114
        return conversation, prompt