Unverified Commit d4c57863 authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[ROCm][CI] Fix engine teardown and text normalization to stabilize voxtral test (#37138)


Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent 68e1b711
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from dataclasses import asdict from dataclasses import asdict
import pytest import pytest
import pytest_asyncio
from mistral_common.audio import Audio from mistral_common.audio import Audio
from mistral_common.protocol.instruct.chunk import RawAudio from mistral_common.protocol.instruct.chunk import RawAudio
from mistral_common.protocol.transcription.request import ( from mistral_common.protocol.transcription.request import (
...@@ -17,18 +19,21 @@ from vllm.assets.audio import AudioAsset ...@@ -17,18 +19,21 @@ from vllm.assets.audio import AudioAsset
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
from ....utils import ROCM_ENGINE_KWARGS
MODEL_NAME = "mistralai/Voxtral-Mini-4B-Realtime-2602" MODEL_NAME = "mistralai/Voxtral-Mini-4B-Realtime-2602"
ENGINE_CONFIG = dict( ENGINE_CONFIG = {
model=MODEL_NAME, "model": MODEL_NAME,
max_model_len=8192, "max_model_len": 8192,
max_num_seqs=4, "max_num_seqs": 4,
limit_mm_per_prompt={"audio": 1}, "limit_mm_per_prompt": {"audio": 1},
config_format="mistral", "config_format": "mistral",
load_format="mistral", "load_format": "mistral",
tokenizer_mode="mistral", "tokenizer_mode": "mistral",
enforce_eager=True, "enforce_eager": True,
gpu_memory_utilization=0.9, "gpu_memory_utilization": 0.9,
) **ROCM_ENGINE_KWARGS,
}
EXPECTED_TEXT = [ EXPECTED_TEXT = [
...@@ -49,6 +54,14 @@ EXPECTED_TEXT = [ ...@@ -49,6 +54,14 @@ EXPECTED_TEXT = [
] ]
def _normalize(texts: list[str]) -> list[str]:
# The model occasionally transcribes "OBS" as "a base hit" and
# "oh, my" as "oh my", but both are acoustically valid. Normalise so
# the assertion is stable across runs and hardware.
texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my")
return texts
@pytest.fixture @pytest.fixture
def audio_assets() -> list[AudioAsset]: def audio_assets() -> list[AudioAsset]:
return [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")] return [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
...@@ -60,15 +73,27 @@ def tokenizer() -> MistralTokenizer: ...@@ -60,15 +73,27 @@ def tokenizer() -> MistralTokenizer:
@pytest.fixture @pytest.fixture
def engine() -> LLM: def engine():
engine_args = EngineArgs(**ENGINE_CONFIG) engine_args = EngineArgs(**ENGINE_CONFIG)
return LLM(**asdict(engine_args)) llm = LLM(**asdict(engine_args))
try:
yield llm
finally:
with contextlib.suppress(Exception):
llm.llm_engine.engine_core.shutdown()
import torch
torch.accelerator.empty_cache()
@pytest.fixture
def async_engine() -> AsyncLLM: @pytest_asyncio.fixture
async def async_engine():
engine_args = AsyncEngineArgs(**ENGINE_CONFIG) engine_args = AsyncEngineArgs(**ENGINE_CONFIG)
return AsyncLLM.from_engine_args(engine_args) llm = AsyncLLM.from_engine_args(engine_args)
try:
yield llm
finally:
llm.shutdown()
def test_voxtral_realtime_forward(audio_assets, tokenizer, engine): def test_voxtral_realtime_forward(audio_assets, tokenizer, engine):
...@@ -108,8 +133,13 @@ def test_voxtral_realtime_forward(audio_assets, tokenizer, engine): ...@@ -108,8 +133,13 @@ def test_voxtral_realtime_forward(audio_assets, tokenizer, engine):
sampling_params=sampling_params, sampling_params=sampling_params,
) )
texts = [out.outputs[0].text for out in outputs] texts = _normalize([out.outputs[0].text for out in outputs])
assert texts == EXPECTED_TEXT for i, (got, expected) in enumerate(zip(texts, EXPECTED_TEXT)):
assert got == expected, (
f"Output mismatch at index {i}:\n"
f" got: {got!r}\n"
f" expected: {expected!r}"
)
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -149,9 +179,17 @@ async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine) ...@@ -149,9 +179,17 @@ async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine)
output_tokens_list.append(output_tokens) output_tokens_list.append(output_tokens)
texts = [ texts = _normalize(
tokenizer.decode(output_tokens, special_token_policy=SpecialTokenPolicy.IGNORE) [
for output_tokens in output_tokens_list tokenizer.decode(
] output_tokens, special_token_policy=SpecialTokenPolicy.IGNORE
texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my") )
assert texts == EXPECTED_TEXT for output_tokens in output_tokens_list
]
)
for i, (got, expected) in enumerate(zip(texts, EXPECTED_TEXT)):
assert got == expected, (
f"Output mismatch at index {i}:\n"
f" got: {got!r}\n"
f" expected: {expected!r}"
)
...@@ -122,6 +122,12 @@ ROCM_EXTRA_ARGS = ( ...@@ -122,6 +122,12 @@ ROCM_EXTRA_ARGS = (
if current_platform.is_rocm() if current_platform.is_rocm()
else [] else []
) )
# Python-API equivalent of ROCM_EXTRA_ARGS for use with EngineArgs kwargs.
ROCM_ENGINE_KWARGS: dict = (
{"enable_prefix_caching": False, "max_num_seqs": 1}
if current_platform.is_rocm()
else {}
)
class RemoteVLLMServer: class RemoteVLLMServer:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment