Unverified Commit f6220f98 authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[ROCm][Test] Fix beam search determinism failures from batch-size-dependent FP...


[ROCm][Test] Fix beam search determinism failures from batch-size-dependent FP divergence and removed wrong marker (#34878)
Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent 2df2bb27
...@@ -549,7 +549,7 @@ steps: ...@@ -549,7 +549,7 @@ steps:
- tests/samplers - tests/samplers
- tests/conftest.py - tests/conftest.py
commands: commands:
- pytest -v -s -m samplers - pytest -v -s samplers
- label: LoRA Test %N # 20min each - label: LoRA Test %N # 20min each
timeout_in_minutes: 30 timeout_in_minutes: 30
...@@ -2177,7 +2177,7 @@ steps: ...@@ -2177,7 +2177,7 @@ steps:
- tests/samplers - tests/samplers
- tests/conftest.py - tests/conftest.py
commands: commands:
- pytest -v -s -m samplers - pytest -v -s samplers
- label: LoRA Test %N # 20min each - label: LoRA Test %N # 20min each
timeout_in_minutes: 30 timeout_in_minutes: 30
......
...@@ -18,4 +18,4 @@ steps: ...@@ -18,4 +18,4 @@ steps:
depends_on: depends_on:
- image-build-amd - image-build-amd
commands: commands:
- pytest -v -s -m samplers - pytest -v -s samplers
...@@ -9,6 +9,26 @@ import pytest ...@@ -9,6 +9,26 @@ import pytest
from transformers import AutoModelForSeq2SeqLM from transformers import AutoModelForSeq2SeqLM
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
from vllm.platforms import current_platform
# Extra engine kwargs needed for numerically deterministic beam search.
# On ROCm, floating-point reductions in attention and GEMM kernels are
# non-associative and sensitive to batch geometry, so we:
# async_scheduling=False – deterministic batch composition
# enforce_eager=True – no CUDA-graph padding changing effective size
# enable_prefix_caching=False – avoid prefix-sharing side effects
# max_num_seqs=1 – fixed batch size across runs
# On other platforms these are not needed and the dict is empty.
EXTRA_ENGINE_KWARGS: dict = (
dict(
async_scheduling=False,
enforce_eager=True,
enable_prefix_caching=False,
max_num_seqs=1,
)
if current_platform.is_rocm()
else {}
)
# FIXME(zhuohan): The test can not pass if we: # FIXME(zhuohan): The test can not pass if we:
# 1. Increase max_tokens to 256. # 1. Increase max_tokens to 256.
...@@ -25,6 +45,7 @@ MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"] ...@@ -25,6 +45,7 @@ MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
@pytest.mark.parametrize("max_tokens", MAX_TOKENS) @pytest.mark.parametrize("max_tokens", MAX_TOKENS)
@pytest.mark.parametrize("beam_width", BEAM_WIDTHS) @pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
def test_beam_search_single_input( def test_beam_search_single_input(
monkeypatch,
hf_runner, hf_runner,
vllm_runner, vllm_runner,
example_prompts, example_prompts,
...@@ -33,13 +54,16 @@ def test_beam_search_single_input( ...@@ -33,13 +54,16 @@ def test_beam_search_single_input(
max_tokens: int, max_tokens: int,
beam_width: int, beam_width: int,
) -> None: ) -> None:
if current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_SKINNY_GEMM", "0")
example_prompts = example_prompts[:1] example_prompts = example_prompts[:1]
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_beam_search( hf_outputs = hf_model.generate_beam_search(
example_prompts, beam_width, max_tokens example_prompts, beam_width, max_tokens
) )
with vllm_runner(model, dtype=dtype) as vllm_model: with vllm_runner(model, dtype=dtype, **EXTRA_ENGINE_KWARGS) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search( vllm_outputs = vllm_model.generate_beam_search(
example_prompts, beam_width, max_tokens example_prompts, beam_width, max_tokens
) )
...@@ -66,6 +90,7 @@ def test_beam_search_single_input( ...@@ -66,6 +90,7 @@ def test_beam_search_single_input(
@pytest.mark.parametrize("max_tokens", MAX_TOKENS) @pytest.mark.parametrize("max_tokens", MAX_TOKENS)
@pytest.mark.parametrize("beam_width", BEAM_WIDTHS) @pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
def test_beam_search_with_concurrency_limit( def test_beam_search_with_concurrency_limit(
monkeypatch,
hf_runner, hf_runner,
vllm_runner, vllm_runner,
example_prompts, example_prompts,
...@@ -74,21 +99,29 @@ def test_beam_search_with_concurrency_limit( ...@@ -74,21 +99,29 @@ def test_beam_search_with_concurrency_limit(
max_tokens: int, max_tokens: int,
beam_width: int, beam_width: int,
) -> None: ) -> None:
if current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_SKINNY_GEMM", "0")
# example_prompts[1]&[3]&[7] fails due to unknown reason even without # example_prompts[1]&[3]&[7] fails due to unknown reason even without
# concurrency limit. skip them for now. # concurrency limit. skip them for now.
example_prompts = example_prompts[:8] example_prompts = example_prompts[:8]
concurrency_limit = 2 concurrency_limit = 2
assert len(example_prompts) > concurrency_limit assert len(example_prompts) > concurrency_limit
with vllm_runner(model, dtype=dtype) as vllm_model: with vllm_runner(model, dtype=dtype, **EXTRA_ENGINE_KWARGS) as vllm_model:
outputs_with_limit = vllm_model.generate_beam_search( outputs_with_limit = vllm_model.generate_beam_search(
example_prompts, beam_width, max_tokens, concurrency_limit=concurrency_limit example_prompts,
beam_width,
max_tokens,
concurrency_limit=concurrency_limit,
) )
outputs_without_limit = [] outputs_without_limit = []
for i in range(0, len(example_prompts), concurrency_limit): for i in range(0, len(example_prompts), concurrency_limit):
outputs_without_limit.extend( outputs_without_limit.extend(
vllm_model.generate_beam_search( vllm_model.generate_beam_search(
example_prompts[i : i + concurrency_limit], beam_width, max_tokens example_prompts[i : i + concurrency_limit],
beam_width,
max_tokens,
) )
) )
...@@ -118,6 +151,7 @@ def test_beam_search_with_concurrency_limit( ...@@ -118,6 +151,7 @@ def test_beam_search_with_concurrency_limit(
@pytest.mark.parametrize("max_tokens", MAX_TOKENS) @pytest.mark.parametrize("max_tokens", MAX_TOKENS)
@pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS) @pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS)
def test_beam_search_passes_multimodal_data( def test_beam_search_passes_multimodal_data(
monkeypatch,
hf_runner, hf_runner,
vllm_runner, vllm_runner,
dtype: str, dtype: str,
...@@ -125,6 +159,9 @@ def test_beam_search_passes_multimodal_data( ...@@ -125,6 +159,9 @@ def test_beam_search_passes_multimodal_data(
beam_width: int, beam_width: int,
) -> None: ) -> None:
"""Ensure that beam search passes multimodal data through correctly.""" """Ensure that beam search passes multimodal data through correctly."""
if current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_SKINNY_GEMM", "0")
# NOTE - this test is primarily to check that mm data is passed to beams # NOTE - this test is primarily to check that mm data is passed to beams
# correctly. As such, we just need to check one extra modality to make # correctly. As such, we just need to check one extra modality to make
# sure things pass through properly. # sure things pass through properly.
...@@ -145,7 +182,7 @@ def test_beam_search_passes_multimodal_data( ...@@ -145,7 +182,7 @@ def test_beam_search_passes_multimodal_data(
audios=audios, audios=audios,
) )
with vllm_runner(model, dtype=dtype) as vllm_model: with vllm_runner(model, dtype=dtype, **EXTRA_ENGINE_KWARGS) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search( vllm_outputs = vllm_model.generate_beam_search(
prompts, prompts,
beam_width=beam_width, beam_width=beam_width,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment