Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
de42abb3
Unverified
Commit
de42abb3
authored
Feb 13, 2026
by
Andreas Karatzas
Committed by
GitHub
Feb 13, 2026
Browse files
[CI] Heavy refactoring of Voxtral multimodal audio model tests (#34294)
Signed-off-by:
Andreas Karatzas
<
akaratza@amd.com
>
parent
60ca7981
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
350 additions
and
70 deletions
+350
-70
requirements/rocm-test.txt
requirements/rocm-test.txt
+2
-0
tests/conftest.py
tests/conftest.py
+0
-2
tests/models/multimodal/generation/test_voxtral.py
tests/models/multimodal/generation/test_voxtral.py
+138
-43
tests/models/multimodal/generation/test_voxtral_realtime.py
tests/models/multimodal/generation/test_voxtral_realtime.py
+6
-2
tests/models/multimodal/generation/vlm_utils/model_utils.py
tests/models/multimodal/generation/vlm_utils/model_utils.py
+88
-0
tests/models/multimodal/processing/test_common.py
tests/models/multimodal/processing/test_common.py
+34
-14
vllm/model_executor/models/voxtral.py
vllm/model_executor/models/voxtral.py
+28
-0
vllm/model_executor/models/whisper_causal.py
vllm/model_executor/models/whisper_causal.py
+44
-4
vllm/reasoning/mistral_reasoning_parser.py
vllm/reasoning/mistral_reasoning_parser.py
+2
-2
vllm/tokenizers/mistral.py
vllm/tokenizers/mistral.py
+1
-1
vllm/v1/attention/backends/rocm_aiter_fa.py
vllm/v1/attention/backends/rocm_aiter_fa.py
+7
-2
No files found.
requirements/rocm-test.txt
View file @
de42abb3
...
@@ -96,3 +96,5 @@ albumentations==1.4.6
...
@@ -96,3 +96,5 @@ albumentations==1.4.6
transformers==4.57.3
transformers==4.57.3
# Pin HF Hub version
# Pin HF Hub version
huggingface-hub==0.36.2
huggingface-hub==0.36.2
# Pin Mistral Common
mistral-common[image,audio]==1.9.1
tests/conftest.py
View file @
de42abb3
...
@@ -419,7 +419,6 @@ class HfRunner:
...
@@ -419,7 +419,6 @@ class HfRunner:
self
.
tokenizer
:
"PreTrainedTokenizer | PreTrainedTokenizerFast"
=
(
self
.
tokenizer
:
"PreTrainedTokenizer | PreTrainedTokenizerFast"
=
(
AutoTokenizer
.
from_pretrained
(
AutoTokenizer
.
from_pretrained
(
model_name
,
model_name
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
)
)
...
@@ -430,7 +429,6 @@ class HfRunner:
...
@@ -430,7 +429,6 @@ class HfRunner:
self
.
processor
=
AutoProcessor
.
from_pretrained
(
self
.
processor
=
AutoProcessor
.
from_pretrained
(
model_name
,
model_name
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
if
skip_tokenizer_init
:
if
skip_tokenizer_init
:
...
...
tests/models/multimodal/generation/test_voxtral.py
View file @
de42abb3
...
@@ -4,16 +4,18 @@
...
@@ -4,16 +4,18 @@
import
json
import
json
import
pytest
import
pytest
import
pytest_asyncio
from
mistral_common.audio
import
Audio
from
mistral_common.audio
import
Audio
from
mistral_common.protocol.instruct.chunk
import
AudioChunk
,
RawAudio
,
TextChunk
from
mistral_common.protocol.instruct.chunk
import
AudioChunk
,
RawAudio
,
TextChunk
from
mistral_common.protocol.instruct.messages
import
UserMessage
from
mistral_common.protocol.instruct.messages
import
UserMessage
from
transformers
import
VoxtralForConditionalGeneration
from
vllm.tokenizers.mistral
import
MistralTokenizer
from
vllm.tokenizers.mistral
import
MistralTokenizer
from
....conftest
import
AudioTestAssets
from
....conftest
import
AudioTestAssets
from
....utils
import
RemoteOpenAIServer
from
....utils
import
RemoteOpenAIServer
from
...utils
import
check_logprobs_close
from
.test_ultravox
import
MULTI_AUDIO_PROMPT
,
run_multi_audio_test
from
.test_ultravox
import
MULTI_AUDIO_PROMPT
,
run_multi_audio_test
from
.vlm_utils
import
model_utils
MODEL_NAME
=
"mistralai/Voxtral-Mini-3B-2507"
MODEL_NAME
=
"mistralai/Voxtral-Mini-3B-2507"
MISTRAL_FORMAT_ARGS
=
[
MISTRAL_FORMAT_ARGS
=
[
...
@@ -26,40 +28,21 @@ MISTRAL_FORMAT_ARGS = [
...
@@ -26,40 +28,21 @@ MISTRAL_FORMAT_ARGS = [
]
]
@
pytest
.
fixture
()
def
_get_prompt
(
audio_assets
:
AudioTestAssets
,
question
:
str
)
->
list
[
int
]:
def
server
(
request
,
audio_assets
:
AudioTestAssets
):
"""Build a token-ID prompt via mistral_common for vLLM offline inference."""
args
=
[
"--enforce-eager"
,
"--limit-mm-per-prompt"
,
json
.
dumps
({
"audio"
:
len
(
audio_assets
)}),
]
+
MISTRAL_FORMAT_ARGS
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
,
env_dict
=
{
"VLLM_AUDIO_FETCH_TIMEOUT"
:
"30"
}
)
as
remote_server
:
yield
remote_server
@
pytest_asyncio
.
fixture
async
def
client
(
server
):
async
with
server
.
get_async_client
()
as
async_client
:
yield
async_client
def
_get_prompt
(
audio_assets
,
question
):
tokenizer
=
MistralTokenizer
.
from_pretrained
(
MODEL_NAME
)
tokenizer
=
MistralTokenizer
.
from_pretrained
(
MODEL_NAME
)
audios
=
[
audios
=
[
Audio
.
from_file
(
str
(
a
udio_assets
[
i
]
.
get_local_path
()),
strict
=
False
)
Audio
.
from_file
(
str
(
a
sset
.
get_local_path
()),
strict
=
False
)
for
i
in
range
(
len
(
audio_assets
))
for
asset
in
audio_assets
]
]
audio_chunks
=
[
audio_chunks
=
[
AudioChunk
(
input_audio
=
RawAudio
.
from_audio
(
audio
))
for
audio
in
audios
AudioChunk
(
input_audio
=
RawAudio
.
from_audio
(
audio
))
for
audio
in
audios
]
]
text_chunk
=
TextChunk
(
text
=
question
)
messages
=
[
messages
=
[
UserMessage
(
content
=
[
*
audio_chunks
,
t
ext
_c
hunk
]).
to_openai
()
]
UserMessage
(
content
=
[
*
audio_chunks
,
T
ext
C
hunk
(
text
=
question
)
]).
to_openai
()
]
return
tokenizer
.
apply_chat_template
(
messages
=
messages
)
return
tokenizer
.
apply_chat_template
(
messages
=
messages
)
...
@@ -77,7 +60,7 @@ def test_models_with_multiple_audios(
...
@@ -77,7 +60,7 @@ def test_models_with_multiple_audios(
vllm_prompt
=
_get_prompt
(
audio_assets
,
MULTI_AUDIO_PROMPT
)
vllm_prompt
=
_get_prompt
(
audio_assets
,
MULTI_AUDIO_PROMPT
)
run_multi_audio_test
(
run_multi_audio_test
(
vllm_runner
,
vllm_runner
,
[(
vllm_prompt
,
[
a
udio
.
audio_and_sample_rate
for
a
udio
in
audio_assets
])],
[(
vllm_prompt
,
[
a
.
audio_and_sample_rate
for
a
in
audio_assets
])],
# type: ignore[list-item]
MODEL_NAME
,
MODEL_NAME
,
dtype
=
dtype
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
...
@@ -86,30 +69,142 @@ def test_models_with_multiple_audios(
...
@@ -86,30 +69,142 @@ def test_models_with_multiple_audios(
)
)
@
pytest
.
mark
.
asyncio
def
test_online_serving
(
vllm_runner
,
audio_assets
:
AudioTestAssets
):
async
def
test_online_serving
(
client
,
audio_assets
:
AudioTestAssets
):
"""Two-layer accuracy and serving validation using Mistral format.
"""Exercises online serving with/without chunked prefill enabled."""
1. Offline vLLM greedy output (runs first to avoid CUDA fork issues
with multiprocessing - see vlm_utils/core.py).
2. Online OpenAI-compatible API output must match offline — validates
that the serving path (chat template, audio encoding, tokenization)
does not corrupt anything.
Steps run sequentially so each releases the GPU before the next starts.
"""
def
asset_to_chunk
(
asset
):
question
=
f
"What's happening in these
{
len
(
audio_assets
)
}
audio clips?"
max_tokens
=
10
audio_data
=
[
asset
.
audio_and_sample_rate
for
asset
in
audio_assets
]
vllm_prompt
=
_get_prompt
(
audio_assets
,
question
)
with
vllm_runner
(
MODEL_NAME
,
dtype
=
"half"
,
enforce_eager
=
True
,
tokenizer_mode
=
"mistral"
,
config_format
=
"mistral"
,
load_format
=
"mistral"
,
limit_mm_per_prompt
=
{
"audio"
:
len
(
audio_assets
)},
)
as
vllm_model
:
offline_outputs
=
vllm_model
.
generate_greedy
(
[
vllm_prompt
],
max_tokens
,
audios
=
[
audio_data
],
)
offline_text
=
offline_outputs
[
0
][
1
]
assert
offline_text
,
"Offline vLLM inference produced empty output"
def
_asset_to_openai_chunk
(
asset
):
audio
=
Audio
.
from_file
(
str
(
asset
.
get_local_path
()),
strict
=
False
)
audio
=
Audio
.
from_file
(
str
(
asset
.
get_local_path
()),
strict
=
False
)
audio
.
format
=
"wav"
audio
.
format
=
"wav"
audio_dict
=
AudioChunk
.
from_audio
(
audio
).
to_openai
()
return
AudioChunk
.
from_audio
(
audio
).
to_openai
()
return
audio_dict
audio_chunks
=
[
asset_to_chunk
(
asset
)
for
asset
in
audio_assets
]
text
=
f
"What's happening in these
{
len
(
audio_assets
)
}
audio clips?"
messages
=
[
messages
=
[
{
{
"role"
:
"user"
,
"role"
:
"user"
,
"content"
:
[
*
audio_chunks
,
{
"type"
:
"text"
,
"text"
:
text
}],
"content"
:
[
*
[
_asset_to_openai_chunk
(
a
)
for
a
in
audio_assets
],
{
"type"
:
"text"
,
"text"
:
question
},
],
}
}
]
]
chat_completion
=
await
client
.
chat
.
completions
.
create
(
server_args
=
[
model
=
MODEL_NAME
,
messages
=
messages
,
max_tokens
=
10
"--enforce-eager"
,
"--limit-mm-per-prompt"
,
json
.
dumps
({
"audio"
:
len
(
audio_assets
)}),
*
MISTRAL_FORMAT_ARGS
,
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
server_args
,
env_dict
=
{
"VLLM_AUDIO_FETCH_TIMEOUT"
:
"30"
},
)
as
remote_server
:
client
=
remote_server
.
get_client
()
completion
=
client
.
chat
.
completions
.
create
(
model
=
MODEL_NAME
,
messages
=
messages
,
max_tokens
=
max_tokens
,
temperature
=
0
,
)
)
assert
len
(
chat_completion
.
choices
)
==
1
assert
len
(
completion
.
choices
)
==
1
choice
=
chat_completion
.
choices
[
0
]
choice
=
completion
.
choices
[
0
]
assert
choice
.
message
.
content
==
"In the first audio clip, you hear a brief"
assert
choice
.
finish_reason
==
"length"
assert
choice
.
finish_reason
==
"length"
assert
choice
.
message
.
content
==
offline_text
,
(
f
"Online serving output does not match offline inference.
\n
"
f
" Online:
{
choice
.
message
.
content
!
r
}
\n
"
f
" Offline:
{
offline_text
!
r
}
"
)
def
test_hf_reference
(
hf_runner
,
vllm_runner
,
audio_assets
:
AudioTestAssets
):
"""Compare vLLM Mistral-format output against HF Transformers reference.
Instead of requiring an exact text match (which is brittle across
attention backends), we compare per-token logprobs using the standard
check_logprobs_close helper: when tokens diverge at a position, each
runner's chosen token must appear in the other's top-k logprobs.
Marked xfail(strict=False) so remaining edge-case mismatches
don't block CI.
"""
question
=
f
"What's happening in these
{
len
(
audio_assets
)
}
audio clips?"
max_tokens
=
10
num_logprobs
=
5
audio_data
=
[
asset
.
audio_and_sample_rate
for
asset
in
audio_assets
]
vllm_prompt
=
_get_prompt
(
audio_assets
,
question
)
with
vllm_runner
(
MODEL_NAME
,
dtype
=
"half"
,
enforce_eager
=
True
,
tokenizer_mode
=
"mistral"
,
config_format
=
"mistral"
,
load_format
=
"mistral"
,
limit_mm_per_prompt
=
{
"audio"
:
len
(
audio_assets
)},
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
[
vllm_prompt
],
max_tokens
,
num_logprobs
,
audios
=
[
audio_data
],
)
assert
vllm_outputs
[
0
][
1
],
"vLLM inference produced empty output"
with
hf_runner
(
MODEL_NAME
,
dtype
=
"half"
,
auto_cls
=
VoxtralForConditionalGeneration
,
)
as
hf_model
:
hf_model
=
model_utils
.
voxtral_patch_hf_runner
(
hf_model
)
hf_outputs
=
hf_model
.
generate_greedy_logprobs_limit
(
[
question
],
max_tokens
,
num_logprobs
,
audios
=
[
audio_data
],
)
assert
hf_outputs
[
0
][
1
],
"HF Transformers produced empty output"
print
(
f
"HF Reference Comparison
\n
"
f
" vLLM:
{
vllm_outputs
[
0
][
1
]
!
r
}
\n
"
f
" HF:
{
hf_outputs
[
0
][
1
]
!
r
}
"
)
check_logprobs_close
(
outputs_0_lst
=
vllm_outputs
,
outputs_1_lst
=
hf_outputs
,
name_0
=
"vllm"
,
name_1
=
"hf"
,
)
tests/models/multimodal/generation/test_voxtral_realtime.py
View file @
de42abb3
...
@@ -10,6 +10,7 @@ from mistral_common.protocol.transcription.request import (
...
@@ -10,6 +10,7 @@ from mistral_common.protocol.transcription.request import (
TranscriptionRequest
,
TranscriptionRequest
,
)
)
from
mistral_common.tokens.tokenizers.mistral
import
MistralTokenizer
from
mistral_common.tokens.tokenizers.mistral
import
MistralTokenizer
from
mistral_common.tokens.tokenizers.tekken
import
SpecialTokenPolicy
from
vllm
import
LLM
,
EngineArgs
,
SamplingParams
from
vllm
import
LLM
,
EngineArgs
,
SamplingParams
from
vllm.assets.audio
import
AudioAsset
from
vllm.assets.audio
import
AudioAsset
...
@@ -26,7 +27,7 @@ ENGINE_CONFIG = dict(
...
@@ -26,7 +27,7 @@ ENGINE_CONFIG = dict(
load_format
=
"mistral"
,
load_format
=
"mistral"
,
tokenizer_mode
=
"mistral"
,
tokenizer_mode
=
"mistral"
,
enforce_eager
=
True
,
enforce_eager
=
True
,
gpu_memory_utilization
=
0.
4
,
gpu_memory_utilization
=
0.
9
,
)
)
...
@@ -148,6 +149,9 @@ async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine)
...
@@ -148,6 +149,9 @@ async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine)
output_tokens_list
.
append
(
output_tokens
)
output_tokens_list
.
append
(
output_tokens
)
texts
=
[
tokenizer
.
decode
(
output_tokens
)
for
output_tokens
in
output_tokens_list
]
texts
=
[
tokenizer
.
decode
(
output_tokens
,
special_token_policy
=
SpecialTokenPolicy
.
IGNORE
)
for
output_tokens
in
output_tokens_list
]
texts
[
1
]
=
texts
[
1
].
replace
(
"a base hit"
,
"OBS"
).
replace
(
"oh my"
,
"oh, my"
)
texts
[
1
]
=
texts
[
1
].
replace
(
"a base hit"
,
"OBS"
).
replace
(
"oh my"
,
"oh, my"
)
assert
texts
==
EXPECTED_TEXT
assert
texts
==
EXPECTED_TEXT
tests/models/multimodal/generation/vlm_utils/model_utils.py
View file @
de42abb3
...
@@ -1215,3 +1215,91 @@ def tarsier_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
...
@@ -1215,3 +1215,91 @@ def tarsier_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
hf_processor
.
patch_size
=
vision_encoder_info
.
get_patch_size
()
hf_processor
.
patch_size
=
vision_encoder_info
.
get_patch_size
()
return
hf_model
return
hf_model
def
voxtral_patch_hf_runner
(
hf_model
:
"HfRunner"
)
->
"HfRunner"
:
"""Patch HfRunner for Voxtral's conversation-based processor.
Two issues in HfRunner require patching:
1. VoxtralProcessor requires ``apply_chat_template()`` with conversation
dicts (accepting ``url``, ``path``, or ``base64`` audio) rather than
the standard ``processor(text=, audio=, sampling_rate=)`` interface.
2. HfRunner.get_inputs cannot handle multi-audio per prompt because it
mis-unpacks ``[(arr1, sr1), (arr2, sr2)]`` via a ``len == 2`` check.
We override ``get_inputs`` to build conversation dicts and call
``apply_chat_template`` directly, bypassing both issues. We also wrap
``model.generate`` to strip prompt tokens before decoding, since
HfRunner.generate calls batch_decode on the full sequence (prompt +
generated).
"""
import
base64
import
io
import
soundfile
as
sf
processor
=
hf_model
.
processor
def
_audio_to_base64
(
audio_array
,
sample_rate
:
int
)
->
str
:
"""Encode a numpy audio array as a base64 WAV string."""
buf
=
io
.
BytesIO
()
sf
.
write
(
buf
,
audio_array
,
int
(
sample_rate
),
format
=
"WAV"
)
return
base64
.
b64encode
(
buf
.
getvalue
()).
decode
(
"ascii"
)
def
patched_get_inputs
(
prompts
,
images
=
None
,
videos
=
None
,
audios
=
None
,
**
kwargs
):
all_inputs
=
[]
for
i
,
prompt
in
enumerate
(
prompts
):
content
:
list
[
dict
]
=
[]
if
audios
is
not
None
and
audios
[
i
]
is
not
None
:
items
=
audios
[
i
]
if
not
isinstance
(
items
,
list
):
items
=
[
items
]
for
item
in
items
:
if
isinstance
(
item
,
(
list
,
tuple
))
and
len
(
item
)
==
2
:
arr
,
sr
=
item
else
:
arr
,
sr
=
item
,
16_000
content
.
append
(
{
"type"
:
"audio"
,
"base64"
:
_audio_to_base64
(
arr
,
sr
),
}
)
content
.
append
({
"type"
:
"text"
,
"text"
:
prompt
})
inputs
=
processor
.
apply_chat_template
(
[{
"role"
:
"user"
,
"content"
:
content
}]
)
if
hasattr
(
inputs
,
"to"
):
inputs
=
inputs
.
to
(
dtype
=
hf_model
.
dtype
)
all_inputs
.
append
(
inputs
)
return
all_inputs
_orig_generate
=
hf_model
.
model
.
generate
def
patched_generate
(
*
args
,
**
kwargs
):
"""Strip prompt tokens so only generated tokens are decoded."""
input_ids
=
kwargs
.
get
(
"input_ids"
)
if
input_ids
is
None
and
args
:
input_ids
=
args
[
0
]
prompt_len
=
input_ids
.
shape
[
1
]
if
input_ids
is
not
None
else
0
output
=
_orig_generate
(
*
args
,
**
kwargs
)
if
prompt_len
:
if
isinstance
(
output
,
torch
.
Tensor
):
output
=
output
[:,
prompt_len
:]
else
:
# GenerateDecoderOnlyOutput - trim sequences but preserve
# scores/logits so generate_greedy_logprobs_limit can
# extract per-token logprobs.
output
.
sequences
=
output
.
sequences
[:,
prompt_len
:]
return
output
hf_model
.
get_inputs
=
patched_get_inputs
# type: ignore[method-assign, assignment]
hf_model
.
model
.
generate
=
patched_generate
# type: ignore[method-assign]
return
hf_model
tests/models/multimodal/processing/test_common.py
View file @
de42abb3
...
@@ -184,6 +184,25 @@ def get_text_token_prompts(
...
@@ -184,6 +184,25 @@ def get_text_token_prompts(
text_prompt
:
str
|
None
text_prompt
:
str
|
None
token_prompt
:
list
[
int
]
token_prompt
:
list
[
int
]
if
isinstance
(
tokenizer
,
MistralTokenizer
):
if
isinstance
(
tokenizer
,
MistralTokenizer
):
# ChatCompletionRequest only supports ImageChunk natively;
# for other modalities (e.g. audio), fall back to the model's
# own dummy inputs builder which knows the right placeholders.
has_non_image
=
any
(
k
!=
"image"
and
count
>
0
for
k
,
count
in
mm_counts
.
items
()
)
if
has_non_image
:
inputs
=
dummy_inputs
.
get_dummy_processor_inputs
(
model_config
.
max_model_len
,
mm_counts
,
)
text_prompt
=
None
token_prompt
=
(
inputs
.
prompt
if
isinstance
(
inputs
.
prompt
,
list
)
else
tokenizer
.
encode
(
inputs
.
prompt
,
add_special_tokens
=
False
)
)
else
:
images
=
parsed_data
.
get
(
"image"
,
[])
images
=
parsed_data
.
get
(
"image"
,
[])
request
=
ChatCompletionRequest
(
request
=
ChatCompletionRequest
(
messages
=
[
messages
=
[
...
@@ -197,7 +216,8 @@ def get_text_token_prompts(
...
@@ -197,7 +216,8 @@ def get_text_token_prompts(
)
)
res
=
tokenizer
.
mistral
.
encode_chat_completion
(
request
)
res
=
tokenizer
.
mistral
.
encode_chat_completion
(
request
)
# Mistral does not support decode_tokens with skip_special_tokens=False
# Mistral does not support decode_tokens with
# skip_special_tokens=False
text_prompt
=
None
text_prompt
=
None
token_prompt
=
res
.
tokens
token_prompt
=
res
.
tokens
else
:
else
:
...
...
vllm/model_executor/models/voxtral.py
View file @
de42abb3
...
@@ -291,6 +291,34 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
...
@@ -291,6 +291,34 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
# skip validation here
# skip validation here
...
...
def
_apply_hf_processor_mm_only
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
tokenization_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
processor_data
,
passthrough_data
=
self
.
_get_hf_mm_data
(
mm_items
)
audios
=
processor_data
.
get
(
"audios"
,
[])
if
not
isinstance
(
audios
,
list
):
audios
=
[
audios
]
audio_config
=
processor
.
_audio_processor
.
audio_config
audio_tensors
:
list
[
torch
.
Tensor
]
=
[]
for
audio
in
audios
:
audio
=
np
.
asarray
(
audio
,
dtype
=
np
.
float32
).
ravel
()
if
not
audio_config
.
is_streaming
:
audio
=
processor
.
_audio_processor
.
pad
(
audio
,
processor
.
sampling_rate
,
audio_config
.
is_streaming
,
)
audio_tensors
.
append
(
torch
.
tensor
(
audio
))
result
=
BatchFeature
({
"audio_arrays"
:
audio_tensors
}
if
audio_tensors
else
{})
result
.
update
(
passthrough_data
)
return
result
def
_get_prompt_updates
(
def
_get_prompt_updates
(
self
,
self
,
mm_items
:
MultiModalDataItems
,
mm_items
:
MultiModalDataItems
,
...
...
vllm/model_executor/models/whisper_causal.py
View file @
de42abb3
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
import
copy
import
functools
import
functools
import
logging
import
math
import
math
from
dataclasses
import
replace
from
dataclasses
import
replace
from
functools
import
partial
from
functools
import
partial
...
@@ -30,11 +31,20 @@ from vllm.v1.attention.backend import (
...
@@ -30,11 +31,20 @@ from vllm.v1.attention.backend import (
subclass_attention_backend_with_overrides
,
subclass_attention_backend_with_overrides
,
)
)
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionBackend
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionBackend
try
:
from
vllm.v1.attention.backends.rocm_aiter_fa
import
AiterFlashAttentionBackend
except
ImportError
:
AiterFlashAttentionBackend
=
None
from
vllm.v1.attention.backends.rocm_attn
import
RocmAttentionBackend
from
vllm.v1.attention.backends.triton_attn
import
TritonAttentionBackend
from
vllm.v1.attention.selector
import
get_attn_backend
from
vllm.v1.attention.selector
import
get_attn_backend
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
.utils
import
make_layers
from
.utils
import
make_layers
logger
=
logging
.
getLogger
(
__name__
)
CausalRMSNorm
=
partial
(
RMSNorm
,
eps
=
1e-5
)
CausalRMSNorm
=
partial
(
RMSNorm
,
eps
=
1e-5
)
...
@@ -122,6 +132,13 @@ def create_whisper_attention_backend_with_block_pooling(
...
@@ -122,6 +132,13 @@ def create_whisper_attention_backend_with_block_pooling(
num_kv_heads
=
kv_cache_spec
.
num_kv_heads
//
block_pool_size
,
num_kv_heads
=
kv_cache_spec
.
num_kv_heads
//
block_pool_size
,
)
)
super
().
__init__
(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
super
().
__init__
(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
# Override model_config-derived values with the actual
# encoder values from kv_cache_spec
self
.
num_heads_kv
=
kv_cache_spec
.
num_kv_heads
self
.
headdim
=
kv_cache_spec
.
head_size
# num_heads_q for the encoder is the same as num_kv_heads
# (no GQA in whisper encoder)
self
.
num_heads_q
=
kv_cache_spec
.
num_kv_heads
def
build
(
def
build
(
self
,
self
,
...
@@ -192,13 +209,36 @@ def create_whisper_attention_backend_with_block_pooling(
...
@@ -192,13 +209,36 @@ def create_whisper_attention_backend_with_block_pooling(
output_block_scale
,
output_block_scale
,
)
)
if
not
issubclass
(
underlying_attn_backend
,
FlashAttentionBackend
):
_SUPPORTED_BACKENDS
=
tuple
(
b
for
b
in
(
AiterFlashAttentionBackend
,
FlashAttentionBackend
,
RocmAttentionBackend
,
TritonAttentionBackend
,
)
if
b
is
not
None
)
if
not
issubclass
(
underlying_attn_backend
,
_SUPPORTED_BACKENDS
):
raise
NotImplementedError
(
raise
NotImplementedError
(
f
"
{
underlying_attn_backend
}
is not yet supported."
f
"
{
underlying_attn_backend
}
is not yet supported."
"Contributions to support more backends are much "
"Contributions to support more backends are much "
"appreciated."
"appreciated."
)
)
if
not
issubclass
(
underlying_attn_backend
,
FlashAttentionBackend
):
logger
.
info
(
"Using %s for Whisper causal attention with block pooling. "
"This backend was recently enabled for this model. "
"If you encounter any accuracy or performance issues, "
"please open an issue at "
"https://github.com/vllm-project/vllm/issues "
"with the [ROCm] tag so it can be triaged by the "
"appropriate team."
,
underlying_attn_backend
.
get_name
(),
)
attn_backend
=
subclass_attention_backend_with_overrides
(
attn_backend
=
subclass_attention_backend_with_overrides
(
name_prefix
=
prefix
,
name_prefix
=
prefix
,
attention_backend_cls
=
underlying_attn_backend
,
attention_backend_cls
=
underlying_attn_backend
,
...
@@ -209,14 +249,14 @@ def create_whisper_attention_backend_with_block_pooling(
...
@@ -209,14 +249,14 @@ def create_whisper_attention_backend_with_block_pooling(
block_size
,
block_size
,
num_kv_heads
,
num_kv_heads
,
head_size
,
head_size
,
cache_dtype_str
:
(
cache_dtype_str
:
underlying_attn_backend
.
get_kv_cache_shape
(
2
,
num_blocks
,
num_blocks
,
# we stretch each block by `block_pool_size`
# we stretch each block by `block_pool_size`
block_size
*
block_pool_size
,
block_size
*
block_pool_size
,
num_kv_heads
//
block_pool_size
,
num_kv_heads
//
block_pool_size
,
head_size
,
head_size
,
),
# TODO: generalize to other backends
cache_dtype_str
,
),
"forward_includes_kv_cache_update"
:
True
,
"forward_includes_kv_cache_update"
:
True
,
},
},
)
)
...
...
vllm/reasoning/mistral_reasoning_parser.py
View file @
de42abb3
...
@@ -43,8 +43,8 @@ class MistralReasoningParser(BaseThinkingReasoningParser):
...
@@ -43,8 +43,8 @@ class MistralReasoningParser(BaseThinkingReasoningParser):
"constructor during construction."
"constructor during construction."
)
)
self
.
start_token_id
=
tokenizer
.
tokenizer
.
get_
contro
l_token
(
self
.
start_token
)
self
.
start_token_id
=
tokenizer
.
tokenizer
.
get_
specia
l_token
(
self
.
start_token
)
self
.
end_token_id
=
tokenizer
.
tokenizer
.
get_
contro
l_token
(
self
.
end_token
)
self
.
end_token_id
=
tokenizer
.
tokenizer
.
get_
specia
l_token
(
self
.
end_token
)
if
self
.
start_token_id
is
None
or
self
.
end_token_id
is
None
:
if
self
.
start_token_id
is
None
or
self
.
end_token_id
is
None
:
raise
RuntimeError
(
raise
RuntimeError
(
...
...
vllm/tokenizers/mistral.py
View file @
de42abb3
...
@@ -517,7 +517,7 @@ class MistralTokenizer(TokenizerLike):
...
@@ -517,7 +517,7 @@ class MistralTokenizer(TokenizerLike):
return
[
self
.
tokenizer
.
id_to_piece
(
token_id
)
for
token_id
in
ids
]
return
[
self
.
tokenizer
.
id_to_piece
(
token_id
)
for
token_id
in
ids
]
non_skip_special_tokens_ids
=
{
non_skip_special_tokens_ids
=
{
self
.
tokenizer
.
get_
contro
l_token
(
SpecialTokens
.
tool_calls
),
self
.
tokenizer
.
get_
specia
l_token
(
SpecialTokens
.
tool_calls
),
}
}
if
isinstance
(
self
.
instruct
,
InstructTokenizerV13
):
if
isinstance
(
self
.
instruct
,
InstructTokenizerV13
):
if
self
.
instruct
.
BEGIN_THINK
:
if
self
.
instruct
.
BEGIN_THINK
:
...
...
vllm/v1/attention/backends/rocm_aiter_fa.py
View file @
de42abb3
...
@@ -425,8 +425,13 @@ class AiterFlashAttentionMetadataBuilder(
...
@@ -425,8 +425,13 @@ class AiterFlashAttentionMetadataBuilder(
sliding_window_configs
:
set
[
tuple
[
int
,
int
]
|
None
]
=
set
()
sliding_window_configs
:
set
[
tuple
[
int
,
int
]
|
None
]
=
set
()
layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
)
layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
)
for
layer
in
layers
.
values
():
for
name
,
layer
in
layers
.
items
():
assert
isinstance
(
layer
.
impl
,
AiterFlashAttentionImpl
)
if
name
not
in
layer_names
:
continue
assert
isinstance
(
layer
.
impl
,
AiterFlashAttentionImpl
),
(
"Aiter Flash Attention Metadata Builder can only be used "
"with Aiter Flash Attention Impl."
)
sliding_window_configs
.
add
(
layer
.
impl
.
sliding_window
)
sliding_window_configs
.
add
(
layer
.
impl
.
sliding_window
)
while
len
(
sliding_window_configs
)
>
0
:
while
len
(
sliding_window_configs
)
>
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment