deepseekv32.py 5.17 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from pathlib import Path
5
from typing import Any
6
7
8

from transformers import BatchEncoding

9
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
10

11
12
13
from .deepseek_v32_encoding import encode_messages
from .hf import CachedHfTokenizer
from .protocol import TokenizerLike
14
15


16
class DeepseekV32Tokenizer(CachedHfTokenizer):
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    @classmethod
    def from_pretrained(
        cls,
        path_or_repo_id: str | Path,
        *args,
        trust_remote_code: bool = False,
        revision: str | None = None,
        download_dir: str | None = None,
        **kwargs,
    ) -> "TokenizerLike":
        tokenizer = super().from_pretrained(
            path_or_repo_id,
            *args,
            trust_remote_code=trust_remote_code,
            revision=revision,
            download_dir=download_dir,
            **kwargs,
        )
        return DeepseekV32Tokenizer(tokenizer)

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    def __init__(self, tokenizer: TokenizerLike) -> None:
        super().__init__()

        self.tokenizer = tokenizer
        self.name_or_path = getattr(tokenizer, "name_or_path", "")

        self._added_vocab = self.tokenizer.get_added_vocab()
        self._added_vocab_size = len(self._added_vocab)

    def apply_chat_template(
        self,
        messages: list["ChatCompletionMessageParam"],
        tools: list[dict[str, Any]] | None = None,
        **kwargs,
    ) -> str | list[int]:
52
53
54
55
        thinking = kwargs.get("thinking", False)
        thinking_mode = "thinking"
        if not thinking:
            thinking_mode = "chat"
56
57
        conversation = kwargs.get("conversation", messages)
        messages = conversation.copy()
58
59
        if tools is not None and len(tools) > 0:
            messages.insert(0, {"role": "system"})
60
            messages[0]["tools"] = tools  # type: ignore[typeddict-unknown-key]
61
62
63
64

        # Historical reasoning content is dropped when a new user message is introduced
        drop_thinking = messages[-1]["role"] == "user"

65
66
        encode_config = dict(thinking_mode=thinking_mode, drop_thinking=drop_thinking)
        prompt_str = encode_messages(messages, **encode_config)  # type: ignore
67
68
69
70
71
72
73
74
75
76
77

        if kwargs.get("tokenize", True):
            tokenizer_kwargs = {
                k: kwargs[k] for k in ("truncation", "max_length") if k in kwargs
            }
            return self.encode(
                prompt_str,
                add_special_tokens=False,
                **tokenizer_kwargs,
            )

78
79
        return prompt_str

80
81
82
    def num_special_tokens_to_add(self) -> int:
        return len(self.encode(""))

83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    @property
    def all_special_tokens(self) -> list[str]:
        return self.tokenizer.all_special_tokens

    @property
    def all_special_ids(self) -> list[int]:
        return self.tokenizer.all_special_ids

    @property
    def bos_token_id(self) -> int:
        return self.tokenizer.bos_token_id

    @property
    def eos_token_id(self) -> int:
        return self.tokenizer.eos_token_id

    @property
    def pad_token_id(self) -> int:
        return self.tokenizer.pad_token_id

    @property
    def is_fast(self) -> bool:
        return self.tokenizer.is_fast

    @property
    def vocab_size(self) -> int:
        return self.tokenizer.vocab_size

    @property
    def max_token_id(self) -> int:
        return self.tokenizer.max_token_id

    @property
    def truncation_side(self) -> str:
        return self.tokenizer.truncation_side

    def __hash__(self) -> int:
        return hash(id(self))

    def __len__(self) -> int:
        # </think> is an added token in DeepseekV32 tokenizer
124
        return self.vocab_size + self._added_vocab_size
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

    def __call__(
        self,
        text: str | list[str],
        text_pair: str | None = None,
        add_special_tokens: bool = True,
        truncation: bool = False,
        max_length: int | None = None,
    ) -> "BatchEncoding":
        return self.tokenizer(
            text,
            text_pair=text_pair,
            add_special_tokens=add_special_tokens,
            truncation=truncation,
            max_length=max_length,
        )

    def get_vocab(self) -> dict[str, int]:
        return self.tokenizer.get_vocab()

    def get_added_vocab(self) -> dict[str, int]:
146
        return self._added_vocab.copy()
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175

    def encode(
        self,
        text: str,
        truncation: bool | None = None,
        max_length: int | None = None,
        add_special_tokens: bool = True,
    ) -> list[int]:
        return self.tokenizer.encode(
            text,
            truncation=truncation,
            max_length=max_length,
            add_special_tokens=add_special_tokens,
        )

    def convert_tokens_to_string(self, tokens: list[str]) -> str:
        return self.tokenizer.convert_tokens_to_string(tokens)

    def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:
        return self.tokenizer.decode(ids, skip_special_tokens=skip_special_tokens)

    def convert_ids_to_tokens(
        self,
        ids: list[int],
        skip_special_tokens: bool = False,
    ) -> list[str]:
        return self.tokenizer.convert_ids_to_tokens(
            ids, skip_special_tokens=skip_special_tokens
        )