Unverified Commit a690fb5b authored by Divakar Verma's avatar Divakar Verma Committed by GitHub
Browse files

[CI][ROCm] Fix test_correctness_sliding_window (#29243)


Signed-off-by: default avatarDivakar Verma <divakar.verma@amd.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent 81fe3f82
...@@ -5,6 +5,7 @@ from dataclasses import dataclass ...@@ -5,6 +5,7 @@ from dataclasses import dataclass
import pytest import pytest
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.platforms import current_platform
from ...utils import check_answers, prep_prompts from ...utils import check_answers, prep_prompts
...@@ -40,10 +41,17 @@ def test_sliding_window_retrieval( ...@@ -40,10 +41,17 @@ def test_sliding_window_retrieval(
If we tell it upfront which we are going to be looking for, then If we tell it upfront which we are going to be looking for, then
it answers correctly (mostly). it answers correctly (mostly).
""" """
# NOTE: For ROCm, we have to enforce eager mode to use custom kernel
# implementation of GELU with tanh approximation, as PyTorch's native
# implementation is currently unstable with torch.compile and produces garbage.
enforce_eager = current_platform.is_rocm()
test_config = model_config[model] test_config = model_config[model]
llm = LLM( llm = LLM(
model=model, disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager model=model,
disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager,
enforce_eager=enforce_eager,
) )
sampling_params = SamplingParams(temperature=0.0, max_tokens=100) sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment