Unverified Commit e02706d2 authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[ROCm][CI][V1] Fix `nixl_connector` test failure and achieve CUDA parity in...


[ROCm][CI][V1] Fix `nixl_connector` test failure and achieve CUDA parity in `test_async_scheduling` (#32000)
Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent b474782a
...@@ -163,14 +163,7 @@ def run_tests( ...@@ -163,14 +163,7 @@ def run_tests(
uni/multiproc executor with spec decoding.""" uni/multiproc executor with spec decoding."""
# Determine attention config based on platform # Determine attention config based on platform
if current_platform.is_rocm(): attention_config = {"backend": "FLEX_ATTENTION"}
if is_testing_with_spec_decoding:
# Use TRITON_ATTN for spec decoding test for consistency
attention_config = {"backend": "TRITON_ATTN"}
else:
attention_config = {"backend": "ROCM_ATTN"}
else:
attention_config = {"backend": "FLEX_ATTENTION"}
with monkeypatch.context() as m: with monkeypatch.context() as m:
# lock matmul precision to full FP32 (IEEE) # lock matmul precision to full FP32 (IEEE)
...@@ -226,15 +219,7 @@ def run_tests( ...@@ -226,15 +219,7 @@ def run_tests(
name_1=f"config=[{test_config}], params={params}", name_1=f"config=[{test_config}], params={params}",
) )
# On ROCm with TRITON_ATTN (spec decoding test), skip strict assert _all_logprobs_match(base_logprobs, test_logprobs)
# logprobs comparison when logprobs are requested
skip_logprobs_check = (
current_platform.is_rocm()
and params.get("logprobs")
and is_testing_with_spec_decoding
)
if not skip_logprobs_check:
assert _all_logprobs_match(base_logprobs, test_logprobs)
if ( if (
base_acceptance_rate is not None base_acceptance_rate is not None
...@@ -374,12 +359,7 @@ def _all_logprobs_match(req_a, req_b) -> bool: ...@@ -374,12 +359,7 @@ def _all_logprobs_match(req_a, req_b) -> bool:
def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool: def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool:
if current_platform.is_rocm(): rel_tol, abs_tol = 1e-3, 1e-6
# ROCm has higher numerical variance
# due to use of float16.
rel_tol, abs_tol = 5e-2, 1e-5
else:
rel_tol, abs_tol = 1e-3, 1e-6
return ( return (
len(lps_a) == len(lps_b) len(lps_a) == len(lps_b)
and lps_a.keys() == lps_b.keys() and lps_a.keys() == lps_b.keys()
......
...@@ -185,18 +185,21 @@ class FakeNixlWrapper: ...@@ -185,18 +185,21 @@ class FakeNixlWrapper:
def _make_fake_nixl_pkg(): def _make_fake_nixl_pkg():
"""Context manager that creates a temporary package making """Context manager that creates a temporary package making
`from nixl._api import nixl_agent` resolve to our FakeNixlWrapper. `from nixl._api import nixl_agent` resolve to our FakeNixlWrapper.
Also creates rixl package for ROCm compatibility.
Automatically cleans up the temporary directory when done. Automatically cleans up the temporary directory when done.
""" """
with tempfile.TemporaryDirectory() as td: with tempfile.TemporaryDirectory() as td:
pkg_root = os.path.join(td, "nixl", "_api") # Create both nixl and rixl packages for cross-platform compatibility
os.makedirs(pkg_root, exist_ok=True) for pkg_name in ["nixl", "rixl"]:
pkg_root = os.path.join(td, pkg_name, "_api")
os.makedirs(pkg_root, exist_ok=True)
# Get the source code of FakeNixlWrapper class and dedent it # Get the source code of FakeNixlWrapper class and dedent it
fake_nixl_source = inspect.getsource(FakeNixlWrapper) fake_nixl_source = inspect.getsource(FakeNixlWrapper)
fake_nixl_source = textwrap.dedent(fake_nixl_source) fake_nixl_source = textwrap.dedent(fake_nixl_source)
stub = f"""\ stub = f"""\
# Copy of FakeNixlWrapper implementation for Ray workers # Copy of FakeNixlWrapper implementation for Ray workers
import uuid import uuid
from collections import defaultdict from collections import defaultdict
...@@ -206,16 +209,17 @@ from collections import defaultdict ...@@ -206,16 +209,17 @@ from collections import defaultdict
# Export as nixl_agent # Export as nixl_agent
nixl_agent = FakeNixlWrapper nixl_agent = FakeNixlWrapper
""" """
with open(os.path.join(pkg_root, "__init__.py"), "w") as f: with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
f.write(stub) f.write(stub)
# Mock nixlXferTelemetry class # Mock nixlXferTelemetry class
pkg_root2 = os.path.join(td, "nixl", "_bindings") pkg_root2 = os.path.join(td, pkg_name, "_bindings")
os.makedirs(pkg_root2, exist_ok=True) os.makedirs(pkg_root2, exist_ok=True)
with open(os.path.join(pkg_root2, "__init__.py"), "w") as f: with open(os.path.join(pkg_root2, "__init__.py"), "w") as f:
f.write("class nixlXferTelemetry: pass") f.write("class nixlXferTelemetry: pass")
# touch parent package # touch parent package
open(os.path.join(td, "nixl", "__init__.py"), "w").close() open(os.path.join(td, pkg_name, "__init__.py"), "w").close()
yield td yield td
......
...@@ -187,6 +187,11 @@ class EagleProposer: ...@@ -187,6 +187,11 @@ class EagleProposer:
rocm_types.append(MLACommonMetadata) rocm_types.append(MLACommonMetadata)
# FlexAttention backend support
from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadata
rocm_types.append(FlexAttentionMetadata)
self.allowed_attn_types = tuple(rocm_types) self.allowed_attn_types = tuple(rocm_types)
# Parse the speculative token tree. # Parse the speculative token tree.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment