kimi_audio.py 14.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tokenizer for Kimi-Audio using TikToken."""

import contextlib
import json
from pathlib import Path
from typing import Any, overload

import pybase64
import tiktoken
from huggingface_hub import hf_hub_download
from transformers import AddedToken, BatchEncoding
from transformers.utils import chat_template_utils as hf_chat_utils

from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.logger import init_logger
from vllm.tokenizers.protocol import TokenizerLike

logger = init_logger(__name__)


def _load_tiktoken_encoding(
    vocab_file: Path, special_tokens: dict[str, int]
) -> tuple[Any, dict[str, int]]:
    """Load TikToken encoding from vocab file."""
    mergeable_ranks: dict[bytes, int] = {}
    with open(vocab_file, encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split()
            if len(parts) == 2:
                token_b64 = parts[0]
                rank = int(parts[1])
                token_bytes = pybase64.b64decode(token_b64)
                mergeable_ranks[token_bytes] = rank

    tokenizer = tiktoken.Encoding(
        name=str(vocab_file),
        pat_str=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}|"""
        r""" ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
        mergeable_ranks=mergeable_ranks,
        special_tokens=special_tokens,
    )

    return tokenizer, special_tokens


class KimiAudioTokenizer(TokenizerLike):
    """TikToken tokenizer for Kimi-Audio."""

    @classmethod
    def from_pretrained(
        cls,
        path_or_repo_id: str | Path,
        *args,
        trust_remote_code: bool = False,
        revision: str | None = None,
        download_dir: str | None = None,
        **kwargs,
    ) -> "KimiAudioTokenizer":
        if args:
            logger.debug_once("Ignoring extra positional args for KimiAudioTokenizer.")

        path = Path(path_or_repo_id)
        if path.is_file():
            vocab_file = path
        elif path.is_dir():
            vocab_file = path / "tiktoken.model"
            if not vocab_file.is_file():
                vocab_file = path / "tokenizer.model"
        else:
            # Download from HuggingFace Hub
            repo_id = str(path_or_repo_id)

            # Try to download tiktoken.model or tokenizer.model
            try:
                vocab_path = hf_hub_download(
                    repo_id=repo_id,
                    filename="tiktoken.model",
                    revision=revision,
                    local_dir=download_dir,
                )
                vocab_file = Path(vocab_path)
            except Exception:
                try:
                    vocab_path = hf_hub_download(
                        repo_id=repo_id,
                        filename="tokenizer.model",
                        revision=revision,
                        local_dir=download_dir,
                    )
                    vocab_file = Path(vocab_path)
                except Exception as exc:
                    raise ValueError(
                        f"Could not find tiktoken.model or tokenizer.model in {repo_id}"
                    ) from exc

            # Also download tokenizer_config.json if available
            with contextlib.suppress(Exception):
                hf_hub_download(
                    repo_id=repo_id,
                    filename="tokenizer_config.json",
                    revision=revision,
                    local_dir=download_dir,
                )

        if not vocab_file.is_file():
            raise FileNotFoundError(f"tiktoken.model not found at {vocab_file}.")

        return cls(
            vocab_file=vocab_file,
            name_or_path=str(path_or_repo_id),
            truncation_side=kwargs.get("truncation_side", "left"),
        )

    def __init__(
        self,
        *,
        vocab_file: Path,
        name_or_path: str,
        truncation_side: str,
    ) -> None:
        super().__init__()
        self.name_or_path = name_or_path
        self._truncation_side = truncation_side
        self._vocab_file = vocab_file

        # Load special tokens from tokenizer_config.json
        special_tokens: dict[str, int] = {}
        tokenizer_config = vocab_file.parent / "tokenizer_config.json"
        if tokenizer_config.is_file():
            with open(tokenizer_config, encoding="utf-8") as f:
                config = json.load(f)
                # Extract special tokens from added_tokens_decoder
                added_tokens = config.get("added_tokens_decoder", {})
                for token_id_str, token_info in added_tokens.items():
                    token_id = int(token_id_str)
                    content = token_info.get("content", "")
                    if content:
                        special_tokens[content] = token_id

        self._tokenizer, self._special_tokens = _load_tiktoken_encoding(
            vocab_file, special_tokens
        )

        # Build token <-> ID mappings
        self._token_to_id: dict[str, int] = {}
        self._id_to_token: dict[int, str] = {}
        for token_bytes, token_id in self._tokenizer._mergeable_ranks.items():
            token_str = token_bytes.decode("utf-8", errors="replace")
            self._token_to_id[token_str] = token_id
            self._id_to_token[token_id] = token_str

        # Initialize added_tokens_decoder before adding special tokens
        self._added_tokens_decoder: dict[int, Any] = {}

        # Add Kimi-Audio special tokens
        self._add_kimiaudio_special_tokens()

        # Set default special token IDs (will be updated when special tokens are added)
        self._bos_token_id = 151643  # Kimi-Audio BOS
        self._eos_token_id = 151644  # Kimi-Audio EOS
        self._pad_token_id = self._eos_token_id
        self._unk_token_id = self._pad_token_id

        self._max_chars_per_token = max(
            (len(tok) for tok in self._token_to_id), default=10
        )

    def _add_kimiaudio_special_tokens(self) -> None:
        """Add Kimi-Audio special tokens to the tokenizer."""
        # Tokens should already be in self._special_tokens from tokenizer_config.json
        # Just add them to added_tokens_decoder for compatibility
        kimiaudio_special_tokens = {
            "<|im_media_begin|>": 151661,
            "<|im_media_end|>": 151663,
            "<|im_kimia_text_blank|>": 151666,
            "<|im_msg_end|>": 151645,
            "<|im_kimia_user_msg_start|>": 151670,
            "<|im_kimia_assistant_msg_start|>": 151671,
        }

        for token_str, token_id in kimiaudio_special_tokens.items():
            # Only add if not already present
            if token_id not in self._added_tokens_decoder:
                self._added_tokens_decoder[token_id] = AddedToken(
                    token_str, single_word=True, normalized=False, special=True
                )
                # Also ensure it's in _token_to_id and _id_to_token
                if token_str not in self._token_to_id:
                    self._token_to_id[token_str] = token_id
                if token_id not in self._id_to_token:
                    self._id_to_token[token_id] = token_str

    def num_special_tokens_to_add(self) -> int:
        return 0

    @property
    def all_special_tokens(self) -> list[str]:
        return list(self._added_tokens_decoder.values())

    @property
    def all_special_ids(self) -> list[int]:
        return list(self._added_tokens_decoder.keys())

    @property
    def bos_token_id(self) -> int:
        return self._bos_token_id

    @property
    def eos_token_id(self) -> int:
        return self._eos_token_id

    @property
    def pad_token_id(self) -> int:
        return self._pad_token_id

    @property
    def is_fast(self) -> bool:
        return False

    @property
    def vocab_size(self) -> int:
        return self._tokenizer.n_vocab

    @property
    def max_token_id(self) -> int:
        return self._tokenizer.n_vocab - 1

    @property
    def max_chars_per_token(self) -> int:
        return self._max_chars_per_token

    @property
    def truncation_side(self) -> str:
        return self._truncation_side

    @property
    def added_tokens_decoder(self) -> dict[int, Any]:
        return self._added_tokens_decoder

    @added_tokens_decoder.setter
    def added_tokens_decoder(self, value: dict[int, Any]) -> None:
        """Set added tokens decoder and update special token IDs."""
        self._added_tokens_decoder = value
        # Update special token IDs if known tokens are added
        for token_id, token in value.items():
            token_str = str(token) if hasattr(token, "__str__") else token
            if "<|im_kimia_user_msg_start|>" in token_str:
                self._bos_token_id = token_id
            elif "<|im_msg_end|>" in token_str or "<|im_end|>" in token_str:
                self._eos_token_id = token_id

    def get_vocab(self) -> dict[str, int]:
        return dict(self._token_to_id)

    def __len__(self) -> int:
        """Return vocab size for compatibility with HF tokenizer interface."""
        return self._tokenizer.n_vocab

    def get_added_vocab(self) -> dict[str, int]:
        return {
            str(token): token_id
            for token_id, token in self._added_tokens_decoder.items()
        }

    def _maybe_truncate(self, tokens: list[int], max_length: int | None) -> list[int]:
        if max_length is None or len(tokens) <= max_length:
            return tokens
        if self.truncation_side == "left":
            return tokens[-max_length:]
        return tokens[:max_length]

    def encode(
        self,
        text: str,
        truncation: bool | None = None,
        max_length: int | None = None,
        add_special_tokens: bool = True,
        **kwargs,
    ) -> list[int]:
        del add_special_tokens
        # Allow Kimi-Audio special tokens to be encoded
        tokens = self._tokenizer.encode(
            text,
            allowed_special={
                "<|im_media_begin|>",
                "<|im_media_end|>",
                "<|im_kimia_text_blank|>",
                "<|im_msg_end|>",
                "<|im_kimia_user_msg_start|>",
                "<|im_kimia_assistant_msg_start|>",
            },
        )
        if truncation:
            tokens = self._maybe_truncate(tokens, max_length)
        return tokens

    def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:
        """Decode token IDs to text, optionally skipping special tokens."""
        if isinstance(ids, int):
            ids = [ids]
        if skip_special_tokens:
            # Skip tokens that are in special_tokens (loaded from config)
            special_ids = set(self._special_tokens.values())
            ids = [token_id for token_id in ids if token_id not in special_ids]
        return self._tokenizer.decode(ids)

    @overload
    def convert_tokens_to_ids(self, tokens: str) -> int: ...

    @overload
    def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ...

    def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
        if isinstance(tokens, str):
            return self._token_to_id.get(tokens, self._unk_token_id)
        return [self._token_to_id.get(token, self._unk_token_id) for token in tokens]

    def convert_ids_to_tokens(
        self, ids: list[int], skip_special_tokens: bool = False
    ) -> list[str]:
        tokens = []
        for token_id in ids:
            if skip_special_tokens and token_id in self._added_tokens_decoder:
                continue
            tokens.append(self._id_to_token.get(token_id, "<|unk|>"))
        return tokens

    def convert_tokens_to_string(self, tokens: list[str]) -> str:
        token_ids = self.convert_tokens_to_ids(tokens)
        return self.decode(token_ids, skip_special_tokens=False)

    def __call__(
        self,
        text: str | list[str],
        text_pair: str | None = None,
        add_special_tokens: bool = True,
        truncation: bool = False,
        max_length: int | None = None,
        **kwargs,
    ) -> BatchEncoding:
        if text_pair is not None:
            raise NotImplementedError(
                "text_pair is not supported for KimiAudioTokenizer."
            )

        if isinstance(text, list):
            input_ids_batch: list[list[int]] = [
                self.encode(
                    item,
                    truncation=truncation,
                    max_length=max_length,
                    add_special_tokens=add_special_tokens,
                )
                for item in text
            ]
            attention_mask_batch = [[1] * len(ids) for ids in input_ids_batch]
            return BatchEncoding(
                {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}
            )

        input_ids = self.encode(
            text,
            truncation=truncation,
            max_length=max_length,
            add_special_tokens=add_special_tokens,
        )
        attention_mask = [1] * len(input_ids)
        return BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask})

    def get_chat_template(
        self, chat_template: str | None, tools: list[dict[str, Any]] | None = None
    ) -> str | None:
        del tools
        return chat_template

    def apply_chat_template(
        self,
        messages: list[ChatCompletionMessageParam] | None = None,
        tools: list[dict[str, Any]] | None = None,
        chat_template: str | None = None,
        tokenize: bool = False,
        **kwargs,
    ) -> str | list[int]:
        # Handle both 'messages' (protocol) and 'conversation' (caller) parameter names
        conversation = messages if messages is not None else kwargs.get("conversation")
        if conversation is None:
            raise ValueError("Either 'messages' or 'conversation' must be provided.")
        template = self.get_chat_template(chat_template, tools=tools)
        if template is None:
            raise ValueError(
                "No chat template available. Provide `chat_template` explicitly."
            )
        # Use render_jinja_template instead of apply_chat_template
        # Note: render_jinja_template returns ([prompts], [generation_indices])
        rendered, _ = hf_chat_utils.render_jinja_template(
            conversation,
            chat_template=template,
            tools=tools,
            **kwargs,
        )
        # Extract the first (and usually only) prompt
        prompt = rendered[0] if rendered else ""
        if tokenize:
            return self.encode(prompt, add_special_tokens=False)
        return prompt