Unverified Commit 42fadebe authored by tunglinwood's avatar tunglinwood Committed by GitHub
Browse files

[Model] Add support for moonshotai/Kimi-Audio-7B-Instruct (#36127)


Signed-off-by: default avatartunglinwood <tunglinwood@gmail.com>
Signed-off-by: default avatartunglinwood <tomwu.tunglin@gmail.com>
Signed-off-by: default avatartunglinwood <113751333+tunglinwood@users.noreply.github.com>
parent a197eda9
...@@ -713,8 +713,9 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen ...@@ -713,8 +713,9 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `KananaVForConditionalGeneration` | Kanana-V | T + I<sup>+</sup> | `kakaocorp/kanana-1.5-v-3b-instruct`, etc. | | ✅︎ | | `KananaVForConditionalGeneration` | Kanana-V | T + I<sup>+</sup> | `kakaocorp/kanana-1.5-v-3b-instruct`, etc. | | ✅︎ |
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ | | `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ |
| `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ | | `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ |
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | | `KimiAudioForConditionalGeneration` | Kimi-Audio | T + A<sup>+</sup> | `moonshotai/Kimi-Audio-7B-Instruct` | | ✅︎ |
| `KimiK25ForConditionalGeneration` | Kimi-K2.5 | T + I<sup>+</sup> | `moonshotai/Kimi-K2.5` | | ✅︎ | | `KimiK25ForConditionalGeneration` | Kimi-K2.5 | T + I<sup>+</sup> | `moonshotai/Kimi-K2.5` | | ✅︎ |
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ |
| `LightOnOCRForConditionalGeneration` | LightOnOCR-1B | T + I<sup>+</sup> | `lightonai/LightOnOCR-1B`, etc | ✅︎ | ✅︎ | | `LightOnOCRForConditionalGeneration` | LightOnOCR-1B | T + I<sup>+</sup> | `lightonai/LightOnOCR-1B`, etc | ✅︎ | ✅︎ |
| `Lfm2VlForConditionalGeneration` | LFM2-VL | T + I<sup>+</sup> | `LiquidAI/LFM2-VL-450M`, `LiquidAI/LFM2-VL-3B`, `LiquidAI/LFM2-VL-8B-A1B`, etc. | ✅︎ | ✅︎ | | `Lfm2VlForConditionalGeneration` | LFM2-VL | T + I<sup>+</sup> | `LiquidAI/LFM2-VL-450M`, `LiquidAI/LFM2-VL-3B`, `LiquidAI/LFM2-VL-8B-A1B`, etc. | ✅︎ | ✅︎ |
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | ✅︎ | ✅︎ | | `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | ✅︎ | ✅︎ |
......
...@@ -201,6 +201,34 @@ def run_granite_speech(question: str, audio_count: int) -> ModelRequestData: ...@@ -201,6 +201,34 @@ def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
) )
# Kimi-Audio-7B-Instruct
def run_kimi_audio(question: str, audio_count: int) -> ModelRequestData:
"""Kimi-Audio-7B-Instruct for audio transcription and understanding."""
model_name = "moonshotai/Kimi-Audio-7B-Instruct"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
)
# Kimi-Audio uses <|im_kimia_text_blank|> as placeholder for audio features
audio_placeholder = "<|im_kimia_text_blank|>" * audio_count
# Default prompt for transcription
if not question:
question = "Please transcribe the audio"
prompt = f"{audio_placeholder}{question}"
# Stop at EOS token (151644) to prevent repetition
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
stop_token_ids=[151644],
)
# MiDashengLM # MiDashengLM
def run_midashenglm(question: str, audio_count: int): def run_midashenglm(question: str, audio_count: int):
model_name = "mispeech/midashenglm-7b" model_name = "mispeech/midashenglm-7b"
...@@ -485,6 +513,7 @@ model_example_map = { ...@@ -485,6 +513,7 @@ model_example_map = {
"glmasr": run_glmasr, "glmasr": run_glmasr,
"funaudiochat": run_funaudiochat, "funaudiochat": run_funaudiochat,
"granite_speech": run_granite_speech, "granite_speech": run_granite_speech,
"kimi_audio": run_kimi_audio,
"midashenglm": run_midashenglm, "midashenglm": run_midashenglm,
"minicpmo": run_minicpmo, "minicpmo": run_minicpmo,
"phi4_mm": run_phi4mm, "phi4_mm": run_phi4mm,
......
...@@ -198,8 +198,12 @@ def get_text_token_prompts( ...@@ -198,8 +198,12 @@ def get_text_token_prompts(
mm_counts, mm_counts,
mm_options={}, mm_options={},
) )
# Some models (e.g., Kimi-Audio) return token IDs directly instead of str
if isinstance(inputs.prompt, list):
text_prompt = None
token_prompt = inputs.prompt
else:
assert isinstance(inputs.prompt, str) assert isinstance(inputs.prompt, str)
text_prompt = inputs.prompt text_prompt = inputs.prompt
token_prompt = tokenizer.encode( token_prompt = tokenizer.encode(
text_prompt, text_prompt,
......
...@@ -857,6 +857,15 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -857,6 +857,15 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Kwai-Keye/Keye-VL-1_5-8B", "Kwai-Keye/Keye-VL-1_5-8B",
trust_remote_code=True, trust_remote_code=True,
), ),
"MoonshotKimiaForCausalLM": _HfExamplesInfo(
"moonshotai/Kimi-Audio-7B-Instruct",
tokenizer_mode="kimi_audio",
trust_remote_code=True,
),
"KimiK25ForConditionalGeneration": _HfExamplesInfo(
"moonshotai/Kimi-K2.5",
trust_remote_code=True,
),
"KimiVLForConditionalGeneration": _HfExamplesInfo( "KimiVLForConditionalGeneration": _HfExamplesInfo(
"moonshotai/Kimi-VL-A3B-Instruct", "moonshotai/Kimi-VL-A3B-Instruct",
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"},
...@@ -870,10 +879,6 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -870,10 +879,6 @@ _MULTIMODAL_EXAMPLE_MODELS = {
) )
}, },
), ),
"KimiK25ForConditionalGeneration": _HfExamplesInfo(
"moonshotai/Kimi-K2.5",
trust_remote_code=True,
),
"LightOnOCRForConditionalGeneration": _HfExamplesInfo( "LightOnOCRForConditionalGeneration": _HfExamplesInfo(
"lightonai/LightOnOCR-1B-1025" "lightonai/LightOnOCR-1B-1025"
), ),
......
...@@ -103,6 +103,12 @@ def can_initialize( ...@@ -103,6 +103,12 @@ def can_initialize(
"pickle error when loading `transformers.models.auto.CONFIG_MAPPING`" "pickle error when loading `transformers.models.auto.CONFIG_MAPPING`"
) )
if model_arch == "MoonshotKimiaForCausalLM":
pytest.skip(
"Kimi-Audio requires SpeechToTextConfig "
"which is not configured in test environment"
)
if model_arch in ["DeepseekV32ForCausalLM", "GlmMoeDsaForCausalLM"]: if model_arch in ["DeepseekV32ForCausalLM", "GlmMoeDsaForCausalLM"]:
from vllm.platforms import current_platform from vllm.platforms import current_platform
......
This diff is collapsed.
...@@ -421,6 +421,7 @@ _MULTIMODAL_MODELS = { ...@@ -421,6 +421,7 @@ _MULTIMODAL_MODELS = {
"RForConditionalGeneration": ("rvl", "RForConditionalGeneration"), "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501 "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
"KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"), # noqa: E501 "KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"), # noqa: E501
"MoonshotKimiaForCausalLM": ("kimi_audio", "KimiAudioForConditionalGeneration"), # noqa: E501
"LightOnOCRForConditionalGeneration": ( "LightOnOCRForConditionalGeneration": (
"lightonocr", "lightonocr",
"LightOnOCRForConditionalGeneration", "LightOnOCRForConditionalGeneration",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, cast
from vllm.config import VllmConfig
from vllm.tokenizers.kimi_audio import KimiAudioTokenizer
from vllm.tokenizers.registry import get_tokenizer
from .hf import HfRenderer, HfTokenizer
class KimiAudioRenderer(HfRenderer):
"""Renderer for Kimi-Audio models.
This renderer uses HfRenderer internally with a custom TikToken tokenizer.
"""
@classmethod
def from_config( # type: ignore[override]
cls,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> "HfRenderer":
"""Create an HfRenderer instance for Kimi-Audio models."""
model_config = config.model_config
if model_config.skip_tokenizer_init:
tokenizer = None
else:
# Extract tokenizer_name from kwargs (already processed by
# tokenizer_args_from_config for ModelScope/GGUF/etc)
tokenizer_name = tokenizer_kwargs.pop(
"tokenizer_name", model_config.tokenizer
)
# Remove tokenizer_cls from kwargs to avoid duplicate argument
tokenizer_kwargs = {
k: v for k, v in tokenizer_kwargs.items() if k != "tokenizer_cls"
}
# Use get_tokenizer directly instead of cached_get_tokenizer
# (KimiAudioTokenizer doesn't work with get_cached_tokenizer)
tokenizer = cast(
HfTokenizer,
get_tokenizer(
tokenizer_name,
tokenizer_cls=KimiAudioTokenizer, # type: ignore[arg-type]
**tokenizer_kwargs,
),
)
return HfRenderer(config, tokenizer)
...@@ -19,6 +19,7 @@ _VLLM_RENDERERS = { ...@@ -19,6 +19,7 @@ _VLLM_RENDERERS = {
"deepseek_v32": ("deepseek_v32", "DeepseekV32Renderer"), "deepseek_v32": ("deepseek_v32", "DeepseekV32Renderer"),
"hf": ("hf", "HfRenderer"), "hf": ("hf", "HfRenderer"),
"grok2": ("grok2", "Grok2Renderer"), "grok2": ("grok2", "Grok2Renderer"),
"kimi_audio": ("kimi_audio", "KimiAudioRenderer"),
"mistral": ("mistral", "MistralRenderer"), "mistral": ("mistral", "MistralRenderer"),
"qwen_vl": ("qwen_vl", "QwenVLRenderer"), "qwen_vl": ("qwen_vl", "QwenVLRenderer"),
"terratorch": ("terratorch", "TerratorchRenderer"), "terratorch": ("terratorch", "TerratorchRenderer"),
...@@ -74,10 +75,18 @@ RENDERER_REGISTRY = RendererRegistry( ...@@ -74,10 +75,18 @@ RENDERER_REGISTRY = RendererRegistry(
def renderer_from_config(config: "VllmConfig", **kwargs): def renderer_from_config(config: "VllmConfig", **kwargs):
model_config = config.model_config model_config = config.model_config
tokenizer_mode, tokenizer_name, args, kwargs = tokenizer_args_from_config( tokenizer_mode, tokenizer_name, args, kwargs = tokenizer_args_from_config(
model_config, **kwargs model_config, **kwargs
) )
# Override tokenizer_mode for Kimi-Audio models
if model_config.architecture == "MoonshotKimiaForCausalLM":
tokenizer_mode = "kimi_audio"
# Update model_config so other components (e.g., multimodal registry)
# also use the correct tokenizer mode
model_config.tokenizer_mode = "kimi_audio"
if ( if (
model_config.tokenizer_mode == "auto" model_config.tokenizer_mode == "auto"
and model_config.model_impl == "terratorch" and model_config.model_impl == "terratorch"
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tokenizer for Kimi-Audio using TikToken."""
import contextlib
import json
from pathlib import Path
from typing import Any, overload
import pybase64
import tiktoken
from huggingface_hub import hf_hub_download
from transformers import AddedToken, BatchEncoding
from transformers.utils import chat_template_utils as hf_chat_utils
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.logger import init_logger
from vllm.tokenizers.protocol import TokenizerLike
logger = init_logger(__name__)
def _load_tiktoken_encoding(
vocab_file: Path, special_tokens: dict[str, int]
) -> tuple[Any, dict[str, int]]:
"""Load TikToken encoding from vocab file."""
mergeable_ranks: dict[bytes, int] = {}
with open(vocab_file, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split()
if len(parts) == 2:
token_b64 = parts[0]
rank = int(parts[1])
token_bytes = pybase64.b64decode(token_b64)
mergeable_ranks[token_bytes] = rank
tokenizer = tiktoken.Encoding(
name=str(vocab_file),
pat_str=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}|"""
r""" ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
mergeable_ranks=mergeable_ranks,
special_tokens=special_tokens,
)
return tokenizer, special_tokens
class KimiAudioTokenizer(TokenizerLike):
"""TikToken tokenizer for Kimi-Audio."""
@classmethod
def from_pretrained(
cls,
path_or_repo_id: str | Path,
*args,
trust_remote_code: bool = False,
revision: str | None = None,
download_dir: str | None = None,
**kwargs,
) -> "KimiAudioTokenizer":
if args:
logger.debug_once("Ignoring extra positional args for KimiAudioTokenizer.")
path = Path(path_or_repo_id)
if path.is_file():
vocab_file = path
elif path.is_dir():
vocab_file = path / "tiktoken.model"
if not vocab_file.is_file():
vocab_file = path / "tokenizer.model"
else:
# Download from HuggingFace Hub
repo_id = str(path_or_repo_id)
# Try to download tiktoken.model or tokenizer.model
try:
vocab_path = hf_hub_download(
repo_id=repo_id,
filename="tiktoken.model",
revision=revision,
local_dir=download_dir,
)
vocab_file = Path(vocab_path)
except Exception:
try:
vocab_path = hf_hub_download(
repo_id=repo_id,
filename="tokenizer.model",
revision=revision,
local_dir=download_dir,
)
vocab_file = Path(vocab_path)
except Exception as exc:
raise ValueError(
f"Could not find tiktoken.model or tokenizer.model in {repo_id}"
) from exc
# Also download tokenizer_config.json if available
with contextlib.suppress(Exception):
hf_hub_download(
repo_id=repo_id,
filename="tokenizer_config.json",
revision=revision,
local_dir=download_dir,
)
if not vocab_file.is_file():
raise FileNotFoundError(f"tiktoken.model not found at {vocab_file}.")
return cls(
vocab_file=vocab_file,
name_or_path=str(path_or_repo_id),
truncation_side=kwargs.get("truncation_side", "left"),
)
def __init__(
self,
*,
vocab_file: Path,
name_or_path: str,
truncation_side: str,
) -> None:
super().__init__()
self.name_or_path = name_or_path
self._truncation_side = truncation_side
self._vocab_file = vocab_file
# Load special tokens from tokenizer_config.json
special_tokens: dict[str, int] = {}
tokenizer_config = vocab_file.parent / "tokenizer_config.json"
if tokenizer_config.is_file():
with open(tokenizer_config, encoding="utf-8") as f:
config = json.load(f)
# Extract special tokens from added_tokens_decoder
added_tokens = config.get("added_tokens_decoder", {})
for token_id_str, token_info in added_tokens.items():
token_id = int(token_id_str)
content = token_info.get("content", "")
if content:
special_tokens[content] = token_id
self._tokenizer, self._special_tokens = _load_tiktoken_encoding(
vocab_file, special_tokens
)
# Build token <-> ID mappings
self._token_to_id: dict[str, int] = {}
self._id_to_token: dict[int, str] = {}
for token_bytes, token_id in self._tokenizer._mergeable_ranks.items():
token_str = token_bytes.decode("utf-8", errors="replace")
self._token_to_id[token_str] = token_id
self._id_to_token[token_id] = token_str
# Initialize added_tokens_decoder before adding special tokens
self._added_tokens_decoder: dict[int, Any] = {}
# Add Kimi-Audio special tokens
self._add_kimiaudio_special_tokens()
# Set default special token IDs (will be updated when special tokens are added)
self._bos_token_id = 151643 # Kimi-Audio BOS
self._eos_token_id = 151644 # Kimi-Audio EOS
self._pad_token_id = self._eos_token_id
self._unk_token_id = self._pad_token_id
self._max_chars_per_token = max(
(len(tok) for tok in self._token_to_id), default=10
)
def _add_kimiaudio_special_tokens(self) -> None:
"""Add Kimi-Audio special tokens to the tokenizer."""
# Tokens should already be in self._special_tokens from tokenizer_config.json
# Just add them to added_tokens_decoder for compatibility
kimiaudio_special_tokens = {
"<|im_media_begin|>": 151661,
"<|im_media_end|>": 151663,
"<|im_kimia_text_blank|>": 151666,
"<|im_msg_end|>": 151645,
"<|im_kimia_user_msg_start|>": 151670,
"<|im_kimia_assistant_msg_start|>": 151671,
}
for token_str, token_id in kimiaudio_special_tokens.items():
# Only add if not already present
if token_id not in self._added_tokens_decoder:
self._added_tokens_decoder[token_id] = AddedToken(
token_str, single_word=True, normalized=False, special=True
)
# Also ensure it's in _token_to_id and _id_to_token
if token_str not in self._token_to_id:
self._token_to_id[token_str] = token_id
if token_id not in self._id_to_token:
self._id_to_token[token_id] = token_str
def num_special_tokens_to_add(self) -> int:
return 0
@property
def all_special_tokens(self) -> list[str]:
return list(self._added_tokens_decoder.values())
@property
def all_special_ids(self) -> list[int]:
return list(self._added_tokens_decoder.keys())
@property
def bos_token_id(self) -> int:
return self._bos_token_id
@property
def eos_token_id(self) -> int:
return self._eos_token_id
@property
def pad_token_id(self) -> int:
return self._pad_token_id
@property
def is_fast(self) -> bool:
return False
@property
def vocab_size(self) -> int:
return self._tokenizer.n_vocab
@property
def max_token_id(self) -> int:
return self._tokenizer.n_vocab - 1
@property
def max_chars_per_token(self) -> int:
return self._max_chars_per_token
@property
def truncation_side(self) -> str:
return self._truncation_side
@property
def added_tokens_decoder(self) -> dict[int, Any]:
return self._added_tokens_decoder
@added_tokens_decoder.setter
def added_tokens_decoder(self, value: dict[int, Any]) -> None:
"""Set added tokens decoder and update special token IDs."""
self._added_tokens_decoder = value
# Update special token IDs if known tokens are added
for token_id, token in value.items():
token_str = str(token) if hasattr(token, "__str__") else token
if "<|im_kimia_user_msg_start|>" in token_str:
self._bos_token_id = token_id
elif "<|im_msg_end|>" in token_str or "<|im_end|>" in token_str:
self._eos_token_id = token_id
def get_vocab(self) -> dict[str, int]:
return dict(self._token_to_id)
def __len__(self) -> int:
"""Return vocab size for compatibility with HF tokenizer interface."""
return self._tokenizer.n_vocab
def get_added_vocab(self) -> dict[str, int]:
return {
str(token): token_id
for token_id, token in self._added_tokens_decoder.items()
}
def _maybe_truncate(self, tokens: list[int], max_length: int | None) -> list[int]:
if max_length is None or len(tokens) <= max_length:
return tokens
if self.truncation_side == "left":
return tokens[-max_length:]
return tokens[:max_length]
def encode(
self,
text: str,
truncation: bool | None = None,
max_length: int | None = None,
add_special_tokens: bool = True,
**kwargs,
) -> list[int]:
del add_special_tokens
# Allow Kimi-Audio special tokens to be encoded
tokens = self._tokenizer.encode(
text,
allowed_special={
"<|im_media_begin|>",
"<|im_media_end|>",
"<|im_kimia_text_blank|>",
"<|im_msg_end|>",
"<|im_kimia_user_msg_start|>",
"<|im_kimia_assistant_msg_start|>",
},
)
if truncation:
tokens = self._maybe_truncate(tokens, max_length)
return tokens
def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:
"""Decode token IDs to text, optionally skipping special tokens."""
if isinstance(ids, int):
ids = [ids]
if skip_special_tokens:
# Skip tokens that are in special_tokens (loaded from config)
special_ids = set(self._special_tokens.values())
ids = [token_id for token_id in ids if token_id not in special_ids]
return self._tokenizer.decode(ids)
@overload
def convert_tokens_to_ids(self, tokens: str) -> int: ...
@overload
def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ...
def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
if isinstance(tokens, str):
return self._token_to_id.get(tokens, self._unk_token_id)
return [self._token_to_id.get(token, self._unk_token_id) for token in tokens]
def convert_ids_to_tokens(
self, ids: list[int], skip_special_tokens: bool = False
) -> list[str]:
tokens = []
for token_id in ids:
if skip_special_tokens and token_id in self._added_tokens_decoder:
continue
tokens.append(self._id_to_token.get(token_id, "<|unk|>"))
return tokens
def convert_tokens_to_string(self, tokens: list[str]) -> str:
token_ids = self.convert_tokens_to_ids(tokens)
return self.decode(token_ids, skip_special_tokens=False)
def __call__(
self,
text: str | list[str],
text_pair: str | None = None,
add_special_tokens: bool = True,
truncation: bool = False,
max_length: int | None = None,
**kwargs,
) -> BatchEncoding:
if text_pair is not None:
raise NotImplementedError(
"text_pair is not supported for KimiAudioTokenizer."
)
if isinstance(text, list):
input_ids_batch: list[list[int]] = [
self.encode(
item,
truncation=truncation,
max_length=max_length,
add_special_tokens=add_special_tokens,
)
for item in text
]
attention_mask_batch = [[1] * len(ids) for ids in input_ids_batch]
return BatchEncoding(
{"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}
)
input_ids = self.encode(
text,
truncation=truncation,
max_length=max_length,
add_special_tokens=add_special_tokens,
)
attention_mask = [1] * len(input_ids)
return BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask})
def get_chat_template(
self, chat_template: str | None, tools: list[dict[str, Any]] | None = None
) -> str | None:
del tools
return chat_template
def apply_chat_template(
self,
messages: list[ChatCompletionMessageParam] | None = None,
tools: list[dict[str, Any]] | None = None,
chat_template: str | None = None,
tokenize: bool = False,
**kwargs,
) -> str | list[int]:
# Handle both 'messages' (protocol) and 'conversation' (caller) parameter names
conversation = messages if messages is not None else kwargs.get("conversation")
if conversation is None:
raise ValueError("Either 'messages' or 'conversation' must be provided.")
template = self.get_chat_template(chat_template, tools=tools)
if template is None:
raise ValueError(
"No chat template available. Provide `chat_template` explicitly."
)
# Use render_jinja_template instead of apply_chat_template
# Note: render_jinja_template returns ([prompts], [generation_indices])
rendered, _ = hf_chat_utils.render_jinja_template(
conversation,
chat_template=template,
tools=tools,
**kwargs,
)
# Extract the first (and usually only) prompt
prompt = rendered[0] if rendered else ""
if tokenize:
return self.encode(prompt, add_special_tokens=False)
return prompt
...@@ -35,6 +35,7 @@ _VLLM_TOKENIZERS = { ...@@ -35,6 +35,7 @@ _VLLM_TOKENIZERS = {
"deepseek_v32": ("deepseek_v32", "DeepseekV32Tokenizer"), "deepseek_v32": ("deepseek_v32", "DeepseekV32Tokenizer"),
"grok2": ("grok2", "Grok2Tokenizer"), "grok2": ("grok2", "Grok2Tokenizer"),
"hf": ("hf", "CachedHfTokenizer"), "hf": ("hf", "CachedHfTokenizer"),
"kimi_audio": ("kimi_audio", "KimiAudioTokenizer"),
"mistral": ("mistral", "MistralTokenizer"), "mistral": ("mistral", "MistralTokenizer"),
"qwen_vl": ("qwen_vl", "QwenVLTokenizer"), "qwen_vl": ("qwen_vl", "QwenVLTokenizer"),
} }
......
{% set messages = conversations[0] if conversations else [] -%}
{% if messages and messages[0]['role'] == 'system' -%}
{% set loop_messages = messages[1:] -%}
{% else -%}
{% set loop_messages = messages -%}
{% endif -%}
{% for message in loop_messages -%}
{% if message['role'] == 'user' -%}
<|im_kimia_user_msg_start|>{{ message['content'] }}<|im_msg_end|><|im_kimia_assistant_msg_start|>
{%- elif message['role'] == 'assistant' -%}
{{ message['content'] }}<|im_kimia_text_eos|>
{%- endif -%}
{% endfor -%}
...@@ -10,23 +10,6 @@ reasons: ...@@ -10,23 +10,6 @@ reasons:
import importlib import importlib
_CLASS_TO_MODULE: dict[str, str] = {
"BagelProcessor": "vllm.transformers_utils.processors.bagel",
"DeepseekVLV2Processor": "vllm.transformers_utils.processors.deepseek_vl2",
"FireRedASR2Processor": "vllm.transformers_utils.processors.fireredasr2",
"FunASRProcessor": "vllm.transformers_utils.processors.funasr",
"GLM4VProcessor": "vllm.transformers_utils.processors.glm4v",
"HunYuanVLProcessor": "vllm.transformers_utils.processors.hunyuan_vl",
"HunYuanVLImageProcessor": "vllm.transformers_utils.processors.hunyuan_vl_image",
"MistralCommonPixtralProcessor": "vllm.transformers_utils.processors.pixtral",
"MistralCommonVoxtralProcessor": "vllm.transformers_utils.processors.voxtral",
"OvisProcessor": "vllm.transformers_utils.processors.ovis",
"Ovis2_5Processor": "vllm.transformers_utils.processors.ovis2_5",
"QwenVLProcessor": "vllm.transformers_utils.processors.qwen_vl",
"Qwen3ASRProcessor": "vllm.transformers_utils.processors.qwen3_asr",
}
__all__ = [ __all__ = [
"BagelProcessor", "BagelProcessor",
"DeepseekVLV2Processor", "DeepseekVLV2Processor",
...@@ -35,6 +18,7 @@ __all__ = [ ...@@ -35,6 +18,7 @@ __all__ = [
"GLM4VProcessor", "GLM4VProcessor",
"HunYuanVLProcessor", "HunYuanVLProcessor",
"HunYuanVLImageProcessor", "HunYuanVLImageProcessor",
"KimiAudioProcessor",
"MistralCommonPixtralProcessor", "MistralCommonPixtralProcessor",
"MistralCommonVoxtralProcessor", "MistralCommonVoxtralProcessor",
"OvisProcessor", "OvisProcessor",
...@@ -43,6 +27,23 @@ __all__ = [ ...@@ -43,6 +27,23 @@ __all__ = [
"Qwen3ASRProcessor", "Qwen3ASRProcessor",
] ]
_CLASS_TO_MODULE: dict[str, str] = {
"BagelProcessor": "vllm.transformers_utils.processors.bagel",
"DeepseekVLV2Processor": "vllm.transformers_utils.processors.deepseek_vl2",
"FireRedASR2Processor": "vllm.transformers_utils.processors.fireredasr2",
"FunASRProcessor": "vllm.transformers_utils.processors.funasr",
"GLM4VProcessor": "vllm.transformers_utils.processors.glm4v",
"HunYuanVLProcessor": "vllm.transformers_utils.processors.hunyuan_vl",
"HunYuanVLImageProcessor": "vllm.transformers_utils.processors.hunyuan_vl_image",
"KimiAudioProcessor": "vllm.transformers_utils.processors.kimi_audio",
"MistralCommonPixtralProcessor": "vllm.transformers_utils.processors.pixtral",
"MistralCommonVoxtralProcessor": "vllm.transformers_utils.processors.voxtral",
"OvisProcessor": "vllm.transformers_utils.processors.ovis",
"Ovis2_5Processor": "vllm.transformers_utils.processors.ovis2_5",
"QwenVLProcessor": "vllm.transformers_utils.processors.qwen_vl",
"Qwen3ASRProcessor": "vllm.transformers_utils.processors.qwen3_asr",
}
def __getattr__(name: str): def __getattr__(name: str):
if name in _CLASS_TO_MODULE: if name in _CLASS_TO_MODULE:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa
# mypy: ignore-errors
# coding=utf-8
# Copyright 2026 The Moonshot AI team and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Processor for Kimi-Audio ASR model."""
from collections.abc import Mapping
from typing import Any
import numpy as np
import torch
from transformers import AutoFeatureExtractor, BatchFeature, ProcessorMixin
from transformers.audio_utils import AudioInput
from transformers.tokenization_utils_base import TextInput
from vllm.tokenizers.kimi_audio import KimiAudioTokenizer
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor:
"""Compute output lengths after Whisper feature extraction."""
input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
output_lengths = (
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
)
return output_lengths
class KimiAudioProcessor(ProcessorMixin):
r"""
Constructs a Kimi-Audio processor.
[`KimiAudioProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`], and a tokenizer.
See the [`~KimiAudioProcessor.__call__`] and [`~KimiAudioProcessor.decode`] for more information.
Args:
feature_extractor ([`WhisperFeatureExtractor`], *optional*):
The audio feature extractor.
tokenizer ([`PreTrainedTokenizer`], *optional*):
The text tokenizer.
"""
# Required for ProcessorMixin
attributes = ["feature_extractor", "tokenizer"]
feature_extractor_class = "AutoFeatureExtractor"
tokenizer_class = "AutoTokenizer"
# Special token IDs
KIMIA_MEDIA_BEGIN: int = 151661
KIMIA_MEDIA_END: int = 151663
KIMIA_TEXT_BLANK: int = 151666
# Audio processing constants
AUDIO_SEQ_LEN: int = 376
def __init__(self, feature_extractor=None, tokenizer=None, **kwargs):
# Pass feature_extractor and tokenizer to parent ProcessorMixin
super().__init__(
feature_extractor=feature_extractor,
tokenizer=tokenizer,
**kwargs,
)
def check_argument_for_proper_class(self, attribute_name: str, argument: Any):
"""Override to skip class validation for custom tokenizer."""
# Skip validation for tokenizer since KimiAudioTokenizer doesn't inherit
# from PreTrainedTokenizerBase but is compatible
if attribute_name == "tokenizer" and argument is not None:
return
# For other attributes, use default validation
super().check_argument_for_proper_class(attribute_name, argument)
def __call__(
self,
text: TextInput = None,
audio: AudioInput = None,
return_tensors: str = "pt",
**kwargs,
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and audio(s).
Args:
text (`str`, `List[str]`):
The sequence or batch of sequences to be encoded.
audio (`np.ndarray`, `List[np.ndarray]`):
The audio or batch of audio to be prepared. Each audio can be a NumPy array.
return_tensors (`str`):
The type of tensors to return ("pt", "np", etc.)
"""
if text is None:
raise ValueError("You need to specify either a `text` input to process.")
# Process audio if provided
if audio is not None:
# Ensure audio is a list
if isinstance(audio, np.ndarray):
audio = [audio]
# Pad audio to hop length (required by WhisperFeatureExtractor)
hop_length = self.feature_extractor.hop_length
padded_audio = []
for aud in audio:
length = aud.shape[-1]
if length % hop_length != 0:
pad_length = hop_length - (length % hop_length)
aud = np.pad(
aud, (0, pad_length), mode="constant", constant_values=0
)
padded_audio.append(aud)
# Use feature_extractor directly like Qwen3ASR does
audio_inputs = self.feature_extractor(
padded_audio,
sampling_rate=16000,
padding=True,
return_attention_mask=True,
return_tensors=return_tensors,
)
# Rename to match Kimi-Audio expectations
if "input_features" in audio_inputs:
audio_inputs["whisper_input_features"] = audio_inputs.pop(
"input_features"
)
if "attention_mask" in audio_inputs:
audio_inputs["feature_attention_mask"] = audio_inputs.pop(
"attention_mask"
)
else:
audio_inputs = {}
# Handle text input - can be string or token IDs from vLLM processor
if isinstance(text, list) and len(text) > 0 and isinstance(text[0], int):
# Text is already token IDs (from vLLM processor) - just wrap
text_inputs = {"input_ids": torch.tensor([text], dtype=torch.long)}
else:
# Text is string - tokenize
if not isinstance(text, list):
text = [text]
text_inputs = self.tokenizer(
text, return_tensors=return_tensors, padding=True
)
return BatchFeature(
data={**text_inputs, **audio_inputs},
tensor_type=return_tensors,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment