deepseek_v32.py 5.28 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from pathlib import Path
5
from typing import Any
6
7
8

from transformers import BatchEncoding

9
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
10

11
12
13
from .deepseek_v32_encoding import encode_messages
from .hf import CachedHfTokenizer
from .protocol import TokenizerLike
14
15


16
class DeepseekV32Tokenizer(CachedHfTokenizer):
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    @classmethod
    def from_pretrained(
        cls,
        path_or_repo_id: str | Path,
        *args,
        trust_remote_code: bool = False,
        revision: str | None = None,
        download_dir: str | None = None,
        **kwargs,
    ) -> "TokenizerLike":
        tokenizer = super().from_pretrained(
            path_or_repo_id,
            *args,
            trust_remote_code=trust_remote_code,
            revision=revision,
            download_dir=download_dir,
            **kwargs,
        )
        return DeepseekV32Tokenizer(tokenizer)

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    def __init__(self, tokenizer: TokenizerLike) -> None:
        super().__init__()

        self.tokenizer = tokenizer
        self.name_or_path = getattr(tokenizer, "name_or_path", "")

        self._added_vocab = self.tokenizer.get_added_vocab()
        self._added_vocab_size = len(self._added_vocab)

    def apply_chat_template(
        self,
        messages: list["ChatCompletionMessageParam"],
        tools: list[dict[str, Any]] | None = None,
        **kwargs,
    ) -> str | list[int]:
52
        thinking = kwargs.get("thinking", False)
53
54
        enable_thinking = kwargs.get("enable_thinking", False)
        thinking = thinking or enable_thinking
55
56
57
        thinking_mode = "thinking"
        if not thinking:
            thinking_mode = "chat"
58
59
        conversation = kwargs.get("conversation", messages)
        messages = conversation.copy()
60
61
        if tools is not None and len(tools) > 0:
            messages.insert(0, {"role": "system"})
62
            messages[0]["tools"] = tools  # type: ignore[typeddict-unknown-key]
63
64
65
66

        # Historical reasoning content is dropped when a new user message is introduced
        drop_thinking = messages[-1]["role"] == "user"

67
68
        encode_config = dict(thinking_mode=thinking_mode, drop_thinking=drop_thinking)
        prompt_str = encode_messages(messages, **encode_config)  # type: ignore
69
70
71
72
73
74
75
76
77
78
79

        if kwargs.get("tokenize", True):
            tokenizer_kwargs = {
                k: kwargs[k] for k in ("truncation", "max_length") if k in kwargs
            }
            return self.encode(
                prompt_str,
                add_special_tokens=False,
                **tokenizer_kwargs,
            )

80
81
        return prompt_str

82
83
84
    def num_special_tokens_to_add(self) -> int:
        return len(self.encode(""))

85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    @property
    def all_special_tokens(self) -> list[str]:
        return self.tokenizer.all_special_tokens

    @property
    def all_special_ids(self) -> list[int]:
        return self.tokenizer.all_special_ids

    @property
    def bos_token_id(self) -> int:
        return self.tokenizer.bos_token_id

    @property
    def eos_token_id(self) -> int:
        return self.tokenizer.eos_token_id

    @property
    def pad_token_id(self) -> int:
        return self.tokenizer.pad_token_id

    @property
    def is_fast(self) -> bool:
        return self.tokenizer.is_fast

    @property
    def vocab_size(self) -> int:
        return self.tokenizer.vocab_size

    @property
    def max_token_id(self) -> int:
        return self.tokenizer.max_token_id

    @property
    def truncation_side(self) -> str:
        return self.tokenizer.truncation_side

    def __hash__(self) -> int:
        return hash(id(self))

    def __len__(self) -> int:
        # </think> is an added token in DeepseekV32 tokenizer
126
        return self.vocab_size + self._added_vocab_size
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147

    def __call__(
        self,
        text: str | list[str],
        text_pair: str | None = None,
        add_special_tokens: bool = True,
        truncation: bool = False,
        max_length: int | None = None,
    ) -> "BatchEncoding":
        return self.tokenizer(
            text,
            text_pair=text_pair,
            add_special_tokens=add_special_tokens,
            truncation=truncation,
            max_length=max_length,
        )

    def get_vocab(self) -> dict[str, int]:
        return self.tokenizer.get_vocab()

    def get_added_vocab(self) -> dict[str, int]:
148
        return self._added_vocab.copy()
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

    def encode(
        self,
        text: str,
        truncation: bool | None = None,
        max_length: int | None = None,
        add_special_tokens: bool = True,
    ) -> list[int]:
        return self.tokenizer.encode(
            text,
            truncation=truncation,
            max_length=max_length,
            add_special_tokens=add_special_tokens,
        )

    def convert_tokens_to_string(self, tokens: list[str]) -> str:
        return self.tokenizer.convert_tokens_to_string(tokens)

    def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:
        return self.tokenizer.decode(ids, skip_special_tokens=skip_special_tokens)

    def convert_ids_to_tokens(
        self,
        ids: list[int],
        skip_special_tokens: bool = False,
    ) -> list[str]:
        return self.tokenizer.convert_ids_to_tokens(
            ids, skip_special_tokens=skip_special_tokens
        )