Unverified Commit 9a7ae77a authored by Aishwarya Badlani's avatar Aishwarya Badlani Committed by GitHub
Browse files

Fix PyTorch 2.3.1 compatibility: add version guard for torch.library.… (#12206)



* Fix PyTorch 2.3.1 compatibility: add version guard for torch.library.custom_op

- Add hasattr() check for torch.library.custom_op and register_fake
- These functions were added in PyTorch 2.4, causing import failures in 2.3.1
- Both decorators and functions are now properly guarded with version checks
- Maintains backward compatibility while preserving functionality

Fixes #12195

* Use dummy decorators approach for PyTorch version compatibility

- Replace hasattr check with version string comparison
- Add no-op decorator functions for PyTorch < 2.4.0
- Follows pattern from #11941 as suggested by reviewer
- Maintains cleaner code structure without indentation changes

* Update src/diffusers/models/attention_dispatch.py

Update all the decorator usages
Co-authored-by: default avatarAryan <contact.aryanvs@gmail.com>

* Update src/diffusers/models/attention_dispatch.py
Co-authored-by: default avatarAryan <contact.aryanvs@gmail.com>

* Update src/diffusers/models/attention_dispatch.py
Co-authored-by: default avatarAryan <contact.aryanvs@gmail.com>

* Update src/diffusers/models/attention_dispatch.py
Co-authored-by: default avatarAryan <contact.aryanvs@gmail.com>

* Move version check to top of file and use private naming as requested

* Apply style fixes

---------
Co-authored-by: default avatarAryan <contact.aryanvs@gmail.com>
Co-authored-by: default avatarAryan <aryan@huggingface.co>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 673d4357
...@@ -110,6 +110,27 @@ if _CAN_USE_XFORMERS_ATTN: ...@@ -110,6 +110,27 @@ if _CAN_USE_XFORMERS_ATTN:
else: else:
xops = None xops = None
# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4
if torch.__version__ >= "2.4.0":
_custom_op = torch.library.custom_op
_register_fake = torch.library.register_fake
else:
def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
def wrap(func):
return func
return wrap if fn is None else fn
def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1):
def wrap(func):
return func
return wrap if fn is None else fn
_custom_op = custom_op_no_op
_register_fake = register_fake_no_op
logger = get_logger(__name__) # pylint: disable=invalid-name logger = get_logger(__name__) # pylint: disable=invalid-name
...@@ -473,12 +494,11 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): ...@@ -473,12 +494,11 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
# ===== torch op registrations ===== # ===== torch op registrations =====
# Registrations are required for fullgraph tracing compatibility # Registrations are required for fullgraph tracing compatibility
# TODO: library.custom_op and register_fake probably need version guards?
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding # TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590 # this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
@_custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3_original( def _wrapped_flash_attn_3_original(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
...@@ -487,7 +507,7 @@ def _wrapped_flash_attn_3_original( ...@@ -487,7 +507,7 @@ def _wrapped_flash_attn_3_original(
return out, lse return out, lse
@torch.library.register_fake("flash_attn_3::_flash_attn_forward") @_register_fake("flash_attn_3::_flash_attn_forward")
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_len, num_heads, head_dim = query.shape batch_size, seq_len, num_heads, head_dim = query.shape
lse_shape = (batch_size, seq_len, num_heads) lse_shape = (batch_size, seq_len, num_heads)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment