Unverified Commit ef1f7030 authored by Micah Williamson's avatar Micah Williamson Committed by GitHub
Browse files

[ROCm][CI] Fix test_cudagraph_mode failure in AMD CI (#29367)


Signed-off-by: default avatarMicah Williamson <micah.williamson@amd.com>
parent 12c007e2
......@@ -340,4 +340,11 @@ full_cg_backend_configs = {
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
),
"RocmAttn": BackendConfig(
name="RocmAttn",
env_vars={"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"},
comp_config={
"cudagraph_mode": "FULL",
},
),
}
......@@ -35,14 +35,22 @@ def temporary_environ(env_vars):
# test attention backend and cudagraph_mode combo
# (backend_name, cudagraph_mode, supported)
combo_cases_1 = [
if current_platform.is_rocm():
combo_cases_1 = [
("RocmAttn", "FULL", True),
("RocmAttn", "FULL_AND_PIECEWISE", True),
("TritonAttn", "FULL", True),
("TritonAttn", "FULL_AND_PIECEWISE", True),
]
else:
combo_cases_1 = [
("FA3", "FULL", True),
("FA3", "FULL_AND_PIECEWISE", True),
("FA2", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
("FA2", "FULL_AND_PIECEWISE", True),
("FlashInfer", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
("FlashInfer", "FULL_AND_PIECEWISE", True),
]
]
@pytest.mark.parametrize("backend_name, cudagraph_mode, supported", combo_cases_1)
......@@ -92,7 +100,21 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
# test cudagraph_mode with different compilation mode.
# (backend_name, cudagraph_mode, compilation_mode, supported)
combo_cases_2 = [
if current_platform.is_rocm():
combo_cases_2 = [
("RocmAttn", "FULL", CompilationMode.NONE, True),
("RocmAttn", "FULL", CompilationMode.VLLM_COMPILE, True),
("RocmAttn", "PIECEWISE", CompilationMode.NONE, False),
("RocmAttn", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.NONE, False),
("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
("RocmAttn", "NONE", CompilationMode.NONE, True),
("RocmAttn", "NONE", CompilationMode.VLLM_COMPILE, True),
]
else:
combo_cases_2 = [
("FA2", "FULL", CompilationMode.NONE, True),
("FA2", "FULL", CompilationMode.VLLM_COMPILE, True),
("FA2", "PIECEWISE", CompilationMode.NONE, False),
......@@ -103,7 +125,7 @@ combo_cases_2 = [
("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
("FA2", "NONE", CompilationMode.NONE, True),
("FA2", "NONE", CompilationMode.VLLM_COMPILE, True),
]
]
@pytest.mark.parametrize(
......
......@@ -321,8 +321,8 @@ class RocmPlatform(Platform):
return AttentionBackendEnum.TRITON_ATTN.get_path()
raise RuntimeError(
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
"to select a supported backend."
f"Attention backend {selected_backend.name} is not supported on "
"ROCm. Note that V0 attention backends have been removed."
)
@classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment