Unverified Commit 0753134f authored by Mohit Sharma's avatar Mohit Sharma Committed by GitHub
Browse files

Disable the FA backend for SDPA on AMD GPUs (#30850)

* disable fa

* disable fa

* update warning

* update warning
parent 9d889f87
...@@ -1479,6 +1479,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1479,6 +1479,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
config, config,
hard_check_only=False if requested_attn_implementation is None else True, hard_check_only=False if requested_attn_implementation is None else True,
) )
if (
torch.version.hip is not None
and config._attn_implementation == "sdpa"
and torch.cuda.device_count() > 1
):
logger.warning_once(
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
)
torch.backends.cuda.enable_flash_sdp(False)
else: else:
config._attn_implementation = "eager" config._attn_implementation = "eager"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment