Unverified Commit 196802df authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Clean up renderers (#36770)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent c84b519c
...@@ -6,9 +6,6 @@ from functools import partial ...@@ -6,9 +6,6 @@ from functools import partial
import numpy as np import numpy as np
import pytest import pytest
from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image from PIL import Image
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -21,7 +18,10 @@ from vllm.config.multimodal import ( ...@@ -21,7 +18,10 @@ from vllm.config.multimodal import (
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.cache import MultiModalProcessorOnlyCache from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import MultiModalInputs, batched_tensors_equal from vllm.multimodal.inputs import MultiModalInputs, batched_tensors_equal
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext from vllm.multimodal.processing import (
BaseMultiModalProcessor,
InputProcessingContext,
)
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from vllm.utils.mistral import is_mistral_tokenizer from vllm.utils.mistral import is_mistral_tokenizer
...@@ -74,20 +74,6 @@ def glmasr_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict: ...@@ -74,20 +74,6 @@ def glmasr_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
return mm_data return mm_data
# For some multimodal models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
_ADD_SPECIAL_TOKENS_OVERRIDES = {
"lfm2_vl": False,
"nemotron_parse": False,
"ovis": False,
"ovis2_5": False,
"paligemma": False,
"ultravox": False,
"whisper": False,
}
_IGNORE_MM_KEYS = { _IGNORE_MM_KEYS = {
# In Ultravox, the audio_features can be different depending on padding # In Ultravox, the audio_features can be different depending on padding
# The slight difference should not be a problem though, since # The slight difference should not be a problem though, since
...@@ -152,63 +138,34 @@ def get_text_token_prompts( ...@@ -152,63 +138,34 @@ def get_text_token_prompts(
parsed_data = processor.info.parse_mm_data(mm_data) parsed_data = processor.info.parse_mm_data(mm_data)
mm_counts = {k: len(vs) for k, vs in parsed_data.items()} mm_counts = {k: len(vs) for k, vs in parsed_data.items()}
text_prompt: str | None
token_prompt: list[int]
if is_mistral_tokenizer(tokenizer): if is_mistral_tokenizer(tokenizer):
# ChatCompletionRequest only supports ImageChunk natively; inputs = dummy_inputs.get_dummy_processor_inputs(
# for other modalities (e.g. audio), fall back to the model's model_config.max_model_len,
# own dummy inputs builder which knows the right placeholders. mm_counts,
has_non_image = any( mm_options={},
k != "image" and count > 0 for k, count in mm_counts.items() # Assume all Mistral models define this extra argument
mm_data=mm_data, # type: ignore[call-arg]
) )
if has_non_image:
inputs = dummy_inputs.get_dummy_processor_inputs(
model_config.max_model_len,
mm_counts,
mm_options={},
)
text_prompt = None
token_prompt = (
inputs.prompt
if isinstance(inputs.prompt, list)
else tokenizer.encode(inputs.prompt, add_special_tokens=False)
)
else:
images = parsed_data.get("image", [])
request = ChatCompletionRequest(
messages=[
UserMessage(
content=[
TextChunk(text=""),
*(ImageChunk(image=image) for image in images),
]
),
]
)
res = tokenizer.mistral.encode_chat_completion(request)
# Mistral does not support decode_tokens with
# skip_special_tokens=False
text_prompt = None
token_prompt = res.tokens
else: else:
inputs = dummy_inputs.get_dummy_processor_inputs( inputs = dummy_inputs.get_dummy_processor_inputs(
model_config.max_model_len, model_config.max_model_len,
mm_counts, mm_counts,
mm_options={}, mm_options={},
) )
# Some models (e.g., Kimi-Audio) return token IDs directly instead of str
if isinstance(inputs.prompt, list): text_prompt: str | None
text_prompt = None token_prompt: list[int]
token_prompt = inputs.prompt if isinstance(inputs.prompt, list):
else: text_prompt = None
assert isinstance(inputs.prompt, str) token_prompt = inputs.prompt
text_prompt = inputs.prompt elif isinstance(inputs.prompt, str):
token_prompt = tokenizer.encode( text_prompt = inputs.prompt
text_prompt, token_prompt = tokenizer.encode(
add_special_tokens=_ADD_SPECIAL_TOKENS_OVERRIDES.get(model_type, True), text_prompt,
) **processor.info.get_default_tok_params().get_encode_kwargs(),
)
else:
raise TypeError(type(inputs.prompt))
return text_prompt, token_prompt return text_prompt, token_prompt
...@@ -448,7 +405,7 @@ def test_processing_correctness( ...@@ -448,7 +405,7 @@ def test_processing_correctness(
) )
if model_id == "mistralai/Voxtral-Mini-4B-Realtime-2602": if model_id == "mistralai/Voxtral-Mini-4B-Realtime-2602":
pytest.skip( pytest.skip(
"Voxtral Realtime doesn't make use of any place-holder" "Voxtral Realtime doesn't make use of any place-holder "
"tokens and hence cannot pass the processing " "tokens and hence cannot pass the processing "
"correctness test as is. Let's revisit adapting this " "correctness test as is. Let's revisit adapting this "
"test once more realtime models exist." "test once more realtime models exist."
......
...@@ -532,6 +532,22 @@ class ModelConfig: ...@@ -532,6 +532,22 @@ class ModelConfig:
self._architecture = arch self._architecture = arch
logger.info("Resolved architecture: %s", arch) logger.info("Resolved architecture: %s", arch)
# Set default tokenizer modes based on model architecture
if self.tokenizer_mode == "auto":
if arch == "Grok1ForCausalLM":
self.tokenizer_mode = "grok2"
elif arch == "MoonshotKimiaForCausalLM":
self.tokenizer_mode = "kimi_audio"
elif arch == "QwenVLForConditionalGeneration":
self.tokenizer_mode = "qwen_vl"
if self.tokenizer_mode != "auto":
logger.info(
"Defaulting to tokenizer_mode=%r for %s",
self.tokenizer_mode,
arch,
)
# Init pooler config if needed # Init pooler config if needed
if self.runner_type == "pooling": if self.runner_type == "pooling":
if self.pooler_config is None: if self.pooler_config is None:
......
...@@ -10,11 +10,13 @@ from typing import Any, ClassVar, Literal ...@@ -10,11 +10,13 @@ from typing import Any, ClassVar, Literal
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from huggingface_hub import snapshot_download
from safetensors import safe_open from safetensors import safe_open
from transformers import BatchFeature from transformers import BatchFeature
from transformers import WhisperConfig as HFWhisperConfig from transformers import WhisperConfig as HFWhisperConfig
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.data import PromptType, TokensPrompt
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -47,7 +49,10 @@ from vllm.multimodal.processing import ( ...@@ -47,7 +49,10 @@ from vllm.multimodal.processing import (
BaseProcessingInfo, BaseProcessingInfo,
PromptReplacement, PromptReplacement,
) )
from vllm.multimodal.processing.processor import BaseMultiModalProcessor from vllm.multimodal.processing.processor import (
BaseMultiModalProcessor,
ProcessorInputs,
)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_get_tokenizer from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.kimi_audio import KimiAudioTokenizer from vllm.tokenizers.kimi_audio import KimiAudioTokenizer
...@@ -59,6 +64,15 @@ from vllm.v1.sample.metadata import SamplingMetadata ...@@ -59,6 +64,15 @@ from vllm.v1.sample.metadata import SamplingMetadata
KIMIA_WHISPER_SUBFOLDER = "whisper-large-v3" KIMIA_WHISPER_SUBFOLDER = "whisper-large-v3"
def _get_whisper_local_path(repo_id: str):
if os.path.exists(repo_id):
repo_local_path = repo_id
else:
repo_local_path = snapshot_download(repo_id, local_files_only=True)
return os.path.join(repo_local_path, KIMIA_WHISPER_SUBFOLDER)
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor: def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor:
"""Compute output lengths after Whisper feature extraction. """Compute output lengths after Whisper feature extraction.
...@@ -88,10 +102,10 @@ class KimiAudioWhisperEncoder(WhisperEncoder): ...@@ -88,10 +102,10 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
# Load Whisper config from subfolder (authoritative source) # Load Whisper config from subfolder (authoritative source)
# Kimi-Audio stores Whisper config in whisper-large-v3/config.json # Kimi-Audio stores Whisper config in whisper-large-v3/config.json
model_path = vllm_config.model_config.model model_path = vllm_config.model_config.model
whisper_config_path = os.path.join(model_path, KIMIA_WHISPER_SUBFOLDER)
# Load WhisperConfig from the subfolder # Load WhisperConfig from the subfolder
whisper_config = HFWhisperConfig.from_pretrained(whisper_config_path) whisper_dir = _get_whisper_local_path(model_path)
whisper_config = HFWhisperConfig.from_pretrained(whisper_dir)
# Temporarily replace hf_config for WhisperEncoder.__init__() # Temporarily replace hf_config for WhisperEncoder.__init__()
original_config = vllm_config.model_config.hf_config original_config = vllm_config.model_config.hf_config
...@@ -114,28 +128,18 @@ class KimiAudioWhisperEncoder(WhisperEncoder): ...@@ -114,28 +128,18 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
class KimiAudioProcessingInfo(BaseProcessingInfo): class KimiAudioProcessingInfo(BaseProcessingInfo):
"""Processing info for vLLM registry.""" """Processing info for vLLM registry."""
def get_hf_config(self):
return self.ctx.model_config.hf_config
def get_hf_processor(self, **kwargs: object) -> KimiAudioProcessor: def get_hf_processor(self, **kwargs: object) -> KimiAudioProcessor:
"""Get KimiAudioProcessor with feature extractor and tokenizer."""
# Use vLLM's cached loader for feature extractor
feature_extractor = cached_feature_extractor_from_config( feature_extractor = cached_feature_extractor_from_config(
self.ctx.model_config, self.ctx.model_config,
subfolder=KIMIA_WHISPER_SUBFOLDER, subfolder=KIMIA_WHISPER_SUBFOLDER,
) )
# Use vLLM's standard tokenizer loading (respects tokenizer_mode)
tokenizer = self.get_tokenizer()
# Construct processor directly
return KimiAudioProcessor( return KimiAudioProcessor(
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
tokenizer=tokenizer, tokenizer=self.get_tokenizer(),
) )
def get_feature_extractor(self, **kwargs: object): def get_feature_extractor(self, **kwargs: object):
"""Get feature extractor using vLLM's cached loader."""
return cached_feature_extractor_from_config( return cached_feature_extractor_from_config(
self.ctx.model_config, subfolder=KIMIA_WHISPER_SUBFOLDER self.ctx.model_config, subfolder=KIMIA_WHISPER_SUBFOLDER
) )
...@@ -144,26 +148,16 @@ class KimiAudioProcessingInfo(BaseProcessingInfo): ...@@ -144,26 +148,16 @@ class KimiAudioProcessingInfo(BaseProcessingInfo):
return {"audio": 1} return {"audio": 1}
def get_data_parser(self) -> "KimiAudioMultiModalDataParser": def get_data_parser(self) -> "KimiAudioMultiModalDataParser":
"""Get data parser for audio inputs.""" feature_extractor = self.get_feature_extractor()
return KimiAudioMultiModalDataParser( return KimiAudioMultiModalDataParser(
target_sr=feature_extractor.sampling_rate,
expected_hidden_size=self._get_expected_hidden_size(), expected_hidden_size=self._get_expected_hidden_size(),
) )
class KimiAudioDummyInputsBuilder(BaseDummyInputsBuilder[KimiAudioProcessingInfo]): class KimiAudioDummyInputsBuilder(BaseDummyInputsBuilder[KimiAudioProcessingInfo]):
"""Dummy inputs builder for vLLM registry.""" def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> list[int]:
"""Return dummy text as token IDs directly."""
num_audios = mm_counts.get("audio", 0)
if num_audios == 0:
return [198] # "Transcribe" tokenized
# Return as token IDs directly to avoid tokenizer issues
return [
KimiAudioProcessor.KIMIA_MEDIA_BEGIN,
KimiAudioProcessor.KIMIA_TEXT_BLANK,
KimiAudioProcessor.KIMIA_MEDIA_END,
] * num_audios
def get_dummy_mm_data( def get_dummy_mm_data(
self, self,
...@@ -186,6 +180,29 @@ class KimiAudioDummyInputsBuilder(BaseDummyInputsBuilder[KimiAudioProcessingInfo ...@@ -186,6 +180,29 @@ class KimiAudioDummyInputsBuilder(BaseDummyInputsBuilder[KimiAudioProcessingInfo
), ),
} }
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions],
) -> ProcessorInputs:
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data)
num_audios = mm_counts.get("audio", 0)
dummy_tokens = (
[198]
if num_audios == 0
else [
KimiAudioProcessor.KIMIA_MEDIA_BEGIN,
KimiAudioProcessor.KIMIA_TEXT_BLANK,
KimiAudioProcessor.KIMIA_MEDIA_END,
]
* num_audios
)
return ProcessorInputs(prompt=dummy_tokens, mm_data_items=dummy_mm_items)
# Field config for Kimi-Audio multimodal data # Field config for Kimi-Audio multimodal data
_KIMIAUDIO_FIELD_CONFIG = { _KIMIAUDIO_FIELD_CONFIG = {
...@@ -197,10 +214,6 @@ _KIMIAUDIO_FIELD_CONFIG = { ...@@ -197,10 +214,6 @@ _KIMIAUDIO_FIELD_CONFIG = {
class KimiAudioMultiModalDataParser(MultiModalDataParser): class KimiAudioMultiModalDataParser(MultiModalDataParser):
"""Custom data parser for Kimi-Audio multimodal data.""" """Custom data parser for Kimi-Audio multimodal data."""
def __init__(self, **kwargs):
# Whisper expects 16kHz audio
super().__init__(target_sr=16000, **kwargs)
def _parse_audio_data( def _parse_audio_data(
self, self,
data: dict[str, torch.Tensor] | ModalityData[AudioItem], data: dict[str, torch.Tensor] | ModalityData[AudioItem],
...@@ -589,9 +602,8 @@ class KimiAudioForConditionalGeneration( ...@@ -589,9 +602,8 @@ class KimiAudioForConditionalGeneration(
loaded = loader.load_weights(main_weights, mapper=self.hf_to_vllm_mapper) loaded = loader.load_weights(main_weights, mapper=self.hf_to_vllm_mapper)
# Load Whisper encoder weights from subfolder # Load Whisper encoder weights from subfolder
whisper_path = os.path.join( whisper_dir = _get_whisper_local_path(self.model_path)
self.model_path, f"{KIMIA_WHISPER_SUBFOLDER}/model.safetensors" whisper_path = os.path.join(whisper_dir, "model.safetensors")
)
if os.path.exists(whisper_path): if os.path.exists(whisper_path):
whisper_loaded = self._load_whisper_weights_from_file(whisper_path) whisper_loaded = self._load_whisper_weights_from_file(whisper_path)
loaded.update(whisper_loaded) loaded.update(whisper_loaded)
......
...@@ -63,12 +63,10 @@ from vllm.multimodal.processing import ( ...@@ -63,12 +63,10 @@ from vllm.multimodal.processing import (
BaseDummyInputsBuilder, BaseDummyInputsBuilder,
BaseMultiModalProcessor, BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
InputProcessingContext,
PromptReplacement, PromptReplacement,
PromptUpdate, PromptUpdate,
PromptUpdateDetails, PromptUpdateDetails,
) )
from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
...@@ -546,9 +544,6 @@ class Llama4VisionModel(nn.Module): ...@@ -546,9 +544,6 @@ class Llama4VisionModel(nn.Module):
class Mllama4ProcessingInfo(BaseProcessingInfo): class Mllama4ProcessingInfo(BaseProcessingInfo):
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__(ctx)
def get_hf_config(self) -> Llama4Config: def get_hf_config(self) -> Llama4Config:
return self.ctx.get_hf_config(Llama4Config) return self.ctx.get_hf_config(Llama4Config)
...@@ -557,9 +552,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo): ...@@ -557,9 +552,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
Llama4Processor, use_fast=kwargs.pop("use_fast", True), **kwargs Llama4Processor, use_fast=kwargs.pop("use_fast", True), **kwargs
) )
def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
# Although vLLM can support more images from an infra capability # Although vLLM can support more images from an infra capability
# perspective, we do not recommend using >10 images in practice. # perspective, we do not recommend using >10 images in practice.
...@@ -597,10 +589,6 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]) ...@@ -597,10 +589,6 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo])
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
tokenizer = self.info.get_tokenizer()
if mm_data is None:
return tokenizer(prompt, add_special_tokens=False) # exclude bos
processed_outputs = super()._call_hf_processor( processed_outputs = super()._call_hf_processor(
prompt=prompt, prompt=prompt,
mm_data=mm_data, mm_data=mm_data,
......
...@@ -172,12 +172,20 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): ...@@ -172,12 +172,20 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions], mm_options: Mapping[str, BaseDummyOptions],
mm_data: MultiModalDataDict | None = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
dummy_text = self.get_dummy_text(mm_counts) dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options) dummy_mm_data = (
dummy_images = dummy_mm_data.get("image", []) self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
if mm_data is None
else mm_data
)
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data)
dummy_images = (
[] if "image" not in dummy_mm_data else dummy_mm_items["image"].get_all()
)
request = ChatCompletionRequest( request = ChatCompletionRequest(
messages=[ messages=[
...@@ -192,8 +200,6 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): ...@@ -192,8 +200,6 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
res = tokenizer.mistral.encode_chat_completion(request) res = tokenizer.mistral.encode_chat_completion(request)
dummy_tokens = res.tokens dummy_tokens = res.tokens
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data)
return ProcessorInputs(prompt=dummy_tokens, mm_data_items=dummy_mm_items) return ProcessorInputs(prompt=dummy_tokens, mm_data_items=dummy_mm_items)
......
...@@ -150,13 +150,21 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]): ...@@ -150,13 +150,21 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions], mm_options: Mapping[str, BaseDummyOptions],
mm_data: MultiModalDataDict | None = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
feature_extractor = self.info.get_hf_processor().feature_extractor feature_extractor = self.info.get_hf_processor().feature_extractor
dummy_text = self.get_dummy_text(mm_counts) dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options) dummy_mm_data = (
dummy_audios = dummy_mm_data.get("audio", []) self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
if mm_data is None
else mm_data
)
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data)
dummy_audios = (
[] if "audio" not in dummy_mm_data else dummy_mm_items["audio"].get_all()
)
audio_chunks: list[AudioChunk] = [] audio_chunks: list[AudioChunk] = []
format = "wav" format = "wav"
......
...@@ -6,11 +6,10 @@ from vllm.config import VllmConfig ...@@ -6,11 +6,10 @@ from vllm.config import VllmConfig
from vllm.tokenizers import cached_get_tokenizer from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.qwen_vl import QwenVLTokenizer from vllm.tokenizers.qwen_vl import QwenVLTokenizer
from .base import BaseRenderer
from .hf import HfRenderer from .hf import HfRenderer
class QwenVLRenderer(BaseRenderer[QwenVLTokenizer]): class QwenVLRenderer(HfRenderer):
@classmethod @classmethod
def from_config( # type: ignore[override] def from_config( # type: ignore[override]
cls, cls,
......
...@@ -80,13 +80,6 @@ def renderer_from_config(config: "VllmConfig", **kwargs): ...@@ -80,13 +80,6 @@ def renderer_from_config(config: "VllmConfig", **kwargs):
model_config, **kwargs model_config, **kwargs
) )
# Override tokenizer_mode for Kimi-Audio models
if model_config.architecture == "MoonshotKimiaForCausalLM":
tokenizer_mode = "kimi_audio"
# Update model_config so other components (e.g., multimodal registry)
# also use the correct tokenizer mode
model_config.tokenizer_mode = "kimi_audio"
if ( if (
model_config.tokenizer_mode == "auto" model_config.tokenizer_mode == "auto"
and model_config.model_impl == "terratorch" and model_config.model_impl == "terratorch"
......
...@@ -159,18 +159,6 @@ def resolve_tokenizer_args( ...@@ -159,18 +159,6 @@ def resolve_tokenizer_args(
): ):
tokenizer_mode = "mistral" tokenizer_mode = "mistral"
# Try to use Grok2 tiktoken tokenizer if possible
if tokenizer_mode == "auto" and any_pattern_in_repo_files(
model_name_or_path=str(tokenizer_name),
allow_patterns=["tokenizer.tok.json"],
revision=revision,
):
tokenizer_mode = "grok2"
# Model-specific tokenizers
if tokenizer_mode == "auto" and "/Qwen-VL" in str(tokenizer_name):
tokenizer_mode = "qwen_vl"
# Fallback to HF tokenizer # Fallback to HF tokenizer
if tokenizer_mode == "auto": if tokenizer_mode == "auto":
tokenizer_mode = "hf" tokenizer_mode = "hf"
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/zai-org/CogAgent
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from transformers.image_processing_utils_fast import BaseImageProcessorFast from transformers.image_processing_utils_fast import BaseImageProcessorFast
from transformers.image_utils import PILImageResampling from transformers.image_utils import PILImageResampling
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa # Copyright 2026 The Moonshot AI team and the HuggingFace Inc. team.
# mypy: ignore-errors # All rights reserved.
# coding=utf-8
# Copyright 2026 The Moonshot AI team and the HuggingFace Inc. team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -19,42 +17,13 @@ ...@@ -19,42 +17,13 @@
# limitations under the License. # limitations under the License.
"""Processor for Kimi-Audio ASR model.""" """Processor for Kimi-Audio ASR model."""
from collections.abc import Mapping
from typing import Any
import numpy as np import numpy as np
import torch from transformers import BatchFeature, ProcessorMixin
from transformers import AutoFeatureExtractor, BatchFeature, ProcessorMixin
from transformers.audio_utils import AudioInput from transformers.audio_utils import AudioInput
from transformers.tokenization_utils_base import TextInput from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from vllm.tokenizers.kimi_audio import KimiAudioTokenizer
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor:
"""Compute output lengths after Whisper feature extraction."""
input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
output_lengths = (
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
)
return output_lengths
class KimiAudioProcessor(ProcessorMixin): class KimiAudioProcessor(ProcessorMixin):
r"""
Constructs a Kimi-Audio processor.
[`KimiAudioProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`], and a tokenizer.
See the [`~KimiAudioProcessor.__call__`] and [`~KimiAudioProcessor.decode`] for more information.
Args:
feature_extractor ([`WhisperFeatureExtractor`], *optional*):
The audio feature extractor.
tokenizer ([`PreTrainedTokenizer`], *optional*):
The text tokenizer.
"""
# Required for ProcessorMixin # Required for ProcessorMixin
attributes = ["feature_extractor", "tokenizer"] attributes = ["feature_extractor", "tokenizer"]
feature_extractor_class = "AutoFeatureExtractor" feature_extractor_class = "AutoFeatureExtractor"
...@@ -69,44 +38,30 @@ class KimiAudioProcessor(ProcessorMixin): ...@@ -69,44 +38,30 @@ class KimiAudioProcessor(ProcessorMixin):
AUDIO_SEQ_LEN: int = 376 AUDIO_SEQ_LEN: int = 376
def __init__(self, feature_extractor=None, tokenizer=None, **kwargs): def __init__(self, feature_extractor=None, tokenizer=None, **kwargs):
# Pass feature_extractor and tokenizer to parent ProcessorMixin self.feature_extractor = feature_extractor
super().__init__( self.tokenizer = tokenizer
feature_extractor=feature_extractor,
tokenizer=tokenizer,
**kwargs,
)
def check_argument_for_proper_class(self, attribute_name: str, argument: Any):
"""Override to skip class validation for custom tokenizer."""
# Skip validation for tokenizer since KimiAudioTokenizer doesn't inherit
# from PreTrainedTokenizerBase but is compatible
if attribute_name == "tokenizer" and argument is not None:
return
# For other attributes, use default validation
super().check_argument_for_proper_class(attribute_name, argument)
def __call__( def __call__(
self, self,
text: TextInput = None, text: TextInput
audio: AudioInput = None, | PreTokenizedInput
| list[TextInput]
| list[PreTokenizedInput]
| None = None,
audio: AudioInput | None = None,
return_tensors: str = "pt", return_tensors: str = "pt",
**kwargs, **kwargs,
) -> BatchFeature: ) -> BatchFeature:
""" if text is not None:
Main method to prepare for the model one or several sequences(s) and audio(s). if not isinstance(text, list):
text = [text]
Args: text_inputs = self.tokenizer(
text (`str`, `List[str]`): text, return_tensors=return_tensors, padding=True
The sequence or batch of sequences to be encoded. )
audio (`np.ndarray`, `List[np.ndarray]`): else:
The audio or batch of audio to be prepared. Each audio can be a NumPy array. text_inputs = {}
return_tensors (`str`):
The type of tensors to return ("pt", "np", etc.)
"""
if text is None:
raise ValueError("You need to specify either a `text` input to process.")
# Process audio if provided
if audio is not None: if audio is not None:
# Ensure audio is a list # Ensure audio is a list
if isinstance(audio, np.ndarray): if isinstance(audio, np.ndarray):
...@@ -144,19 +99,6 @@ class KimiAudioProcessor(ProcessorMixin): ...@@ -144,19 +99,6 @@ class KimiAudioProcessor(ProcessorMixin):
else: else:
audio_inputs = {} audio_inputs = {}
# Handle text input - can be string or token IDs from vLLM processor
if isinstance(text, list) and len(text) > 0 and isinstance(text[0], int):
# Text is already token IDs (from vLLM processor) - just wrap
text_inputs = {"input_ids": torch.tensor([text], dtype=torch.long)}
else:
# Text is string - tokenize
if not isinstance(text, list):
text = [text]
text_inputs = self.tokenizer(
text, return_tensors=return_tensors, padding=True
)
return BatchFeature( return BatchFeature(
data={**text_inputs, **audio_inputs}, data={**text_inputs, **audio_inputs},
tensor_type=return_tensors, tensor_type=return_tensors,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://huggingface.co/Qwen/Qwen-VL/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
from transformers.image_processing_utils_fast import BaseImageProcessorFast from transformers.image_processing_utils_fast import BaseImageProcessorFast
from transformers.image_utils import PILImageResampling from transformers.image_utils import PILImageResampling
from transformers.processing_utils import ProcessorMixin from transformers.processing_utils import ProcessorMixin
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment