Unverified Commit 5e714f7f authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[ROCm][CI] Fix HuggingFace flash_attention_2 accuracy issue in Isaac vision encoder (#32233)


Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent 11b6af52
...@@ -30,3 +30,22 @@ def pytest_collection_modifyitems(config, items): ...@@ -30,3 +30,22 @@ def pytest_collection_modifyitems(config, items):
UserWarning, UserWarning,
stacklevel=1, stacklevel=1,
) )
def patch_hf_vision_attn_for_rocm(model):
"""Force SDPA for HF vision encoders on ROCm.
HF's flash_attention_2 has accuracy issues on ROCm that bypass
torch.backends.cuda settings. This forces SDPA which then uses
math_sdp via the pytest_collection_modifyitems settings.
"""
if not current_platform.is_rocm():
return
inner = getattr(model, "model", model)
if hasattr(inner, "vision_embedding"):
vit = inner.vision_embedding[0]
for layer in vit.encoder.layers:
if hasattr(layer, "self_attn"):
layer.self_attn.vision_config._attn_implementation = "sdpa"
...@@ -576,6 +576,14 @@ def isaac_patch_hf_runner(hf_model: HfRunner) -> HfRunner: ...@@ -576,6 +576,14 @@ def isaac_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
# ---------------------------- # ----------------------------
isaac_model = hf_model.model.model isaac_model = hf_model.model.model
# [ROCm] Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
# accuracy issues: https://github.com/vllm-project/vllm/issues/30167
# TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace
# ----------------------------
from ...conftest import patch_hf_vision_attn_for_rocm
patch_hf_vision_attn_for_rocm(hf_model.model)
def patched_forward( def patched_forward(
self, self,
input_ids=None, input_ids=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment