Unverified Commit de42abb3 authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[CI] Heavy refactoring of Voxtral multimodal audio model tests (#34294)


Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent 60ca7981
...@@ -96,3 +96,5 @@ albumentations==1.4.6 ...@@ -96,3 +96,5 @@ albumentations==1.4.6
transformers==4.57.3 transformers==4.57.3
# Pin HF Hub version # Pin HF Hub version
huggingface-hub==0.36.2 huggingface-hub==0.36.2
# Pin Mistral Common
mistral-common[image,audio]==1.9.1
...@@ -419,7 +419,6 @@ class HfRunner: ...@@ -419,7 +419,6 @@ class HfRunner:
self.tokenizer: "PreTrainedTokenizer | PreTrainedTokenizerFast" = ( self.tokenizer: "PreTrainedTokenizer | PreTrainedTokenizerFast" = (
AutoTokenizer.from_pretrained( AutoTokenizer.from_pretrained(
model_name, model_name,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
) )
...@@ -430,7 +429,6 @@ class HfRunner: ...@@ -430,7 +429,6 @@ class HfRunner:
self.processor = AutoProcessor.from_pretrained( self.processor = AutoProcessor.from_pretrained(
model_name, model_name,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if skip_tokenizer_init: if skip_tokenizer_init:
......
...@@ -4,16 +4,18 @@ ...@@ -4,16 +4,18 @@
import json import json
import pytest import pytest
import pytest_asyncio
from mistral_common.audio import Audio from mistral_common.audio import Audio
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.messages import UserMessage
from transformers import VoxtralForConditionalGeneration
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
from ....conftest import AudioTestAssets from ....conftest import AudioTestAssets
from ....utils import RemoteOpenAIServer from ....utils import RemoteOpenAIServer
from ...utils import check_logprobs_close
from .test_ultravox import MULTI_AUDIO_PROMPT, run_multi_audio_test from .test_ultravox import MULTI_AUDIO_PROMPT, run_multi_audio_test
from .vlm_utils import model_utils
MODEL_NAME = "mistralai/Voxtral-Mini-3B-2507" MODEL_NAME = "mistralai/Voxtral-Mini-3B-2507"
MISTRAL_FORMAT_ARGS = [ MISTRAL_FORMAT_ARGS = [
...@@ -26,40 +28,21 @@ MISTRAL_FORMAT_ARGS = [ ...@@ -26,40 +28,21 @@ MISTRAL_FORMAT_ARGS = [
] ]
@pytest.fixture() def _get_prompt(audio_assets: AudioTestAssets, question: str) -> list[int]:
def server(request, audio_assets: AudioTestAssets): """Build a token-ID prompt via mistral_common for vLLM offline inference."""
args = [
"--enforce-eager",
"--limit-mm-per-prompt",
json.dumps({"audio": len(audio_assets)}),
] + MISTRAL_FORMAT_ARGS
with RemoteOpenAIServer(
MODEL_NAME, args, env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": "30"}
) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
def _get_prompt(audio_assets, question):
tokenizer = MistralTokenizer.from_pretrained(MODEL_NAME) tokenizer = MistralTokenizer.from_pretrained(MODEL_NAME)
audios = [ audios = [
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False) Audio.from_file(str(asset.get_local_path()), strict=False)
for i in range(len(audio_assets)) for asset in audio_assets
] ]
audio_chunks = [ audio_chunks = [
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
] ]
text_chunk = TextChunk(text=question) messages = [
messages = [UserMessage(content=[*audio_chunks, text_chunk]).to_openai()] UserMessage(content=[*audio_chunks, TextChunk(text=question)]).to_openai()
]
return tokenizer.apply_chat_template(messages=messages) return tokenizer.apply_chat_template(messages=messages)
...@@ -77,7 +60,7 @@ def test_models_with_multiple_audios( ...@@ -77,7 +60,7 @@ def test_models_with_multiple_audios(
vllm_prompt = _get_prompt(audio_assets, MULTI_AUDIO_PROMPT) vllm_prompt = _get_prompt(audio_assets, MULTI_AUDIO_PROMPT)
run_multi_audio_test( run_multi_audio_test(
vllm_runner, vllm_runner,
[(vllm_prompt, [audio.audio_and_sample_rate for audio in audio_assets])], [(vllm_prompt, [a.audio_and_sample_rate for a in audio_assets])], # type: ignore[list-item]
MODEL_NAME, MODEL_NAME,
dtype=dtype, dtype=dtype,
max_tokens=max_tokens, max_tokens=max_tokens,
...@@ -86,30 +69,142 @@ def test_models_with_multiple_audios( ...@@ -86,30 +69,142 @@ def test_models_with_multiple_audios(
) )
@pytest.mark.asyncio def test_online_serving(vllm_runner, audio_assets: AudioTestAssets):
async def test_online_serving(client, audio_assets: AudioTestAssets): """Two-layer accuracy and serving validation using Mistral format.
"""Exercises online serving with/without chunked prefill enabled."""
1. Offline vLLM greedy output (runs first to avoid CUDA fork issues
with multiprocessing - see vlm_utils/core.py).
2. Online OpenAI-compatible API output must match offline — validates
that the serving path (chat template, audio encoding, tokenization)
does not corrupt anything.
Steps run sequentially so each releases the GPU before the next starts.
"""
def asset_to_chunk(asset): question = f"What's happening in these {len(audio_assets)} audio clips?"
max_tokens = 10
audio_data = [asset.audio_and_sample_rate for asset in audio_assets]
vllm_prompt = _get_prompt(audio_assets, question)
with vllm_runner(
MODEL_NAME,
dtype="half",
enforce_eager=True,
tokenizer_mode="mistral",
config_format="mistral",
load_format="mistral",
limit_mm_per_prompt={"audio": len(audio_assets)},
) as vllm_model:
offline_outputs = vllm_model.generate_greedy(
[vllm_prompt],
max_tokens,
audios=[audio_data],
)
offline_text = offline_outputs[0][1]
assert offline_text, "Offline vLLM inference produced empty output"
def _asset_to_openai_chunk(asset):
audio = Audio.from_file(str(asset.get_local_path()), strict=False) audio = Audio.from_file(str(asset.get_local_path()), strict=False)
audio.format = "wav" audio.format = "wav"
audio_dict = AudioChunk.from_audio(audio).to_openai() return AudioChunk.from_audio(audio).to_openai()
return audio_dict
audio_chunks = [asset_to_chunk(asset) for asset in audio_assets]
text = f"What's happening in these {len(audio_assets)} audio clips?"
messages = [ messages = [
{ {
"role": "user", "role": "user",
"content": [*audio_chunks, {"type": "text", "text": text}], "content": [
*[_asset_to_openai_chunk(a) for a in audio_assets],
{"type": "text", "text": question},
],
} }
] ]
chat_completion = await client.chat.completions.create( server_args = [
model=MODEL_NAME, messages=messages, max_tokens=10 "--enforce-eager",
) "--limit-mm-per-prompt",
json.dumps({"audio": len(audio_assets)}),
*MISTRAL_FORMAT_ARGS,
]
assert len(chat_completion.choices) == 1 with RemoteOpenAIServer(
choice = chat_completion.choices[0] MODEL_NAME,
assert choice.message.content == "In the first audio clip, you hear a brief" server_args,
env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": "30"},
) as remote_server:
client = remote_server.get_client()
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=max_tokens,
temperature=0,
)
assert len(completion.choices) == 1
choice = completion.choices[0]
assert choice.finish_reason == "length" assert choice.finish_reason == "length"
assert choice.message.content == offline_text, (
f"Online serving output does not match offline inference.\n"
f" Online: {choice.message.content!r}\n"
f" Offline: {offline_text!r}"
)
def test_hf_reference(hf_runner, vllm_runner, audio_assets: AudioTestAssets):
"""Compare vLLM Mistral-format output against HF Transformers reference.
Instead of requiring an exact text match (which is brittle across
attention backends), we compare per-token logprobs using the standard
check_logprobs_close helper: when tokens diverge at a position, each
runner's chosen token must appear in the other's top-k logprobs.
Marked xfail(strict=False) so remaining edge-case mismatches
don't block CI.
"""
question = f"What's happening in these {len(audio_assets)} audio clips?"
max_tokens = 10
num_logprobs = 5
audio_data = [asset.audio_and_sample_rate for asset in audio_assets]
vllm_prompt = _get_prompt(audio_assets, question)
with vllm_runner(
MODEL_NAME,
dtype="half",
enforce_eager=True,
tokenizer_mode="mistral",
config_format="mistral",
load_format="mistral",
limit_mm_per_prompt={"audio": len(audio_assets)},
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
[vllm_prompt],
max_tokens,
num_logprobs,
audios=[audio_data],
)
assert vllm_outputs[0][1], "vLLM inference produced empty output"
with hf_runner(
MODEL_NAME,
dtype="half",
auto_cls=VoxtralForConditionalGeneration,
) as hf_model:
hf_model = model_utils.voxtral_patch_hf_runner(hf_model)
hf_outputs = hf_model.generate_greedy_logprobs_limit(
[question],
max_tokens,
num_logprobs,
audios=[audio_data],
)
assert hf_outputs[0][1], "HF Transformers produced empty output"
print(
f"HF Reference Comparison\n"
f" vLLM: {vllm_outputs[0][1]!r}\n"
f" HF: {hf_outputs[0][1]!r}"
)
check_logprobs_close(
outputs_0_lst=vllm_outputs,
outputs_1_lst=hf_outputs,
name_0="vllm",
name_1="hf",
)
...@@ -10,6 +10,7 @@ from mistral_common.protocol.transcription.request import ( ...@@ -10,6 +10,7 @@ from mistral_common.protocol.transcription.request import (
TranscriptionRequest, TranscriptionRequest,
) )
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy
from vllm import LLM, EngineArgs, SamplingParams from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
...@@ -26,7 +27,7 @@ ENGINE_CONFIG = dict( ...@@ -26,7 +27,7 @@ ENGINE_CONFIG = dict(
load_format="mistral", load_format="mistral",
tokenizer_mode="mistral", tokenizer_mode="mistral",
enforce_eager=True, enforce_eager=True,
gpu_memory_utilization=0.4, gpu_memory_utilization=0.9,
) )
...@@ -148,6 +149,9 @@ async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine) ...@@ -148,6 +149,9 @@ async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine)
output_tokens_list.append(output_tokens) output_tokens_list.append(output_tokens)
texts = [tokenizer.decode(output_tokens) for output_tokens in output_tokens_list] texts = [
tokenizer.decode(output_tokens, special_token_policy=SpecialTokenPolicy.IGNORE)
for output_tokens in output_tokens_list
]
texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my") texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my")
assert texts == EXPECTED_TEXT assert texts == EXPECTED_TEXT
...@@ -1215,3 +1215,91 @@ def tarsier_patch_hf_runner(hf_model: HfRunner) -> HfRunner: ...@@ -1215,3 +1215,91 @@ def tarsier_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
hf_processor.patch_size = vision_encoder_info.get_patch_size() hf_processor.patch_size = vision_encoder_info.get_patch_size()
return hf_model return hf_model
def voxtral_patch_hf_runner(hf_model: "HfRunner") -> "HfRunner":
"""Patch HfRunner for Voxtral's conversation-based processor.
Two issues in HfRunner require patching:
1. VoxtralProcessor requires ``apply_chat_template()`` with conversation
dicts (accepting ``url``, ``path``, or ``base64`` audio) rather than
the standard ``processor(text=, audio=, sampling_rate=)`` interface.
2. HfRunner.get_inputs cannot handle multi-audio per prompt because it
mis-unpacks ``[(arr1, sr1), (arr2, sr2)]`` via a ``len == 2`` check.
We override ``get_inputs`` to build conversation dicts and call
``apply_chat_template`` directly, bypassing both issues. We also wrap
``model.generate`` to strip prompt tokens before decoding, since
HfRunner.generate calls batch_decode on the full sequence (prompt +
generated).
"""
import base64
import io
import soundfile as sf
processor = hf_model.processor
def _audio_to_base64(audio_array, sample_rate: int) -> str:
"""Encode a numpy audio array as a base64 WAV string."""
buf = io.BytesIO()
sf.write(buf, audio_array, int(sample_rate), format="WAV")
return base64.b64encode(buf.getvalue()).decode("ascii")
def patched_get_inputs(prompts, images=None, videos=None, audios=None, **kwargs):
all_inputs = []
for i, prompt in enumerate(prompts):
content: list[dict] = []
if audios is not None and audios[i] is not None:
items = audios[i]
if not isinstance(items, list):
items = [items]
for item in items:
if isinstance(item, (list, tuple)) and len(item) == 2:
arr, sr = item
else:
arr, sr = item, 16_000
content.append(
{
"type": "audio",
"base64": _audio_to_base64(arr, sr),
}
)
content.append({"type": "text", "text": prompt})
inputs = processor.apply_chat_template(
[{"role": "user", "content": content}]
)
if hasattr(inputs, "to"):
inputs = inputs.to(dtype=hf_model.dtype)
all_inputs.append(inputs)
return all_inputs
_orig_generate = hf_model.model.generate
def patched_generate(*args, **kwargs):
"""Strip prompt tokens so only generated tokens are decoded."""
input_ids = kwargs.get("input_ids")
if input_ids is None and args:
input_ids = args[0]
prompt_len = input_ids.shape[1] if input_ids is not None else 0
output = _orig_generate(*args, **kwargs)
if prompt_len:
if isinstance(output, torch.Tensor):
output = output[:, prompt_len:]
else:
# GenerateDecoderOnlyOutput - trim sequences but preserve
# scores/logits so generate_greedy_logprobs_limit can
# extract per-token logprobs.
output.sequences = output.sequences[:, prompt_len:]
return output
hf_model.get_inputs = patched_get_inputs # type: ignore[method-assign, assignment]
hf_model.model.generate = patched_generate # type: ignore[method-assign]
return hf_model
...@@ -184,22 +184,42 @@ def get_text_token_prompts( ...@@ -184,22 +184,42 @@ def get_text_token_prompts(
text_prompt: str | None text_prompt: str | None
token_prompt: list[int] token_prompt: list[int]
if isinstance(tokenizer, MistralTokenizer): if isinstance(tokenizer, MistralTokenizer):
images = parsed_data.get("image", []) # ChatCompletionRequest only supports ImageChunk natively;
request = ChatCompletionRequest( # for other modalities (e.g. audio), fall back to the model's
messages=[ # own dummy inputs builder which knows the right placeholders.
UserMessage( has_non_image = any(
content=[ k != "image" and count > 0 for k, count in mm_counts.items()
TextChunk(text=""),
*(ImageChunk(image=image) for image in images),
]
),
]
) )
res = tokenizer.mistral.encode_chat_completion(request)
# Mistral does not support decode_tokens with skip_special_tokens=False if has_non_image:
text_prompt = None inputs = dummy_inputs.get_dummy_processor_inputs(
token_prompt = res.tokens model_config.max_model_len,
mm_counts,
)
text_prompt = None
token_prompt = (
inputs.prompt
if isinstance(inputs.prompt, list)
else tokenizer.encode(inputs.prompt, add_special_tokens=False)
)
else:
images = parsed_data.get("image", [])
request = ChatCompletionRequest(
messages=[
UserMessage(
content=[
TextChunk(text=""),
*(ImageChunk(image=image) for image in images),
]
),
]
)
res = tokenizer.mistral.encode_chat_completion(request)
# Mistral does not support decode_tokens with
# skip_special_tokens=False
text_prompt = None
token_prompt = res.tokens
else: else:
inputs = dummy_inputs.get_dummy_processor_inputs( inputs = dummy_inputs.get_dummy_processor_inputs(
model_config.max_model_len, model_config.max_model_len,
......
...@@ -291,6 +291,34 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]) ...@@ -291,6 +291,34 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
# skip validation here # skip validation here
... ...
def _apply_hf_processor_mm_only(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> BatchFeature:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
audios = processor_data.get("audios", [])
if not isinstance(audios, list):
audios = [audios]
audio_config = processor._audio_processor.audio_config
audio_tensors: list[torch.Tensor] = []
for audio in audios:
audio = np.asarray(audio, dtype=np.float32).ravel()
if not audio_config.is_streaming:
audio = processor._audio_processor.pad(
audio,
processor.sampling_rate,
audio_config.is_streaming,
)
audio_tensors.append(torch.tensor(audio))
result = BatchFeature({"audio_arrays": audio_tensors} if audio_tensors else {})
result.update(passthrough_data)
return result
def _get_prompt_updates( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy import copy
import functools import functools
import logging
import math import math
from dataclasses import replace from dataclasses import replace
from functools import partial from functools import partial
...@@ -30,11 +31,20 @@ from vllm.v1.attention.backend import ( ...@@ -30,11 +31,20 @@ from vllm.v1.attention.backend import (
subclass_attention_backend_with_overrides, subclass_attention_backend_with_overrides,
) )
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
try:
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
except ImportError:
AiterFlashAttentionBackend = None
from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend
from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend
from vllm.v1.attention.selector import get_attn_backend from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from .utils import make_layers from .utils import make_layers
logger = logging.getLogger(__name__)
CausalRMSNorm = partial(RMSNorm, eps=1e-5) CausalRMSNorm = partial(RMSNorm, eps=1e-5)
...@@ -122,6 +132,13 @@ def create_whisper_attention_backend_with_block_pooling( ...@@ -122,6 +132,13 @@ def create_whisper_attention_backend_with_block_pooling(
num_kv_heads=kv_cache_spec.num_kv_heads // block_pool_size, num_kv_heads=kv_cache_spec.num_kv_heads // block_pool_size,
) )
super().__init__(kv_cache_spec, layer_names, vllm_config, device) super().__init__(kv_cache_spec, layer_names, vllm_config, device)
# Override model_config-derived values with the actual
# encoder values from kv_cache_spec
self.num_heads_kv = kv_cache_spec.num_kv_heads
self.headdim = kv_cache_spec.head_size
# num_heads_q for the encoder is the same as num_kv_heads
# (no GQA in whisper encoder)
self.num_heads_q = kv_cache_spec.num_kv_heads
def build( def build(
self, self,
...@@ -192,13 +209,36 @@ def create_whisper_attention_backend_with_block_pooling( ...@@ -192,13 +209,36 @@ def create_whisper_attention_backend_with_block_pooling(
output_block_scale, output_block_scale,
) )
if not issubclass(underlying_attn_backend, FlashAttentionBackend): _SUPPORTED_BACKENDS = tuple(
b
for b in (
AiterFlashAttentionBackend,
FlashAttentionBackend,
RocmAttentionBackend,
TritonAttentionBackend,
)
if b is not None
)
if not issubclass(underlying_attn_backend, _SUPPORTED_BACKENDS):
raise NotImplementedError( raise NotImplementedError(
f"{underlying_attn_backend} is not yet supported." f"{underlying_attn_backend} is not yet supported."
"Contributions to support more backends are much " "Contributions to support more backends are much "
"appreciated." "appreciated."
) )
if not issubclass(underlying_attn_backend, FlashAttentionBackend):
logger.info(
"Using %s for Whisper causal attention with block pooling. "
"This backend was recently enabled for this model. "
"If you encounter any accuracy or performance issues, "
"please open an issue at "
"https://github.com/vllm-project/vllm/issues "
"with the [ROCm] tag so it can be triaged by the "
"appropriate team.",
underlying_attn_backend.get_name(),
)
attn_backend = subclass_attention_backend_with_overrides( attn_backend = subclass_attention_backend_with_overrides(
name_prefix=prefix, name_prefix=prefix,
attention_backend_cls=underlying_attn_backend, attention_backend_cls=underlying_attn_backend,
...@@ -209,14 +249,14 @@ def create_whisper_attention_backend_with_block_pooling( ...@@ -209,14 +249,14 @@ def create_whisper_attention_backend_with_block_pooling(
block_size, block_size,
num_kv_heads, num_kv_heads,
head_size, head_size,
cache_dtype_str: ( cache_dtype_str: underlying_attn_backend.get_kv_cache_shape(
2,
num_blocks, num_blocks,
# we stretch each block by `block_pool_size` # we stretch each block by `block_pool_size`
block_size * block_pool_size, block_size * block_pool_size,
num_kv_heads // block_pool_size, num_kv_heads // block_pool_size,
head_size, head_size,
), # TODO: generalize to other backends cache_dtype_str,
),
"forward_includes_kv_cache_update": True, "forward_includes_kv_cache_update": True,
}, },
) )
......
...@@ -43,8 +43,8 @@ class MistralReasoningParser(BaseThinkingReasoningParser): ...@@ -43,8 +43,8 @@ class MistralReasoningParser(BaseThinkingReasoningParser):
"constructor during construction." "constructor during construction."
) )
self.start_token_id = tokenizer.tokenizer.get_control_token(self.start_token) self.start_token_id = tokenizer.tokenizer.get_special_token(self.start_token)
self.end_token_id = tokenizer.tokenizer.get_control_token(self.end_token) self.end_token_id = tokenizer.tokenizer.get_special_token(self.end_token)
if self.start_token_id is None or self.end_token_id is None: if self.start_token_id is None or self.end_token_id is None:
raise RuntimeError( raise RuntimeError(
......
...@@ -517,7 +517,7 @@ class MistralTokenizer(TokenizerLike): ...@@ -517,7 +517,7 @@ class MistralTokenizer(TokenizerLike):
return [self.tokenizer.id_to_piece(token_id) for token_id in ids] return [self.tokenizer.id_to_piece(token_id) for token_id in ids]
non_skip_special_tokens_ids = { non_skip_special_tokens_ids = {
self.tokenizer.get_control_token(SpecialTokens.tool_calls), self.tokenizer.get_special_token(SpecialTokens.tool_calls),
} }
if isinstance(self.instruct, InstructTokenizerV13): if isinstance(self.instruct, InstructTokenizerV13):
if self.instruct.BEGIN_THINK: if self.instruct.BEGIN_THINK:
......
...@@ -425,8 +425,13 @@ class AiterFlashAttentionMetadataBuilder( ...@@ -425,8 +425,13 @@ class AiterFlashAttentionMetadataBuilder(
sliding_window_configs: set[tuple[int, int] | None] = set() sliding_window_configs: set[tuple[int, int] | None] = set()
layers = get_layers_from_vllm_config(self.vllm_config, Attention) layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for layer in layers.values(): for name, layer in layers.items():
assert isinstance(layer.impl, AiterFlashAttentionImpl) if name not in layer_names:
continue
assert isinstance(layer.impl, AiterFlashAttentionImpl), (
"Aiter Flash Attention Metadata Builder can only be used "
"with Aiter Flash Attention Impl."
)
sliding_window_configs.add(layer.impl.sliding_window) sliding_window_configs.add(layer.impl.sliding_window)
while len(sliding_window_configs) > 0: while len(sliding_window_configs) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment