Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b9cdc852
Unverified
Commit
b9cdc852
authored
Mar 31, 2026
by
Andreas Karatzas
Committed by
GitHub
Mar 31, 2026
Browse files
[ROCm][CI] Fix Whisper translation test attention backend selection (#38508)
Signed-off-by:
Andreas Karatzas
<
akaratza@amd.com
>
parent
3e802e87
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
6 deletions
+39
-6
tests/entrypoints/openai/speech_to_text/test_translation_validation.py
...ints/openai/speech_to_text/test_translation_validation.py
+39
-6
No files found.
tests/entrypoints/openai/speech_to_text/test_translation_validation.py
View file @
b9cdc852
...
@@ -16,10 +16,41 @@ import soundfile as sf
...
@@ -16,10 +16,41 @@ import soundfile as sf
from
tests.entrypoints.openai.conftest
import
add_attention_backend
from
tests.entrypoints.openai.conftest
import
add_attention_backend
from
tests.utils
import
RemoteOpenAIServer
from
tests.utils
import
RemoteOpenAIServer
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
SERVER_ARGS
=
[
"--enforce-eager"
]
SERVER_ARGS
=
[
"--enforce-eager"
]
def
_get_rocm_attention_config
(
model_name
):
"""Return appropriate ROCm attention config for the given model.
Whisper uses cross-attention (ENCODER_DECODER) which ROCM_AITER_FA does
not support. For Whisper we use ROCM_AITER_UNIFIED_ATTN (or TRITON_ATTN
as fallback); other models can use ROCM_AITER_FA.
"""
from
vllm.platforms
import
current_platform
if
not
current_platform
.
is_rocm
():
return
None
if
"whisper"
in
model_name
.
lower
():
try
:
from
vllm.platforms.rocm
import
_ON_MI3XX
if
_ON_MI3XX
:
return
{
"backend"
:
"ROCM_AITER_UNIFIED_ATTN"
}
except
ImportError
:
logger
.
warning
(
"Could not import _ON_MI3XX from rocm platform, "
"falling back to TRITON_ATTN for Whisper."
)
return
{
"backend"
:
"TRITON_ATTN"
}
return
{
"backend"
:
"ROCM_AITER_FA"
}
def
_get_server_args
(
attention_config
):
def
_get_server_args
(
attention_config
):
"""Get server args with attention backend if specified."""
"""Get server args with attention backend if specified."""
args
=
SERVER_ARGS
.
copy
()
args
=
SERVER_ARGS
.
copy
()
...
@@ -30,10 +61,11 @@ def _get_server_args(attention_config):
...
@@ -30,10 +61,11 @@ def _get_server_args(attention_config):
@
pytest
.
fixture
(
@
pytest
.
fixture
(
scope
=
"module"
,
params
=
[
"openai/whisper-small"
,
"google/gemma-3n-E2B-it"
]
scope
=
"module"
,
params
=
[
"openai/whisper-small"
,
"google/gemma-3n-E2B-it"
]
)
)
def
server
(
request
,
rocm_aiter_fa_attention
):
def
server
(
request
):
# Parametrize over model name
# Parametrize over model name
attention_config
=
_get_rocm_attention_config
(
request
.
param
)
with
RemoteOpenAIServer
(
with
RemoteOpenAIServer
(
request
.
param
,
_get_server_args
(
rocm_aiter_fa_
attention
)
request
.
param
,
_get_server_args
(
attention
_config
)
)
as
remote_server
:
)
as
remote_server
:
yield
remote_server
,
request
.
param
yield
remote_server
,
request
.
param
...
@@ -46,11 +78,12 @@ async def client_and_model(server):
...
@@ -46,11 +78,12 @@ async def client_and_model(server):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_non_asr_model
(
foscolo
,
rocm_aiter_fa_attention
):
async
def
test_non_asr_model
(
foscolo
):
# text to text model
# text to text model
model_name
=
"JackFram/llama-68m"
model_name
=
"JackFram/llama-68m"
attention_config
=
_get_rocm_attention_config
(
model_name
)
with
RemoteOpenAIServer
(
with
RemoteOpenAIServer
(
model_name
,
_get_server_args
(
rocm_aiter_fa_
attention
)
model_name
,
_get_server_args
(
attention
_config
)
)
as
remote_server
:
)
as
remote_server
:
client
=
remote_server
.
get_async_client
()
client
=
remote_server
.
get_async_client
()
...
@@ -61,7 +94,7 @@ async def test_non_asr_model(foscolo, rocm_aiter_fa_attention):
...
@@ -61,7 +94,7 @@ async def test_non_asr_model(foscolo, rocm_aiter_fa_attention):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_basic_audio_with_lora
(
mary_had_lamb
,
rocm_aiter_fa_attention
):
async
def
test_basic_audio_with_lora
(
mary_had_lamb
):
"""Ensure STT (translate) requests can pass LoRA through to generate."""
"""Ensure STT (translate) requests can pass LoRA through to generate."""
# ROCm SPECIFIC CONFIGURATION:
# ROCm SPECIFIC CONFIGURATION:
# To ensure the test passes on ROCm, we modify the max model length to 512.
# To ensure the test passes on ROCm, we modify the max model length to 512.
...
@@ -85,7 +118,7 @@ async def test_basic_audio_with_lora(mary_had_lamb, rocm_aiter_fa_attention):
...
@@ -85,7 +118,7 @@ async def test_basic_audio_with_lora(mary_had_lamb, rocm_aiter_fa_attention):
"1"
,
"1"
,
]
]
add_attention_backend
(
server_args
,
rocm_a
iter_fa_attention
)
add_attention_backend
(
server_args
,
_get_
rocm_a
ttention_config
(
model_name
)
)
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
with
RemoteOpenAIServer
(
model_name
,
server_args
)
as
remote_server
:
with
RemoteOpenAIServer
(
model_name
,
server_args
)
as
remote_server
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment