Unverified Commit e02706d2 authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[ROCm][CI][V1] Fix `nixl_connector` test failure and achieve CUDA parity in...


[ROCm][CI][V1] Fix `nixl_connector` test failure and achieve CUDA parity in `test_async_scheduling` (#32000)
Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent b474782a
......@@ -163,14 +163,7 @@ def run_tests(
uni/multiproc executor with spec decoding."""
# Determine attention config based on platform
if current_platform.is_rocm():
if is_testing_with_spec_decoding:
# Use TRITON_ATTN for spec decoding test for consistency
attention_config = {"backend": "TRITON_ATTN"}
else:
attention_config = {"backend": "ROCM_ATTN"}
else:
attention_config = {"backend": "FLEX_ATTENTION"}
attention_config = {"backend": "FLEX_ATTENTION"}
with monkeypatch.context() as m:
# lock matmul precision to full FP32 (IEEE)
......@@ -226,15 +219,7 @@ def run_tests(
name_1=f"config=[{test_config}], params={params}",
)
# On ROCm with TRITON_ATTN (spec decoding test), skip strict
# logprobs comparison when logprobs are requested
skip_logprobs_check = (
current_platform.is_rocm()
and params.get("logprobs")
and is_testing_with_spec_decoding
)
if not skip_logprobs_check:
assert _all_logprobs_match(base_logprobs, test_logprobs)
assert _all_logprobs_match(base_logprobs, test_logprobs)
if (
base_acceptance_rate is not None
......@@ -374,12 +359,7 @@ def _all_logprobs_match(req_a, req_b) -> bool:
def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool:
if current_platform.is_rocm():
# ROCm has higher numerical variance
# due to use of float16.
rel_tol, abs_tol = 5e-2, 1e-5
else:
rel_tol, abs_tol = 1e-3, 1e-6
rel_tol, abs_tol = 1e-3, 1e-6
return (
len(lps_a) == len(lps_b)
and lps_a.keys() == lps_b.keys()
......
......@@ -185,18 +185,21 @@ class FakeNixlWrapper:
def _make_fake_nixl_pkg():
"""Context manager that creates a temporary package making
`from nixl._api import nixl_agent` resolve to our FakeNixlWrapper.
Also creates rixl package for ROCm compatibility.
Automatically cleans up the temporary directory when done.
"""
with tempfile.TemporaryDirectory() as td:
pkg_root = os.path.join(td, "nixl", "_api")
os.makedirs(pkg_root, exist_ok=True)
# Create both nixl and rixl packages for cross-platform compatibility
for pkg_name in ["nixl", "rixl"]:
pkg_root = os.path.join(td, pkg_name, "_api")
os.makedirs(pkg_root, exist_ok=True)
# Get the source code of FakeNixlWrapper class and dedent it
fake_nixl_source = inspect.getsource(FakeNixlWrapper)
fake_nixl_source = textwrap.dedent(fake_nixl_source)
# Get the source code of FakeNixlWrapper class and dedent it
fake_nixl_source = inspect.getsource(FakeNixlWrapper)
fake_nixl_source = textwrap.dedent(fake_nixl_source)
stub = f"""\
stub = f"""\
# Copy of FakeNixlWrapper implementation for Ray workers
import uuid
from collections import defaultdict
......@@ -206,16 +209,17 @@ from collections import defaultdict
# Export as nixl_agent
nixl_agent = FakeNixlWrapper
"""
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
f.write(stub)
# Mock nixlXferTelemetry class
pkg_root2 = os.path.join(td, "nixl", "_bindings")
os.makedirs(pkg_root2, exist_ok=True)
with open(os.path.join(pkg_root2, "__init__.py"), "w") as f:
f.write("class nixlXferTelemetry: pass")
# touch parent package
open(os.path.join(td, "nixl", "__init__.py"), "w").close()
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
f.write(stub)
# Mock nixlXferTelemetry class
pkg_root2 = os.path.join(td, pkg_name, "_bindings")
os.makedirs(pkg_root2, exist_ok=True)
with open(os.path.join(pkg_root2, "__init__.py"), "w") as f:
f.write("class nixlXferTelemetry: pass")
# touch parent package
open(os.path.join(td, pkg_name, "__init__.py"), "w").close()
yield td
......
......@@ -187,6 +187,11 @@ class EagleProposer:
rocm_types.append(MLACommonMetadata)
# FlexAttention backend support
from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadata
rocm_types.append(FlexAttentionMetadata)
self.allowed_attn_types = tuple(rocm_types)
# Parse the speculative token tree.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment