Unverified Commit fd8afdf3 authored by Micah Williamson's avatar Micah Williamson Committed by GitHub
Browse files

[ROCm][CI] Reduce Flakiness For test_async_scheduling Using ROCM_ATTN With FP32 (#30811)


Signed-off-by: default avatarMicah Williamson <micah.williamson@amd.com>
parent a0b782f9
......@@ -148,7 +148,7 @@ def run_tests(
# Use TRITON_ATTN for spec decoding test for consistency
attention_config = {"backend": "TRITON_ATTN"}
else:
attention_config = {"backend": "ROCM_AITER_FA"}
attention_config = {"backend": "ROCM_ATTN"}
else:
attention_config = {"backend": "FLEX_ATTENTION"}
......@@ -284,14 +284,6 @@ def run_test(
print(f"---- TESTING {test_str}: {test_config}")
print("-" * 80)
# On ROCm: use float16 for first test (ROCM_AITER_FA), but float32 for
# spec decoding test (TRITON_ATTN) for better precision.
# On others: always use float32.
if current_platform.is_rocm() and not is_testing_with_spec_decoding:
dtype = "float16"
else:
dtype = "float32"
with VllmRunner(
model,
max_model_len=512,
......@@ -301,7 +293,7 @@ def run_test(
# enforce_eager=True,
async_scheduling=async_scheduling,
distributed_executor_backend=executor,
dtype=dtype,
dtype="float32",
speculative_config=spec_config,
disable_log_stats=False,
attention_config=attention_config,
......
......@@ -152,7 +152,11 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat
class RocmAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16,
torch.bfloat16,
torch.float32,
]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment