Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
196802df
Unverified
Commit
196802df
authored
Mar 12, 2026
by
Cyrus Leung
Committed by
GitHub
Mar 11, 2026
Browse files
[Misc] Clean up renderers (#36770)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
c84b519c
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
136 additions
and
220 deletions
+136
-220
tests/models/multimodal/processing/test_common.py
tests/models/multimodal/processing/test_common.py
+25
-68
vllm/config/model.py
vllm/config/model.py
+16
-0
vllm/model_executor/models/kimi_audio.py
vllm/model_executor/models/kimi_audio.py
+47
-35
vllm/model_executor/models/mllama4.py
vllm/model_executor/models/mllama4.py
+0
-12
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+10
-4
vllm/model_executor/models/voxtral.py
vllm/model_executor/models/voxtral.py
+10
-2
vllm/renderers/qwen_vl.py
vllm/renderers/qwen_vl.py
+1
-2
vllm/renderers/registry.py
vllm/renderers/registry.py
+0
-7
vllm/tokenizers/registry.py
vllm/tokenizers/registry.py
+0
-12
vllm/transformers_utils/processors/glm4v.py
vllm/transformers_utils/processors/glm4v.py
+3
-0
vllm/transformers_utils/processors/kimi_audio.py
vllm/transformers_utils/processors/kimi_audio.py
+20
-78
vllm/transformers_utils/processors/qwen_vl.py
vllm/transformers_utils/processors/qwen_vl.py
+4
-0
No files found.
tests/models/multimodal/processing/test_common.py
View file @
196802df
...
...
@@ -6,9 +6,6 @@ from functools import partial
import
numpy
as
np
import
pytest
from
mistral_common.protocol.instruct.chunk
import
ImageChunk
,
TextChunk
from
mistral_common.protocol.instruct.messages
import
UserMessage
from
mistral_common.protocol.instruct.request
import
ChatCompletionRequest
from
PIL
import
Image
from
vllm.config
import
ModelConfig
...
...
@@ -21,7 +18,10 @@ from vllm.config.multimodal import (
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalDataDict
from
vllm.multimodal.cache
import
MultiModalProcessorOnlyCache
from
vllm.multimodal.inputs
import
MultiModalInputs
,
batched_tensors_equal
from
vllm.multimodal.processing
import
BaseMultiModalProcessor
,
InputProcessingContext
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
InputProcessingContext
,
)
from
vllm.tokenizers
import
TokenizerLike
,
cached_tokenizer_from_config
from
vllm.utils.mistral
import
is_mistral_tokenizer
...
...
@@ -74,20 +74,6 @@ def glmasr_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
return
mm_data
# For some multimodal models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
_ADD_SPECIAL_TOKENS_OVERRIDES
=
{
"lfm2_vl"
:
False
,
"nemotron_parse"
:
False
,
"ovis"
:
False
,
"ovis2_5"
:
False
,
"paligemma"
:
False
,
"ultravox"
:
False
,
"whisper"
:
False
,
}
_IGNORE_MM_KEYS
=
{
# In Ultravox, the audio_features can be different depending on padding
# The slight difference should not be a problem though, since
...
...
@@ -152,63 +138,34 @@ def get_text_token_prompts(
parsed_data
=
processor
.
info
.
parse_mm_data
(
mm_data
)
mm_counts
=
{
k
:
len
(
vs
)
for
k
,
vs
in
parsed_data
.
items
()}
text_prompt
:
str
|
None
token_prompt
:
list
[
int
]
if
is_mistral_tokenizer
(
tokenizer
):
# ChatCompletionRequest only supports ImageChunk natively;
# for other modalities (e.g. audio), fall back to the model's
# own dummy inputs builder which knows the right placeholders.
has_non_image
=
any
(
k
!=
"image"
and
count
>
0
for
k
,
count
in
mm_counts
.
items
()
)
if
has_non_image
:
inputs
=
dummy_inputs
.
get_dummy_processor_inputs
(
model_config
.
max_model_len
,
mm_counts
,
mm_options
=
{},
)
text_prompt
=
None
token_prompt
=
(
inputs
.
prompt
if
isinstance
(
inputs
.
prompt
,
list
)
else
tokenizer
.
encode
(
inputs
.
prompt
,
add_special_tokens
=
False
)
# Assume all Mistral models define this extra argument
mm_data
=
mm_data
,
# type: ignore[call-arg]
)
else
:
images
=
parsed_data
.
get
(
"image"
,
[])
request
=
ChatCompletionRequest
(
messages
=
[
UserMessage
(
content
=
[
TextChunk
(
text
=
""
),
*
(
ImageChunk
(
image
=
image
)
for
image
in
images
),
]
),
]
)
res
=
tokenizer
.
mistral
.
encode_chat_completion
(
request
)
# Mistral does not support decode_tokens with
# skip_special_tokens=False
text_prompt
=
None
token_prompt
=
res
.
tokens
else
:
inputs
=
dummy_inputs
.
get_dummy_processor_inputs
(
model_config
.
max_model_len
,
mm_counts
,
mm_options
=
{},
)
# Some models (e.g., Kimi-Audio) return token IDs directly instead of str
text_prompt
:
str
|
None
token_prompt
:
list
[
int
]
if
isinstance
(
inputs
.
prompt
,
list
):
text_prompt
=
None
token_prompt
=
inputs
.
prompt
else
:
assert
isinstance
(
inputs
.
prompt
,
str
)
elif
isinstance
(
inputs
.
prompt
,
str
):
text_prompt
=
inputs
.
prompt
token_prompt
=
tokenizer
.
encode
(
text_prompt
,
add_special_tokens
=
_ADD_SPECIAL_TOKENS_OVERRIDES
.
get
(
model_type
,
True
),
**
processor
.
info
.
get_default_tok_params
().
get_encode_kwargs
(
),
)
else
:
raise
TypeError
(
type
(
inputs
.
prompt
))
return
text_prompt
,
token_prompt
...
...
@@ -448,7 +405,7 @@ def test_processing_correctness(
)
if
model_id
==
"mistralai/Voxtral-Mini-4B-Realtime-2602"
:
pytest
.
skip
(
"Voxtral Realtime doesn't make use of any place-holder"
"Voxtral Realtime doesn't make use of any place-holder
"
"tokens and hence cannot pass the processing "
"correctness test as is. Let's revisit adapting this "
"test once more realtime models exist."
...
...
vllm/config/model.py
View file @
196802df
...
...
@@ -532,6 +532,22 @@ class ModelConfig:
self
.
_architecture
=
arch
logger
.
info
(
"Resolved architecture: %s"
,
arch
)
# Set default tokenizer modes based on model architecture
if
self
.
tokenizer_mode
==
"auto"
:
if
arch
==
"Grok1ForCausalLM"
:
self
.
tokenizer_mode
=
"grok2"
elif
arch
==
"MoonshotKimiaForCausalLM"
:
self
.
tokenizer_mode
=
"kimi_audio"
elif
arch
==
"QwenVLForConditionalGeneration"
:
self
.
tokenizer_mode
=
"qwen_vl"
if
self
.
tokenizer_mode
!=
"auto"
:
logger
.
info
(
"Defaulting to tokenizer_mode=%r for %s"
,
self
.
tokenizer_mode
,
arch
,
)
# Init pooler config if needed
if
self
.
runner_type
==
"pooling"
:
if
self
.
pooler_config
is
None
:
...
...
vllm/model_executor/models/kimi_audio.py
View file @
196802df
...
...
@@ -10,11 +10,13 @@ from typing import Any, ClassVar, Literal
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
huggingface_hub
import
snapshot_download
from
safetensors
import
safe_open
from
transformers
import
BatchFeature
from
transformers
import
WhisperConfig
as
HFWhisperConfig
from
vllm.config
import
ModelConfig
,
SpeechToTextConfig
,
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.inputs.data
import
PromptType
,
TokensPrompt
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.model_loader.weight_utils
import
(
...
...
@@ -47,7 +49,10 @@ from vllm.multimodal.processing import (
BaseProcessingInfo
,
PromptReplacement
,
)
from
vllm.multimodal.processing.processor
import
BaseMultiModalProcessor
from
vllm.multimodal.processing.processor
import
(
BaseMultiModalProcessor
,
ProcessorInputs
,
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.tokenizers
import
cached_get_tokenizer
from
vllm.tokenizers.kimi_audio
import
KimiAudioTokenizer
...
...
@@ -59,6 +64,15 @@ from vllm.v1.sample.metadata import SamplingMetadata
KIMIA_WHISPER_SUBFOLDER
=
"whisper-large-v3"
def
_get_whisper_local_path
(
repo_id
:
str
):
if
os
.
path
.
exists
(
repo_id
):
repo_local_path
=
repo_id
else
:
repo_local_path
=
snapshot_download
(
repo_id
,
local_files_only
=
True
)
return
os
.
path
.
join
(
repo_local_path
,
KIMIA_WHISPER_SUBFOLDER
)
def
_get_feat_extract_output_lengths
(
input_lengths
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Compute output lengths after Whisper feature extraction.
...
...
@@ -88,10 +102,10 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
# Load Whisper config from subfolder (authoritative source)
# Kimi-Audio stores Whisper config in whisper-large-v3/config.json
model_path
=
vllm_config
.
model_config
.
model
whisper_config_path
=
os
.
path
.
join
(
model_path
,
KIMIA_WHISPER_SUBFOLDER
)
# Load WhisperConfig from the subfolder
whisper_config
=
HFWhisperConfig
.
from_pretrained
(
whisper_config_path
)
whisper_dir
=
_get_whisper_local_path
(
model_path
)
whisper_config
=
HFWhisperConfig
.
from_pretrained
(
whisper_dir
)
# Temporarily replace hf_config for WhisperEncoder.__init__()
original_config
=
vllm_config
.
model_config
.
hf_config
...
...
@@ -114,28 +128,18 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
class
KimiAudioProcessingInfo
(
BaseProcessingInfo
):
"""Processing info for vLLM registry."""
def
get_hf_config
(
self
):
return
self
.
ctx
.
model_config
.
hf_config
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
KimiAudioProcessor
:
"""Get KimiAudioProcessor with feature extractor and tokenizer."""
# Use vLLM's cached loader for feature extractor
feature_extractor
=
cached_feature_extractor_from_config
(
self
.
ctx
.
model_config
,
subfolder
=
KIMIA_WHISPER_SUBFOLDER
,
)
# Use vLLM's standard tokenizer loading (respects tokenizer_mode)
tokenizer
=
self
.
get_tokenizer
()
# Construct processor directly
return
KimiAudioProcessor
(
feature_extractor
=
feature_extractor
,
tokenizer
=
tokenizer
,
tokenizer
=
self
.
get_
tokenizer
()
,
)
def
get_feature_extractor
(
self
,
**
kwargs
:
object
):
"""Get feature extractor using vLLM's cached loader."""
return
cached_feature_extractor_from_config
(
self
.
ctx
.
model_config
,
subfolder
=
KIMIA_WHISPER_SUBFOLDER
)
...
...
@@ -144,26 +148,16 @@ class KimiAudioProcessingInfo(BaseProcessingInfo):
return
{
"audio"
:
1
}
def
get_data_parser
(
self
)
->
"KimiAudioMultiModalDataParser"
:
"""Get data parser for audio inputs."""
feature_extractor
=
self
.
get_feature_extractor
()
return
KimiAudioMultiModalDataParser
(
target_sr
=
feature_extractor
.
sampling_rate
,
expected_hidden_size
=
self
.
_get_expected_hidden_size
(),
)
class
KimiAudioDummyInputsBuilder
(
BaseDummyInputsBuilder
[
KimiAudioProcessingInfo
]):
"""Dummy inputs builder for vLLM registry."""
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
list
[
int
]:
"""Return dummy text as token IDs directly."""
num_audios
=
mm_counts
.
get
(
"audio"
,
0
)
if
num_audios
==
0
:
return
[
198
]
# "Transcribe" tokenized
# Return as token IDs directly to avoid tokenizer issues
return
[
KimiAudioProcessor
.
KIMIA_MEDIA_BEGIN
,
KimiAudioProcessor
.
KIMIA_TEXT_BLANK
,
KimiAudioProcessor
.
KIMIA_MEDIA_END
,
]
*
num_audios
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
return
""
def
get_dummy_mm_data
(
self
,
...
...
@@ -186,6 +180,29 @@ class KimiAudioDummyInputsBuilder(BaseDummyInputsBuilder[KimiAudioProcessingInfo
),
}
def
get_dummy_processor_inputs
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_options
:
Mapping
[
str
,
BaseDummyOptions
],
)
->
ProcessorInputs
:
dummy_mm_data
=
self
.
get_dummy_mm_data
(
seq_len
,
mm_counts
,
mm_options
)
dummy_mm_items
=
self
.
info
.
parse_mm_data
(
dummy_mm_data
)
num_audios
=
mm_counts
.
get
(
"audio"
,
0
)
dummy_tokens
=
(
[
198
]
if
num_audios
==
0
else
[
KimiAudioProcessor
.
KIMIA_MEDIA_BEGIN
,
KimiAudioProcessor
.
KIMIA_TEXT_BLANK
,
KimiAudioProcessor
.
KIMIA_MEDIA_END
,
]
*
num_audios
)
return
ProcessorInputs
(
prompt
=
dummy_tokens
,
mm_data_items
=
dummy_mm_items
)
# Field config for Kimi-Audio multimodal data
_KIMIAUDIO_FIELD_CONFIG
=
{
...
...
@@ -197,10 +214,6 @@ _KIMIAUDIO_FIELD_CONFIG = {
class
KimiAudioMultiModalDataParser
(
MultiModalDataParser
):
"""Custom data parser for Kimi-Audio multimodal data."""
def
__init__
(
self
,
**
kwargs
):
# Whisper expects 16kHz audio
super
().
__init__
(
target_sr
=
16000
,
**
kwargs
)
def
_parse_audio_data
(
self
,
data
:
dict
[
str
,
torch
.
Tensor
]
|
ModalityData
[
AudioItem
],
...
...
@@ -589,9 +602,8 @@ class KimiAudioForConditionalGeneration(
loaded
=
loader
.
load_weights
(
main_weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
# Load Whisper encoder weights from subfolder
whisper_path
=
os
.
path
.
join
(
self
.
model_path
,
f
"
{
KIMIA_WHISPER_SUBFOLDER
}
/model.safetensors"
)
whisper_dir
=
_get_whisper_local_path
(
self
.
model_path
)
whisper_path
=
os
.
path
.
join
(
whisper_dir
,
"model.safetensors"
)
if
os
.
path
.
exists
(
whisper_path
):
whisper_loaded
=
self
.
_load_whisper_weights_from_file
(
whisper_path
)
loaded
.
update
(
whisper_loaded
)
...
...
vllm/model_executor/models/mllama4.py
View file @
196802df
...
...
@@ -63,12 +63,10 @@ from vllm.multimodal.processing import (
BaseDummyInputsBuilder
,
BaseMultiModalProcessor
,
BaseProcessingInfo
,
InputProcessingContext
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
,
)
from
vllm.renderers
import
TokenizeParams
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
...
...
@@ -546,9 +544,6 @@ class Llama4VisionModel(nn.Module):
class
Mllama4ProcessingInfo
(
BaseProcessingInfo
):
def
__init__
(
self
,
ctx
:
InputProcessingContext
)
->
None
:
super
().
__init__
(
ctx
)
def
get_hf_config
(
self
)
->
Llama4Config
:
return
self
.
ctx
.
get_hf_config
(
Llama4Config
)
...
...
@@ -557,9 +552,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
Llama4Processor
,
use_fast
=
kwargs
.
pop
(
"use_fast"
,
True
),
**
kwargs
)
def
get_default_tok_params
(
self
)
->
TokenizeParams
:
return
super
().
get_default_tok_params
().
with_kwargs
(
add_special_tokens
=
False
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
# Although vLLM can support more images from an infra capability
# perspective, we do not recommend using >10 images in practice.
...
...
@@ -597,10 +589,6 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo])
mm_kwargs
:
Mapping
[
str
,
object
],
tok_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
tokenizer
=
self
.
info
.
get_tokenizer
()
if
mm_data
is
None
:
return
tokenizer
(
prompt
,
add_special_tokens
=
False
)
# exclude bos
processed_outputs
=
super
().
_call_hf_processor
(
prompt
=
prompt
,
mm_data
=
mm_data
,
...
...
vllm/model_executor/models/pixtral.py
View file @
196802df
...
...
@@ -172,12 +172,20 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_options
:
Mapping
[
str
,
BaseDummyOptions
],
mm_data
:
MultiModalDataDict
|
None
=
None
,
)
->
ProcessorInputs
:
tokenizer
=
self
.
info
.
get_tokenizer
()
dummy_text
=
self
.
get_dummy_text
(
mm_counts
)
dummy_mm_data
=
self
.
get_dummy_mm_data
(
seq_len
,
mm_counts
,
mm_options
)
dummy_images
=
dummy_mm_data
.
get
(
"image"
,
[])
dummy_mm_data
=
(
self
.
get_dummy_mm_data
(
seq_len
,
mm_counts
,
mm_options
)
if
mm_data
is
None
else
mm_data
)
dummy_mm_items
=
self
.
info
.
parse_mm_data
(
dummy_mm_data
)
dummy_images
=
(
[]
if
"image"
not
in
dummy_mm_data
else
dummy_mm_items
[
"image"
].
get_all
()
)
request
=
ChatCompletionRequest
(
messages
=
[
...
...
@@ -192,8 +200,6 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
res
=
tokenizer
.
mistral
.
encode_chat_completion
(
request
)
dummy_tokens
=
res
.
tokens
dummy_mm_items
=
self
.
info
.
parse_mm_data
(
dummy_mm_data
)
return
ProcessorInputs
(
prompt
=
dummy_tokens
,
mm_data_items
=
dummy_mm_items
)
...
...
vllm/model_executor/models/voxtral.py
View file @
196802df
...
...
@@ -150,13 +150,21 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_options
:
Mapping
[
str
,
BaseDummyOptions
],
mm_data
:
MultiModalDataDict
|
None
=
None
,
)
->
ProcessorInputs
:
tokenizer
=
self
.
info
.
get_tokenizer
()
feature_extractor
=
self
.
info
.
get_hf_processor
().
feature_extractor
dummy_text
=
self
.
get_dummy_text
(
mm_counts
)
dummy_mm_data
=
self
.
get_dummy_mm_data
(
seq_len
,
mm_counts
,
mm_options
)
dummy_audios
=
dummy_mm_data
.
get
(
"audio"
,
[])
dummy_mm_data
=
(
self
.
get_dummy_mm_data
(
seq_len
,
mm_counts
,
mm_options
)
if
mm_data
is
None
else
mm_data
)
dummy_mm_items
=
self
.
info
.
parse_mm_data
(
dummy_mm_data
)
dummy_audios
=
(
[]
if
"audio"
not
in
dummy_mm_data
else
dummy_mm_items
[
"audio"
].
get_all
()
)
audio_chunks
:
list
[
AudioChunk
]
=
[]
format
=
"wav"
...
...
vllm/renderers/qwen_vl.py
View file @
196802df
...
...
@@ -6,11 +6,10 @@ from vllm.config import VllmConfig
from
vllm.tokenizers
import
cached_get_tokenizer
from
vllm.tokenizers.qwen_vl
import
QwenVLTokenizer
from
.base
import
BaseRenderer
from
.hf
import
HfRenderer
class
QwenVLRenderer
(
Base
Renderer
[
QwenVLTokenizer
]
):
class
QwenVLRenderer
(
Hf
Renderer
):
@
classmethod
def
from_config
(
# type: ignore[override]
cls
,
...
...
vllm/renderers/registry.py
View file @
196802df
...
...
@@ -80,13 +80,6 @@ def renderer_from_config(config: "VllmConfig", **kwargs):
model_config
,
**
kwargs
)
# Override tokenizer_mode for Kimi-Audio models
if
model_config
.
architecture
==
"MoonshotKimiaForCausalLM"
:
tokenizer_mode
=
"kimi_audio"
# Update model_config so other components (e.g., multimodal registry)
# also use the correct tokenizer mode
model_config
.
tokenizer_mode
=
"kimi_audio"
if
(
model_config
.
tokenizer_mode
==
"auto"
and
model_config
.
model_impl
==
"terratorch"
...
...
vllm/tokenizers/registry.py
View file @
196802df
...
...
@@ -159,18 +159,6 @@ def resolve_tokenizer_args(
):
tokenizer_mode
=
"mistral"
# Try to use Grok2 tiktoken tokenizer if possible
if
tokenizer_mode
==
"auto"
and
any_pattern_in_repo_files
(
model_name_or_path
=
str
(
tokenizer_name
),
allow_patterns
=
[
"tokenizer.tok.json"
],
revision
=
revision
,
):
tokenizer_mode
=
"grok2"
# Model-specific tokenizers
if
tokenizer_mode
==
"auto"
and
"/Qwen-VL"
in
str
(
tokenizer_name
):
tokenizer_mode
=
"qwen_vl"
# Fallback to HF tokenizer
if
tokenizer_mode
==
"auto"
:
tokenizer_mode
=
"hf"
...
...
vllm/transformers_utils/processors/glm4v.py
View file @
196802df
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/zai-org/CogAgent
from
transformers
import
PreTrainedTokenizer
from
transformers.image_processing_utils_fast
import
BaseImageProcessorFast
from
transformers.image_utils
import
PILImageResampling
...
...
vllm/transformers_utils/processors/kimi_audio.py
View file @
196802df
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa
# mypy: ignore-errors
# coding=utf-8
# Copyright 2026 The Moonshot AI team and the HuggingFace Inc. team. All rights reserved.
# Copyright 2026 The Moonshot AI team and the HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -19,42 +17,13 @@
# limitations under the License.
"""Processor for Kimi-Audio ASR model."""
from
collections.abc
import
Mapping
from
typing
import
Any
import
numpy
as
np
import
torch
from
transformers
import
AutoFeatureExtractor
,
BatchFeature
,
ProcessorMixin
from
transformers
import
BatchFeature
,
ProcessorMixin
from
transformers.audio_utils
import
AudioInput
from
transformers.tokenization_utils_base
import
TextInput
from
vllm.tokenizers.kimi_audio
import
KimiAudioTokenizer
def
_get_feat_extract_output_lengths
(
input_lengths
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Compute output lengths after Whisper feature extraction."""
input_lengths_leave
=
input_lengths
%
100
feat_lengths
=
(
input_lengths_leave
-
1
)
//
2
+
1
output_lengths
=
(
((
feat_lengths
-
1
)
//
2
+
1
-
1
)
//
2
+
1
+
(
input_lengths
//
100
)
*
13
)
return
output_lengths
from
transformers.tokenization_utils_base
import
PreTokenizedInput
,
TextInput
class
KimiAudioProcessor
(
ProcessorMixin
):
r
"""
Constructs a Kimi-Audio processor.
[`KimiAudioProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`], and a tokenizer.
See the [`~KimiAudioProcessor.__call__`] and [`~KimiAudioProcessor.decode`] for more information.
Args:
feature_extractor ([`WhisperFeatureExtractor`], *optional*):
The audio feature extractor.
tokenizer ([`PreTrainedTokenizer`], *optional*):
The text tokenizer.
"""
# Required for ProcessorMixin
attributes
=
[
"feature_extractor"
,
"tokenizer"
]
feature_extractor_class
=
"AutoFeatureExtractor"
...
...
@@ -69,44 +38,30 @@ class KimiAudioProcessor(ProcessorMixin):
AUDIO_SEQ_LEN
:
int
=
376
def
__init__
(
self
,
feature_extractor
=
None
,
tokenizer
=
None
,
**
kwargs
):
# Pass feature_extractor and tokenizer to parent ProcessorMixin
super
().
__init__
(
feature_extractor
=
feature_extractor
,
tokenizer
=
tokenizer
,
**
kwargs
,
)
def
check_argument_for_proper_class
(
self
,
attribute_name
:
str
,
argument
:
Any
):
"""Override to skip class validation for custom tokenizer."""
# Skip validation for tokenizer since KimiAudioTokenizer doesn't inherit
# from PreTrainedTokenizerBase but is compatible
if
attribute_name
==
"tokenizer"
and
argument
is
not
None
:
return
# For other attributes, use default validation
super
().
check_argument_for_proper_class
(
attribute_name
,
argument
)
self
.
feature_extractor
=
feature_extractor
self
.
tokenizer
=
tokenizer
def
__call__
(
self
,
text
:
TextInput
=
None
,
audio
:
AudioInput
=
None
,
text
:
TextInput
|
PreTokenizedInput
|
list
[
TextInput
]
|
list
[
PreTokenizedInput
]
|
None
=
None
,
audio
:
AudioInput
|
None
=
None
,
return_tensors
:
str
=
"pt"
,
**
kwargs
,
)
->
BatchFeature
:
"""
Main method to prepare for the model one or several sequences(s) and audio(s).
if
text
is
not
None
:
if
not
isinstance
(
text
,
list
):
text
=
[
text
]
Args:
text (`str`, `List[str]`):
The sequence or batch of sequences to be encoded.
audio (`np.ndarray`, `List[np.ndarray]`):
The audio or batch of audio to be prepared. Each audio can be a NumPy array.
return_tensors (`str`):
The type of tensors to return ("pt", "np", etc.)
"""
if
text
is
None
:
raise
ValueError
(
"You need to specify either a `text` input to process."
)
text_inputs
=
self
.
tokenizer
(
text
,
return_tensors
=
return_tensors
,
padding
=
True
)
else
:
text_inputs
=
{}
# Process audio if provided
if
audio
is
not
None
:
# Ensure audio is a list
if
isinstance
(
audio
,
np
.
ndarray
):
...
...
@@ -144,19 +99,6 @@ class KimiAudioProcessor(ProcessorMixin):
else
:
audio_inputs
=
{}
# Handle text input - can be string or token IDs from vLLM processor
if
isinstance
(
text
,
list
)
and
len
(
text
)
>
0
and
isinstance
(
text
[
0
],
int
):
# Text is already token IDs (from vLLM processor) - just wrap
text_inputs
=
{
"input_ids"
:
torch
.
tensor
([
text
],
dtype
=
torch
.
long
)}
else
:
# Text is string - tokenize
if
not
isinstance
(
text
,
list
):
text
=
[
text
]
text_inputs
=
self
.
tokenizer
(
text
,
return_tensors
=
return_tensors
,
padding
=
True
)
return
BatchFeature
(
data
=
{
**
text_inputs
,
**
audio_inputs
},
tensor_type
=
return_tensors
,
...
...
vllm/transformers_utils/processors/qwen_vl.py
View file @
196802df
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://huggingface.co/Qwen/Qwen-VL/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
from
transformers.image_processing_utils_fast
import
BaseImageProcessorFast
from
transformers.image_utils
import
PILImageResampling
from
transformers.processing_utils
import
ProcessorMixin
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment