Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
de42abb3
Unverified
Commit
de42abb3
authored
Feb 13, 2026
by
Andreas Karatzas
Committed by
GitHub
Feb 13, 2026
Browse files
[CI] Heavy refactoring of Voxtral multimodal audio model tests (#34294)
Signed-off-by:
Andreas Karatzas
<
akaratza@amd.com
>
parent
60ca7981
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
350 additions
and
70 deletions
+350
-70
requirements/rocm-test.txt
requirements/rocm-test.txt
+2
-0
tests/conftest.py
tests/conftest.py
+0
-2
tests/models/multimodal/generation/test_voxtral.py
tests/models/multimodal/generation/test_voxtral.py
+138
-43
tests/models/multimodal/generation/test_voxtral_realtime.py
tests/models/multimodal/generation/test_voxtral_realtime.py
+6
-2
tests/models/multimodal/generation/vlm_utils/model_utils.py
tests/models/multimodal/generation/vlm_utils/model_utils.py
+88
-0
tests/models/multimodal/processing/test_common.py
tests/models/multimodal/processing/test_common.py
+34
-14
vllm/model_executor/models/voxtral.py
vllm/model_executor/models/voxtral.py
+28
-0
vllm/model_executor/models/whisper_causal.py
vllm/model_executor/models/whisper_causal.py
+44
-4
vllm/reasoning/mistral_reasoning_parser.py
vllm/reasoning/mistral_reasoning_parser.py
+2
-2
vllm/tokenizers/mistral.py
vllm/tokenizers/mistral.py
+1
-1
vllm/v1/attention/backends/rocm_aiter_fa.py
vllm/v1/attention/backends/rocm_aiter_fa.py
+7
-2
No files found.
requirements/rocm-test.txt
View file @
de42abb3
...
...
@@ -96,3 +96,5 @@ albumentations==1.4.6
transformers==4.57.3
# Pin HF Hub version
huggingface-hub==0.36.2
# Pin Mistral Common
mistral-common[image,audio]==1.9.1
tests/conftest.py
View file @
de42abb3
...
...
@@ -419,7 +419,6 @@ class HfRunner:
self
.
tokenizer
:
"PreTrainedTokenizer | PreTrainedTokenizerFast"
=
(
AutoTokenizer
.
from_pretrained
(
model_name
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
)
...
...
@@ -430,7 +429,6 @@ class HfRunner:
self
.
processor
=
AutoProcessor
.
from_pretrained
(
model_name
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
if
skip_tokenizer_init
:
...
...
tests/models/multimodal/generation/test_voxtral.py
View file @
de42abb3
...
...
@@ -4,16 +4,18 @@
import
json
import
pytest
import
pytest_asyncio
from
mistral_common.audio
import
Audio
from
mistral_common.protocol.instruct.chunk
import
AudioChunk
,
RawAudio
,
TextChunk
from
mistral_common.protocol.instruct.messages
import
UserMessage
from
transformers
import
VoxtralForConditionalGeneration
from
vllm.tokenizers.mistral
import
MistralTokenizer
from
....conftest
import
AudioTestAssets
from
....utils
import
RemoteOpenAIServer
from
...utils
import
check_logprobs_close
from
.test_ultravox
import
MULTI_AUDIO_PROMPT
,
run_multi_audio_test
from
.vlm_utils
import
model_utils
MODEL_NAME
=
"mistralai/Voxtral-Mini-3B-2507"
MISTRAL_FORMAT_ARGS
=
[
...
...
@@ -26,40 +28,21 @@ MISTRAL_FORMAT_ARGS = [
]
@
pytest
.
fixture
()
def
server
(
request
,
audio_assets
:
AudioTestAssets
):
args
=
[
"--enforce-eager"
,
"--limit-mm-per-prompt"
,
json
.
dumps
({
"audio"
:
len
(
audio_assets
)}),
]
+
MISTRAL_FORMAT_ARGS
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
,
env_dict
=
{
"VLLM_AUDIO_FETCH_TIMEOUT"
:
"30"
}
)
as
remote_server
:
yield
remote_server
@
pytest_asyncio
.
fixture
async
def
client
(
server
):
async
with
server
.
get_async_client
()
as
async_client
:
yield
async_client
def
_get_prompt
(
audio_assets
,
question
):
def
_get_prompt
(
audio_assets
:
AudioTestAssets
,
question
:
str
)
->
list
[
int
]:
"""Build a token-ID prompt via mistral_common for vLLM offline inference."""
tokenizer
=
MistralTokenizer
.
from_pretrained
(
MODEL_NAME
)
audios
=
[
Audio
.
from_file
(
str
(
a
udio_assets
[
i
]
.
get_local_path
()),
strict
=
False
)
for
i
in
range
(
len
(
audio_assets
))
Audio
.
from_file
(
str
(
a
sset
.
get_local_path
()),
strict
=
False
)
for
asset
in
audio_assets
]
audio_chunks
=
[
AudioChunk
(
input_audio
=
RawAudio
.
from_audio
(
audio
))
for
audio
in
audios
]
text_chunk
=
TextChunk
(
text
=
question
)
messages
=
[
UserMessage
(
content
=
[
*
audio_chunks
,
t
ext
_c
hunk
]).
to_openai
()
]
messages
=
[
UserMessage
(
content
=
[
*
audio_chunks
,
T
ext
C
hunk
(
text
=
question
)
]).
to_openai
()
]
return
tokenizer
.
apply_chat_template
(
messages
=
messages
)
...
...
@@ -77,7 +60,7 @@ def test_models_with_multiple_audios(
vllm_prompt
=
_get_prompt
(
audio_assets
,
MULTI_AUDIO_PROMPT
)
run_multi_audio_test
(
vllm_runner
,
[(
vllm_prompt
,
[
a
udio
.
audio_and_sample_rate
for
a
udio
in
audio_assets
])],
[(
vllm_prompt
,
[
a
.
audio_and_sample_rate
for
a
in
audio_assets
])],
# type: ignore[list-item]
MODEL_NAME
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
...
...
@@ -86,30 +69,142 @@ def test_models_with_multiple_audios(
)
@
pytest
.
mark
.
asyncio
async
def
test_online_serving
(
client
,
audio_assets
:
AudioTestAssets
):
"""Exercises online serving with/without chunked prefill enabled."""
def
test_online_serving
(
vllm_runner
,
audio_assets
:
AudioTestAssets
):
"""Two-layer accuracy and serving validation using Mistral format.
1. Offline vLLM greedy output (runs first to avoid CUDA fork issues
with multiprocessing - see vlm_utils/core.py).
2. Online OpenAI-compatible API output must match offline — validates
that the serving path (chat template, audio encoding, tokenization)
does not corrupt anything.
Steps run sequentially so each releases the GPU before the next starts.
"""
def
asset_to_chunk
(
asset
):
question
=
f
"What's happening in these
{
len
(
audio_assets
)
}
audio clips?"
max_tokens
=
10
audio_data
=
[
asset
.
audio_and_sample_rate
for
asset
in
audio_assets
]
vllm_prompt
=
_get_prompt
(
audio_assets
,
question
)
with
vllm_runner
(
MODEL_NAME
,
dtype
=
"half"
,
enforce_eager
=
True
,
tokenizer_mode
=
"mistral"
,
config_format
=
"mistral"
,
load_format
=
"mistral"
,
limit_mm_per_prompt
=
{
"audio"
:
len
(
audio_assets
)},
)
as
vllm_model
:
offline_outputs
=
vllm_model
.
generate_greedy
(
[
vllm_prompt
],
max_tokens
,
audios
=
[
audio_data
],
)
offline_text
=
offline_outputs
[
0
][
1
]
assert
offline_text
,
"Offline vLLM inference produced empty output"
def
_asset_to_openai_chunk
(
asset
):
audio
=
Audio
.
from_file
(
str
(
asset
.
get_local_path
()),
strict
=
False
)
audio
.
format
=
"wav"
audio_dict
=
AudioChunk
.
from_audio
(
audio
).
to_openai
()
return
audio_dict
return
AudioChunk
.
from_audio
(
audio
).
to_openai
()
audio_chunks
=
[
asset_to_chunk
(
asset
)
for
asset
in
audio_assets
]
text
=
f
"What's happening in these
{
len
(
audio_assets
)
}
audio clips?"
messages
=
[
{
"role"
:
"user"
,
"content"
:
[
*
audio_chunks
,
{
"type"
:
"text"
,
"text"
:
text
}],
"content"
:
[
*
[
_asset_to_openai_chunk
(
a
)
for
a
in
audio_assets
],
{
"type"
:
"text"
,
"text"
:
question
},
],
}
]
chat_completion
=
await
client
.
chat
.
completions
.
create
(
model
=
MODEL_NAME
,
messages
=
messages
,
max_tokens
=
10
server_args
=
[
"--enforce-eager"
,
"--limit-mm-per-prompt"
,
json
.
dumps
({
"audio"
:
len
(
audio_assets
)}),
*
MISTRAL_FORMAT_ARGS
,
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
server_args
,
env_dict
=
{
"VLLM_AUDIO_FETCH_TIMEOUT"
:
"30"
},
)
as
remote_server
:
client
=
remote_server
.
get_client
()
completion
=
client
.
chat
.
completions
.
create
(
model
=
MODEL_NAME
,
messages
=
messages
,
max_tokens
=
max_tokens
,
temperature
=
0
,
)
assert
len
(
chat_completion
.
choices
)
==
1
choice
=
chat_completion
.
choices
[
0
]
assert
choice
.
message
.
content
==
"In the first audio clip, you hear a brief"
assert
len
(
completion
.
choices
)
==
1
choice
=
completion
.
choices
[
0
]
assert
choice
.
finish_reason
==
"length"
assert
choice
.
message
.
content
==
offline_text
,
(
f
"Online serving output does not match offline inference.
\n
"
f
" Online:
{
choice
.
message
.
content
!
r
}
\n
"
f
" Offline:
{
offline_text
!
r
}
"
)
def
test_hf_reference
(
hf_runner
,
vllm_runner
,
audio_assets
:
AudioTestAssets
):
"""Compare vLLM Mistral-format output against HF Transformers reference.
Instead of requiring an exact text match (which is brittle across
attention backends), we compare per-token logprobs using the standard
check_logprobs_close helper: when tokens diverge at a position, each
runner's chosen token must appear in the other's top-k logprobs.
Marked xfail(strict=False) so remaining edge-case mismatches
don't block CI.
"""
question
=
f
"What's happening in these
{
len
(
audio_assets
)
}
audio clips?"
max_tokens
=
10
num_logprobs
=
5
audio_data
=
[
asset
.
audio_and_sample_rate
for
asset
in
audio_assets
]
vllm_prompt
=
_get_prompt
(
audio_assets
,
question
)
with
vllm_runner
(
MODEL_NAME
,
dtype
=
"half"
,
enforce_eager
=
True
,
tokenizer_mode
=
"mistral"
,
config_format
=
"mistral"
,
load_format
=
"mistral"
,
limit_mm_per_prompt
=
{
"audio"
:
len
(
audio_assets
)},
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
[
vllm_prompt
],
max_tokens
,
num_logprobs
,
audios
=
[
audio_data
],
)
assert
vllm_outputs
[
0
][
1
],
"vLLM inference produced empty output"
with
hf_runner
(
MODEL_NAME
,
dtype
=
"half"
,
auto_cls
=
VoxtralForConditionalGeneration
,
)
as
hf_model
:
hf_model
=
model_utils
.
voxtral_patch_hf_runner
(
hf_model
)
hf_outputs
=
hf_model
.
generate_greedy_logprobs_limit
(
[
question
],
max_tokens
,
num_logprobs
,
audios
=
[
audio_data
],
)
assert
hf_outputs
[
0
][
1
],
"HF Transformers produced empty output"
print
(
f
"HF Reference Comparison
\n
"
f
" vLLM:
{
vllm_outputs
[
0
][
1
]
!
r
}
\n
"
f
" HF:
{
hf_outputs
[
0
][
1
]
!
r
}
"
)
check_logprobs_close
(
outputs_0_lst
=
vllm_outputs
,
outputs_1_lst
=
hf_outputs
,
name_0
=
"vllm"
,
name_1
=
"hf"
,
)
tests/models/multimodal/generation/test_voxtral_realtime.py
View file @
de42abb3
...
...
@@ -10,6 +10,7 @@ from mistral_common.protocol.transcription.request import (
TranscriptionRequest
,
)
from
mistral_common.tokens.tokenizers.mistral
import
MistralTokenizer
from
mistral_common.tokens.tokenizers.tekken
import
SpecialTokenPolicy
from
vllm
import
LLM
,
EngineArgs
,
SamplingParams
from
vllm.assets.audio
import
AudioAsset
...
...
@@ -26,7 +27,7 @@ ENGINE_CONFIG = dict(
load_format
=
"mistral"
,
tokenizer_mode
=
"mistral"
,
enforce_eager
=
True
,
gpu_memory_utilization
=
0.
4
,
gpu_memory_utilization
=
0.
9
,
)
...
...
@@ -148,6 +149,9 @@ async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine)
output_tokens_list
.
append
(
output_tokens
)
texts
=
[
tokenizer
.
decode
(
output_tokens
)
for
output_tokens
in
output_tokens_list
]
texts
=
[
tokenizer
.
decode
(
output_tokens
,
special_token_policy
=
SpecialTokenPolicy
.
IGNORE
)
for
output_tokens
in
output_tokens_list
]
texts
[
1
]
=
texts
[
1
].
replace
(
"a base hit"
,
"OBS"
).
replace
(
"oh my"
,
"oh, my"
)
assert
texts
==
EXPECTED_TEXT
tests/models/multimodal/generation/vlm_utils/model_utils.py
View file @
de42abb3
...
...
@@ -1215,3 +1215,91 @@ def tarsier_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
hf_processor
.
patch_size
=
vision_encoder_info
.
get_patch_size
()
return
hf_model
def
voxtral_patch_hf_runner
(
hf_model
:
"HfRunner"
)
->
"HfRunner"
:
"""Patch HfRunner for Voxtral's conversation-based processor.
Two issues in HfRunner require patching:
1. VoxtralProcessor requires ``apply_chat_template()`` with conversation
dicts (accepting ``url``, ``path``, or ``base64`` audio) rather than
the standard ``processor(text=, audio=, sampling_rate=)`` interface.
2. HfRunner.get_inputs cannot handle multi-audio per prompt because it
mis-unpacks ``[(arr1, sr1), (arr2, sr2)]`` via a ``len == 2`` check.
We override ``get_inputs`` to build conversation dicts and call
``apply_chat_template`` directly, bypassing both issues. We also wrap
``model.generate`` to strip prompt tokens before decoding, since
HfRunner.generate calls batch_decode on the full sequence (prompt +
generated).
"""
import
base64
import
io
import
soundfile
as
sf
processor
=
hf_model
.
processor
def
_audio_to_base64
(
audio_array
,
sample_rate
:
int
)
->
str
:
"""Encode a numpy audio array as a base64 WAV string."""
buf
=
io
.
BytesIO
()
sf
.
write
(
buf
,
audio_array
,
int
(
sample_rate
),
format
=
"WAV"
)
return
base64
.
b64encode
(
buf
.
getvalue
()).
decode
(
"ascii"
)
def
patched_get_inputs
(
prompts
,
images
=
None
,
videos
=
None
,
audios
=
None
,
**
kwargs
):
all_inputs
=
[]
for
i
,
prompt
in
enumerate
(
prompts
):
content
:
list
[
dict
]
=
[]
if
audios
is
not
None
and
audios
[
i
]
is
not
None
:
items
=
audios
[
i
]
if
not
isinstance
(
items
,
list
):
items
=
[
items
]
for
item
in
items
:
if
isinstance
(
item
,
(
list
,
tuple
))
and
len
(
item
)
==
2
:
arr
,
sr
=
item
else
:
arr
,
sr
=
item
,
16_000
content
.
append
(
{
"type"
:
"audio"
,
"base64"
:
_audio_to_base64
(
arr
,
sr
),
}
)
content
.
append
({
"type"
:
"text"
,
"text"
:
prompt
})
inputs
=
processor
.
apply_chat_template
(
[{
"role"
:
"user"
,
"content"
:
content
}]
)
if
hasattr
(
inputs
,
"to"
):
inputs
=
inputs
.
to
(
dtype
=
hf_model
.
dtype
)
all_inputs
.
append
(
inputs
)
return
all_inputs
_orig_generate
=
hf_model
.
model
.
generate
def
patched_generate
(
*
args
,
**
kwargs
):
"""Strip prompt tokens so only generated tokens are decoded."""
input_ids
=
kwargs
.
get
(
"input_ids"
)
if
input_ids
is
None
and
args
:
input_ids
=
args
[
0
]
prompt_len
=
input_ids
.
shape
[
1
]
if
input_ids
is
not
None
else
0
output
=
_orig_generate
(
*
args
,
**
kwargs
)
if
prompt_len
:
if
isinstance
(
output
,
torch
.
Tensor
):
output
=
output
[:,
prompt_len
:]
else
:
# GenerateDecoderOnlyOutput - trim sequences but preserve
# scores/logits so generate_greedy_logprobs_limit can
# extract per-token logprobs.
output
.
sequences
=
output
.
sequences
[:,
prompt_len
:]
return
output
hf_model
.
get_inputs
=
patched_get_inputs
# type: ignore[method-assign, assignment]
hf_model
.
model
.
generate
=
patched_generate
# type: ignore[method-assign]
return
hf_model
tests/models/multimodal/processing/test_common.py
View file @
de42abb3
...
...
@@ -184,6 +184,25 @@ def get_text_token_prompts(
text_prompt
:
str
|
None
token_prompt
:
list
[
int
]
if
isinstance
(
tokenizer
,
MistralTokenizer
):
# ChatCompletionRequest only supports ImageChunk natively;
# for other modalities (e.g. audio), fall back to the model's
# own dummy inputs builder which knows the right placeholders.
has_non_image
=
any
(
k
!=
"image"
and
count
>
0
for
k
,
count
in
mm_counts
.
items
()
)
if
has_non_image
:
inputs
=
dummy_inputs
.
get_dummy_processor_inputs
(
model_config
.
max_model_len
,
mm_counts
,
)
text_prompt
=
None
token_prompt
=
(
inputs
.
prompt
if
isinstance
(
inputs
.
prompt
,
list
)
else
tokenizer
.
encode
(
inputs
.
prompt
,
add_special_tokens
=
False
)
)
else
:
images
=
parsed_data
.
get
(
"image"
,
[])
request
=
ChatCompletionRequest
(
messages
=
[
...
...
@@ -197,7 +216,8 @@ def get_text_token_prompts(
)
res
=
tokenizer
.
mistral
.
encode_chat_completion
(
request
)
# Mistral does not support decode_tokens with skip_special_tokens=False
# Mistral does not support decode_tokens with
# skip_special_tokens=False
text_prompt
=
None
token_prompt
=
res
.
tokens
else
:
...
...
vllm/model_executor/models/voxtral.py
View file @
de42abb3
...
...
@@ -291,6 +291,34 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
# skip validation here
...
def
_apply_hf_processor_mm_only
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
tokenization_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
processor_data
,
passthrough_data
=
self
.
_get_hf_mm_data
(
mm_items
)
audios
=
processor_data
.
get
(
"audios"
,
[])
if
not
isinstance
(
audios
,
list
):
audios
=
[
audios
]
audio_config
=
processor
.
_audio_processor
.
audio_config
audio_tensors
:
list
[
torch
.
Tensor
]
=
[]
for
audio
in
audios
:
audio
=
np
.
asarray
(
audio
,
dtype
=
np
.
float32
).
ravel
()
if
not
audio_config
.
is_streaming
:
audio
=
processor
.
_audio_processor
.
pad
(
audio
,
processor
.
sampling_rate
,
audio_config
.
is_streaming
,
)
audio_tensors
.
append
(
torch
.
tensor
(
audio
))
result
=
BatchFeature
({
"audio_arrays"
:
audio_tensors
}
if
audio_tensors
else
{})
result
.
update
(
passthrough_data
)
return
result
def
_get_prompt_updates
(
self
,
mm_items
:
MultiModalDataItems
,
...
...
vllm/model_executor/models/whisper_causal.py
View file @
de42abb3
...
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
import
functools
import
logging
import
math
from
dataclasses
import
replace
from
functools
import
partial
...
...
@@ -30,11 +31,20 @@ from vllm.v1.attention.backend import (
subclass_attention_backend_with_overrides
,
)
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionBackend
try
:
from
vllm.v1.attention.backends.rocm_aiter_fa
import
AiterFlashAttentionBackend
except
ImportError
:
AiterFlashAttentionBackend
=
None
from
vllm.v1.attention.backends.rocm_attn
import
RocmAttentionBackend
from
vllm.v1.attention.backends.triton_attn
import
TritonAttentionBackend
from
vllm.v1.attention.selector
import
get_attn_backend
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
.utils
import
make_layers
logger
=
logging
.
getLogger
(
__name__
)
CausalRMSNorm
=
partial
(
RMSNorm
,
eps
=
1e-5
)
...
...
@@ -122,6 +132,13 @@ def create_whisper_attention_backend_with_block_pooling(
num_kv_heads
=
kv_cache_spec
.
num_kv_heads
//
block_pool_size
,
)
super
().
__init__
(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
# Override model_config-derived values with the actual
# encoder values from kv_cache_spec
self
.
num_heads_kv
=
kv_cache_spec
.
num_kv_heads
self
.
headdim
=
kv_cache_spec
.
head_size
# num_heads_q for the encoder is the same as num_kv_heads
# (no GQA in whisper encoder)
self
.
num_heads_q
=
kv_cache_spec
.
num_kv_heads
def
build
(
self
,
...
...
@@ -192,13 +209,36 @@ def create_whisper_attention_backend_with_block_pooling(
output_block_scale
,
)
if
not
issubclass
(
underlying_attn_backend
,
FlashAttentionBackend
):
_SUPPORTED_BACKENDS
=
tuple
(
b
for
b
in
(
AiterFlashAttentionBackend
,
FlashAttentionBackend
,
RocmAttentionBackend
,
TritonAttentionBackend
,
)
if
b
is
not
None
)
if
not
issubclass
(
underlying_attn_backend
,
_SUPPORTED_BACKENDS
):
raise
NotImplementedError
(
f
"
{
underlying_attn_backend
}
is not yet supported."
"Contributions to support more backends are much "
"appreciated."
)
if
not
issubclass
(
underlying_attn_backend
,
FlashAttentionBackend
):
logger
.
info
(
"Using %s for Whisper causal attention with block pooling. "
"This backend was recently enabled for this model. "
"If you encounter any accuracy or performance issues, "
"please open an issue at "
"https://github.com/vllm-project/vllm/issues "
"with the [ROCm] tag so it can be triaged by the "
"appropriate team."
,
underlying_attn_backend
.
get_name
(),
)
attn_backend
=
subclass_attention_backend_with_overrides
(
name_prefix
=
prefix
,
attention_backend_cls
=
underlying_attn_backend
,
...
...
@@ -209,14 +249,14 @@ def create_whisper_attention_backend_with_block_pooling(
block_size
,
num_kv_heads
,
head_size
,
cache_dtype_str
:
(
2
,
cache_dtype_str
:
underlying_attn_backend
.
get_kv_cache_shape
(
num_blocks
,
# we stretch each block by `block_pool_size`
block_size
*
block_pool_size
,
num_kv_heads
//
block_pool_size
,
head_size
,
),
# TODO: generalize to other backends
cache_dtype_str
,
),
"forward_includes_kv_cache_update"
:
True
,
},
)
...
...
vllm/reasoning/mistral_reasoning_parser.py
View file @
de42abb3
...
...
@@ -43,8 +43,8 @@ class MistralReasoningParser(BaseThinkingReasoningParser):
"constructor during construction."
)
self
.
start_token_id
=
tokenizer
.
tokenizer
.
get_
contro
l_token
(
self
.
start_token
)
self
.
end_token_id
=
tokenizer
.
tokenizer
.
get_
contro
l_token
(
self
.
end_token
)
self
.
start_token_id
=
tokenizer
.
tokenizer
.
get_
specia
l_token
(
self
.
start_token
)
self
.
end_token_id
=
tokenizer
.
tokenizer
.
get_
specia
l_token
(
self
.
end_token
)
if
self
.
start_token_id
is
None
or
self
.
end_token_id
is
None
:
raise
RuntimeError
(
...
...
vllm/tokenizers/mistral.py
View file @
de42abb3
...
...
@@ -517,7 +517,7 @@ class MistralTokenizer(TokenizerLike):
return
[
self
.
tokenizer
.
id_to_piece
(
token_id
)
for
token_id
in
ids
]
non_skip_special_tokens_ids
=
{
self
.
tokenizer
.
get_
contro
l_token
(
SpecialTokens
.
tool_calls
),
self
.
tokenizer
.
get_
specia
l_token
(
SpecialTokens
.
tool_calls
),
}
if
isinstance
(
self
.
instruct
,
InstructTokenizerV13
):
if
self
.
instruct
.
BEGIN_THINK
:
...
...
vllm/v1/attention/backends/rocm_aiter_fa.py
View file @
de42abb3
...
...
@@ -425,8 +425,13 @@ class AiterFlashAttentionMetadataBuilder(
sliding_window_configs
:
set
[
tuple
[
int
,
int
]
|
None
]
=
set
()
layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
)
for
layer
in
layers
.
values
():
assert
isinstance
(
layer
.
impl
,
AiterFlashAttentionImpl
)
for
name
,
layer
in
layers
.
items
():
if
name
not
in
layer_names
:
continue
assert
isinstance
(
layer
.
impl
,
AiterFlashAttentionImpl
),
(
"Aiter Flash Attention Metadata Builder can only be used "
"with Aiter Flash Attention Impl."
)
sliding_window_configs
.
add
(
layer
.
impl
.
sliding_window
)
while
len
(
sliding_window_configs
)
>
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment