grok2.py 14.2 KB
Newer Older
Bijaya Dangol's avatar
Bijaya Dangol committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tokenizer for Grok-2 .tok.json format."""

import functools
import json
from collections.abc import Collection, Set
from pathlib import Path
from typing import Any, Literal, overload

from huggingface_hub import hf_hub_download
from huggingface_hub.utils import (
    EntryNotFoundError,
    HfHubHTTPError,
    RepositoryNotFoundError,
    RevisionNotFoundError,
)
from transformers import BatchEncoding
from transformers.utils import chat_template_utils as hf_chat_utils

from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.logger import init_logger

from .protocol import TokenizerLike

logger = init_logger(__name__)

PAD = "<|pad|>"
EOS = "<|eos|>"
SEP = "<|separator|>"
RESERVED_TOKEN_TEXTS = [f"<|reserved_{i}|>" for i in range(3, 128)]
CONTROL_TOKEN_TEXTS = [f"<|control{i}|>" for i in range(1, 705)]
DEFAULT_SPECIAL_TOKENS = [PAD, SEP, EOS]
DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": SEP, "eos": EOS}
DEFAULT_CHAT_TEMPLATE = (
    "{% for message in messages %}"
    "{% if message['role'] == 'user' %}"
    "{{ 'Human: ' + message['content'].strip() + '<|separator|>\\n\\n' }}"
    "{% elif message['role'] == 'system' %}"
    "{{ 'System: ' + message['content'].strip() + '<|separator|>\\n\\n' }}"
    "{% elif message['role'] == 'assistant' %}"
    "{{ 'Assistant: ' + message['content'] + '<|separator|>\\n\\n' }}"
    "{% endif %}"
    "{% endfor %}"
    "{% if add_generation_prompt %}"
    "{{ 'Assistant:' }}"
    "{% endif %}"
)

# Default + separate each single digit.
PAT_STR_B = (
    r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}|"""
    r""" ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
)


def _maybe_load_tokenizer_config(
    model_path: Path,
    *,
    repo_id: str | None,
    revision: str | None,
    download_dir: str | None,
) -> dict[str, Any]:
    config_path = model_path / "tokenizer_config.json"
    if config_path.is_file():
        with config_path.open("r", encoding="utf-8") as f:
            return json.load(f)

    if repo_id is None:
        return {}

    try:
        config_file = hf_hub_download(
            repo_id=repo_id,
            filename="tokenizer_config.json",
            revision=revision,
            cache_dir=download_dir,
        )
    except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError):
        # If the repo, revision, or file does not exist, fall back silently.
        return {}
    except HfHubHTTPError as exc:
        logger.warning(
            "Failed to download tokenizer_config.json from %s. "
            "This may be due to a network or authentication issue. "
            "The default chat template will be used. Error: %s",
            repo_id,
            exc,
        )
        return {}

    try:
        with Path(config_file).open("r", encoding="utf-8") as f:
            return json.load(f)
    except json.JSONDecodeError as exc:
        logger.warning(
            "Failed to parse tokenizer_config.json. "
            "The default chat template will be used. Error: %s",
            exc,
        )
        return {}
    except OSError as exc:
        logger.warning(
            "Failed to open tokenizer_config.json. "
            "The default chat template will be used. Error: %s",
            exc,
        )
        return {}


def _load_tiktoken_encoding(
    vocab_file: Path,
) -> tuple[Any, dict[str, int]]:
    try:
        import tiktoken
    except ImportError as exc:
        raise ImportError("Grok-2 tokenizer requires the `tiktoken` package.") from exc

    with vocab_file.open("rb") as f:
        xtok_dict = json.load(f)

    mergeable_ranks = {
        bytes(item["bytes"]): item["token"]
        for item in xtok_dict.get("regular_tokens", [])
    }
    special_tokens = {
        bytes(item["bytes"]).decode("utf-8", errors="replace"): item["token"]
        for item in xtok_dict.get("special_tokens", [])
    }

    if xtok_dict.get("word_split") == "V1":
        pat_str = PAT_STR_B
    else:
        raise ValueError(f"Unknown word_split: {xtok_dict.get('word_split')!r}")

    pat_str = xtok_dict.get("pat_str", pat_str)

    kwargs = {
        "name": str(vocab_file),
        "pat_str": pat_str,
        "mergeable_ranks": mergeable_ranks,
        "special_tokens": special_tokens,
    }

    if "vocab_size" in xtok_dict:
        kwargs["explicit_n_vocab"] = xtok_dict["vocab_size"]

    tokenizer = tiktoken.Encoding(**kwargs)

    default_allowed_special: set[str] | None = None
    if "default_allowed_special" in xtok_dict:
        default_allowed_special = {
            bytes(bytes_list).decode("utf-8", errors="replace")
            for bytes_list in xtok_dict["default_allowed_special"]
        }

    tokenizer._default_allowed_special = default_allowed_special or set()
    tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS

    def encode_patched(
        self,
        text: str,
        *,
        allowed_special: Literal["all"] | Set[str] = set(),
        disallowed_special: Literal["all"] | Collection[str] = "all",
    ) -> list[int]:
        del disallowed_special
        if isinstance(allowed_special, set):
            allowed_special |= self._default_allowed_special
        return tiktoken.Encoding.encode(
            self,
            text,
            allowed_special=allowed_special,
            disallowed_special=(),
        )

    tokenizer.encode = functools.partial(encode_patched, tokenizer)
    tokenizer._default_allowed_special |= set(DEFAULT_CONTROL_TOKENS.values())
    tokenizer._default_allowed_special |= set(
        CONTROL_TOKEN_TEXTS + RESERVED_TOKEN_TEXTS
    )

    return tokenizer, special_tokens


class Grok2Tokenizer(TokenizerLike):
    @classmethod
    def from_pretrained(
        cls,
        path_or_repo_id: str | Path,
        *args,
        trust_remote_code: bool = False,
        revision: str | None = None,
        download_dir: str | None = None,
        **kwargs,
    ) -> "Grok2Tokenizer":
        if args:
            logger.debug_once("Ignoring extra positional args for Grok2Tokenizer.")

        path = Path(path_or_repo_id)
        if path.is_file():
            vocab_file = path
            model_path = path.parent
            repo_id = None
        elif path.is_dir():
            vocab_file = path / "tokenizer.tok.json"
            model_path = path
            repo_id = None
        else:
            vocab_file = Path(
                hf_hub_download(
                    repo_id=str(path_or_repo_id),
                    filename="tokenizer.tok.json",
                    revision=revision,
                    cache_dir=download_dir,
                )
            )
            model_path = vocab_file.parent
            repo_id = str(path_or_repo_id)

        if not vocab_file.is_file():
            raise FileNotFoundError(f"tokenizer.tok.json not found at {vocab_file}.")

        config = _maybe_load_tokenizer_config(
            model_path,
            repo_id=repo_id,
            revision=revision,
            download_dir=download_dir,
        )

        return cls(
            vocab_file=vocab_file,
            name_or_path=str(path_or_repo_id),
            truncation_side=kwargs.get("truncation_side", "left"),
            chat_template=config.get("chat_template"),
            init_kwargs=config,
        )

    def __init__(
        self,
        *,
        vocab_file: Path,
        name_or_path: str,
        truncation_side: str,
        chat_template: str | None,
        init_kwargs: dict[str, Any] | None = None,
    ) -> None:
        super().__init__()
        self.name_or_path = name_or_path
        self._truncation_side = truncation_side
        self.init_kwargs = init_kwargs or {}
        self._chat_template = chat_template or DEFAULT_CHAT_TEMPLATE

        self._tokenizer, self._special_tokens = _load_tiktoken_encoding(vocab_file)

        self._token_to_id: dict[str, int] = {}
        self._id_to_token: dict[int, str] = {}
        for token, token_id in self._tokenizer._mergeable_ranks.items():
            token_str = token.decode("utf-8", errors="replace")
            self._token_to_id[token_str] = token_id
            self._id_to_token[token_id] = token_str

        for token, token_id in self._special_tokens.items():
            self._token_to_id[token] = token_id
            self._id_to_token[token_id] = token

        bos_token_id = self._special_tokens.get(SEP)
        if bos_token_id is None:
            bos_token_id = self._special_tokens.get(PAD)
        if bos_token_id is None:
            bos_token_id = self._special_tokens.get(EOS)
        if bos_token_id is None:
            bos_token_id = 0
        self._bos_token_id = bos_token_id

        self._eos_token_id = self._special_tokens.get(EOS, self._bos_token_id)
        self._pad_token_id = self._special_tokens.get(PAD, self._eos_token_id)
        self._unk_token_id = self._pad_token_id

280
281
        self._max_chars_per_token = max(len(tok) for tok in self._token_to_id)

Bijaya Dangol's avatar
Bijaya Dangol committed
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
    def num_special_tokens_to_add(self) -> int:
        return 0

    @property
    def all_special_tokens(self) -> list[str]:
        return list(self._special_tokens.keys())

    @property
    def all_special_ids(self) -> list[int]:
        return list(self._special_tokens.values())

    @property
    def bos_token_id(self) -> int:
        return self._bos_token_id

    @property
    def eos_token_id(self) -> int:
        return self._eos_token_id

    @property
    def pad_token_id(self) -> int:
        return self._pad_token_id

    @property
    def is_fast(self) -> bool:
        return False

    @property
    def vocab_size(self) -> int:
        return self._tokenizer.n_vocab

    @property
    def max_token_id(self) -> int:
        return self._tokenizer.n_vocab - 1

317
318
319
320
    @property
    def max_chars_per_token(self) -> int:
        return self._max_chars_per_token

Bijaya Dangol's avatar
Bijaya Dangol committed
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
    @property
    def truncation_side(self) -> str:
        return self._truncation_side

    def get_vocab(self) -> dict[str, int]:
        return dict(self._token_to_id)

    def get_added_vocab(self) -> dict[str, int]:
        return dict(self._special_tokens)

    def _maybe_truncate(self, tokens: list[int], max_length: int | None) -> list[int]:
        if max_length is None or len(tokens) <= max_length:
            return tokens
        if self.truncation_side == "left":
            return tokens[-max_length:]
        return tokens[:max_length]

    def encode(
        self,
        text: str,
        truncation: bool | None = None,
        max_length: int | None = None,
        add_special_tokens: bool = True,
    ) -> list[int]:
        del add_special_tokens
        tokens = self._tokenizer.encode(text)
        if truncation:
            tokens = self._maybe_truncate(tokens, max_length)
        return tokens

    def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:
        if isinstance(ids, int):
            ids = [ids]
        if skip_special_tokens:
            ids = [
                token_id
                for token_id in ids
                if token_id not in self._special_tokens.values()
            ]
        return self._tokenizer.decode(ids)

    @overload
    def convert_tokens_to_ids(self, tokens: str) -> int: ...

    @overload
    def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ...

    def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
        if isinstance(tokens, str):
            return self._token_to_id.get(tokens, self._unk_token_id)
        return [self._token_to_id.get(token, self._unk_token_id) for token in tokens]

    def convert_ids_to_tokens(
        self, ids: list[int], skip_special_tokens: bool = False
    ) -> list[str]:
        tokens = []
        for token_id in ids:
            if skip_special_tokens and token_id in self._special_tokens.values():
                continue
            tokens.append(self._id_to_token.get(token_id, "<|unk|>"))
        return tokens

    def convert_tokens_to_string(self, tokens: list[str]) -> str:
        token_ids = self.convert_tokens_to_ids(tokens)
        return self.decode(token_ids, skip_special_tokens=False)

    def __call__(
        self,
        text: str | list[str],
        text_pair: str | None = None,
        add_special_tokens: bool = True,
        truncation: bool = False,
        max_length: int | None = None,
    ) -> BatchEncoding:
        if text_pair is not None:
            raise NotImplementedError("text_pair is not supported for Grok2Tokenizer.")

        if isinstance(text, list):
            input_ids_batch: list[list[int]] = [
                self.encode(
                    item,
                    truncation=truncation,
                    max_length=max_length,
                    add_special_tokens=add_special_tokens,
                )
                for item in text
            ]
            attention_mask_batch = [[1] * len(ids) for ids in input_ids_batch]
            return BatchEncoding(
                {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}
            )

        input_ids = self.encode(
            text,
            truncation=truncation,
            max_length=max_length,
            add_special_tokens=add_special_tokens,
        )
        attention_mask = [1] * len(input_ids)
        return BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask})

    def get_chat_template(
        self, chat_template: str | None, tools: list[dict[str, Any]] | None = None
    ) -> str | None:
        del tools
        return chat_template or self._chat_template

    def apply_chat_template(
        self,
        messages: list[ChatCompletionMessageParam],
        tools: list[dict[str, Any]] | None = None,
        chat_template: str | None = None,
        tokenize: bool = False,
        **kwargs,
    ) -> str | list[int]:
        template = self.get_chat_template(chat_template, tools=tools)
        if template is None:
            raise ValueError(
                "No chat template available. Provide `chat_template` explicitly."
            )
441
        kwargs["return_dict"] = False
Bijaya Dangol's avatar
Bijaya Dangol committed
442
443
444
445
446
447
448
449
450
        prompt = hf_chat_utils.apply_chat_template(
            conversation=messages,
            chat_template=template,
            tools=tools,
            **kwargs,
        )
        if tokenize:
            return self.encode(prompt, add_special_tokens=False)
        return prompt