Unverified Commit e02706d2 authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[ROCm][CI][V1] Fix `nixl_connector` test failure and achieve CUDA parity in...


[ROCm][CI][V1] Fix `nixl_connector` test failure and achieve CUDA parity in `test_async_scheduling` (#32000)
Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent b474782a
...@@ -163,13 +163,6 @@ def run_tests( ...@@ -163,13 +163,6 @@ def run_tests(
uni/multiproc executor with spec decoding.""" uni/multiproc executor with spec decoding."""
# Determine attention config based on platform # Determine attention config based on platform
if current_platform.is_rocm():
if is_testing_with_spec_decoding:
# Use TRITON_ATTN for spec decoding test for consistency
attention_config = {"backend": "TRITON_ATTN"}
else:
attention_config = {"backend": "ROCM_ATTN"}
else:
attention_config = {"backend": "FLEX_ATTENTION"} attention_config = {"backend": "FLEX_ATTENTION"}
with monkeypatch.context() as m: with monkeypatch.context() as m:
...@@ -226,14 +219,6 @@ def run_tests( ...@@ -226,14 +219,6 @@ def run_tests(
name_1=f"config=[{test_config}], params={params}", name_1=f"config=[{test_config}], params={params}",
) )
# On ROCm with TRITON_ATTN (spec decoding test), skip strict
# logprobs comparison when logprobs are requested
skip_logprobs_check = (
current_platform.is_rocm()
and params.get("logprobs")
and is_testing_with_spec_decoding
)
if not skip_logprobs_check:
assert _all_logprobs_match(base_logprobs, test_logprobs) assert _all_logprobs_match(base_logprobs, test_logprobs)
if ( if (
...@@ -374,11 +359,6 @@ def _all_logprobs_match(req_a, req_b) -> bool: ...@@ -374,11 +359,6 @@ def _all_logprobs_match(req_a, req_b) -> bool:
def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool: def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool:
if current_platform.is_rocm():
# ROCm has higher numerical variance
# due to use of float16.
rel_tol, abs_tol = 5e-2, 1e-5
else:
rel_tol, abs_tol = 1e-3, 1e-6 rel_tol, abs_tol = 1e-3, 1e-6
return ( return (
len(lps_a) == len(lps_b) len(lps_a) == len(lps_b)
......
...@@ -185,11 +185,14 @@ class FakeNixlWrapper: ...@@ -185,11 +185,14 @@ class FakeNixlWrapper:
def _make_fake_nixl_pkg(): def _make_fake_nixl_pkg():
"""Context manager that creates a temporary package making """Context manager that creates a temporary package making
`from nixl._api import nixl_agent` resolve to our FakeNixlWrapper. `from nixl._api import nixl_agent` resolve to our FakeNixlWrapper.
Also creates rixl package for ROCm compatibility.
Automatically cleans up the temporary directory when done. Automatically cleans up the temporary directory when done.
""" """
with tempfile.TemporaryDirectory() as td: with tempfile.TemporaryDirectory() as td:
pkg_root = os.path.join(td, "nixl", "_api") # Create both nixl and rixl packages for cross-platform compatibility
for pkg_name in ["nixl", "rixl"]:
pkg_root = os.path.join(td, pkg_name, "_api")
os.makedirs(pkg_root, exist_ok=True) os.makedirs(pkg_root, exist_ok=True)
# Get the source code of FakeNixlWrapper class and dedent it # Get the source code of FakeNixlWrapper class and dedent it
...@@ -210,12 +213,13 @@ nixl_agent = FakeNixlWrapper ...@@ -210,12 +213,13 @@ nixl_agent = FakeNixlWrapper
f.write(stub) f.write(stub)
# Mock nixlXferTelemetry class # Mock nixlXferTelemetry class
pkg_root2 = os.path.join(td, "nixl", "_bindings") pkg_root2 = os.path.join(td, pkg_name, "_bindings")
os.makedirs(pkg_root2, exist_ok=True) os.makedirs(pkg_root2, exist_ok=True)
with open(os.path.join(pkg_root2, "__init__.py"), "w") as f: with open(os.path.join(pkg_root2, "__init__.py"), "w") as f:
f.write("class nixlXferTelemetry: pass") f.write("class nixlXferTelemetry: pass")
# touch parent package # touch parent package
open(os.path.join(td, "nixl", "__init__.py"), "w").close() open(os.path.join(td, pkg_name, "__init__.py"), "w").close()
yield td yield td
......
...@@ -187,6 +187,11 @@ class EagleProposer: ...@@ -187,6 +187,11 @@ class EagleProposer:
rocm_types.append(MLACommonMetadata) rocm_types.append(MLACommonMetadata)
# FlexAttention backend support
from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadata
rocm_types.append(FlexAttentionMetadata)
self.allowed_attn_types = tuple(rocm_types) self.allowed_attn_types = tuple(rocm_types)
# Parse the speculative token tree. # Parse the speculative token tree.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment