Unverified Commit e46e139f authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Remove logger warnings for attention backends and hard error during runtime instead (#11967)

* update

* update

* update
parent 14725164
......@@ -38,18 +38,29 @@ from ..utils import (
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
logger = get_logger(__name__) # pylint: disable=invalid-name
if is_flash_attn_available() and is_flash_attn_version(">=", "2.6.3"):
_REQUIRED_FLASH_VERSION = "2.6.3"
_REQUIRED_SAGE_VERSION = "2.1.1"
_REQUIRED_FLEX_VERSION = "2.5.0"
_REQUIRED_XLA_VERSION = "2.2"
_REQUIRED_XFORMERS_VERSION = "0.0.29"
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
_CAN_USE_NPU_ATTN = is_torch_npu_available()
_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION)
_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION)
if _CAN_USE_FLASH_ATTN:
from flash_attn import flash_attn_func, flash_attn_varlen_func
else:
logger.warning("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`.")
flash_attn_func = None
flash_attn_varlen_func = None
if is_flash_attn_3_available():
if _CAN_USE_FLASH_ATTN_3:
from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
else:
......@@ -57,7 +68,7 @@ else:
flash_attn_3_varlen_func = None
if is_sageattention_available() and is_sageattention_version(">=", "2.1.1"):
if _CAN_USE_SAGE_ATTN:
from sageattention import (
sageattn,
sageattn_qk_int8_pv_fp8_cuda,
......@@ -67,9 +78,6 @@ if is_sageattention_available() and is_sageattention_version(">=", "2.1.1"):
sageattn_varlen,
)
else:
logger.warning(
"`sageattention` is not available or the version is too old. Please install `sageattention>=2.1.1`."
)
sageattn = None
sageattn_qk_int8_pv_fp16_cuda = None
sageattn_qk_int8_pv_fp16_triton = None
......@@ -78,39 +86,39 @@ else:
sageattn_varlen = None
if is_torch_version(">=", "2.5.0"):
if _CAN_USE_FLEX_ATTN:
# We cannot import the flex_attention function from the package directly because it is expected (from the
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
# compiled function.
import torch.nn.attention.flex_attention as flex_attention
if is_torch_npu_available():
if _CAN_USE_NPU_ATTN:
from torch_npu import npu_fusion_attention
else:
npu_fusion_attention = None
if is_torch_xla_available() and is_torch_xla_version(">", "2.2"):
if _CAN_USE_XLA_ATTN:
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
else:
xla_flash_attention = None
if is_xformers_available() and is_xformers_version(">=", "0.0.29"):
if _CAN_USE_XFORMERS_ATTN:
import xformers.ops as xops
else:
logger.warning("`xformers` is not available or the version is too old. Please install `xformers>=0.0.29`.")
xops = None
logger = get_logger(__name__) # pylint: disable=invalid-name
# TODO(aryan): Add support for the following:
# - Sage Attention++
# - block sparse, radial and other attention methods
# - CP with sage attention, flex, xformers, other missing backends
# - Add support for normal and CP training with backends that don't support it yet
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
......@@ -179,13 +187,16 @@ class _AttentionBackendRegistry:
@contextlib.contextmanager
def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIVE):
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
"""
Context manager to set the active attention backend.
"""
if backend not in _AttentionBackendRegistry._backends:
raise ValueError(f"Backend {backend} is not registered.")
backend = AttentionBackendName(backend)
_check_attention_backend_requirements(backend)
old_backend = _AttentionBackendRegistry._active_backend
_AttentionBackendRegistry._active_backend = backend
......@@ -226,9 +237,10 @@ def dispatch_attention_fn(
"dropout_p": dropout_p,
"is_causal": is_causal,
"scale": scale,
"enable_gqa": enable_gqa,
**attention_kwargs,
}
if is_torch_version(">=", "2.5.0"):
kwargs["enable_gqa"] = enable_gqa
if _AttentionBackendRegistry._checks_enabled:
removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name])
......@@ -305,6 +317,57 @@ def _check_shape(
# ===== Helper functions =====
def _check_attention_backend_requirements(backend: AttentionBackendName) -> None:
if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]:
if not _CAN_USE_FLASH_ATTN:
raise RuntimeError(
f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`."
)
elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]:
if not _CAN_USE_FLASH_ATTN_3:
raise RuntimeError(
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
)
elif backend in [
AttentionBackendName.SAGE,
AttentionBackendName.SAGE_VARLEN,
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
]:
if not _CAN_USE_SAGE_ATTN:
raise RuntimeError(
f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`."
)
elif backend == AttentionBackendName.FLEX:
if not _CAN_USE_FLEX_ATTN:
raise RuntimeError(
f"Flex Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch>=2.5.0`."
)
elif backend == AttentionBackendName._NATIVE_NPU:
if not _CAN_USE_NPU_ATTN:
raise RuntimeError(
f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`."
)
elif backend == AttentionBackendName._NATIVE_XLA:
if not _CAN_USE_XLA_ATTN:
raise RuntimeError(
f"XLA Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_xla>={_REQUIRED_XLA_VERSION}`."
)
elif backend == AttentionBackendName.XFORMERS:
if not _CAN_USE_XFORMERS_ATTN:
raise RuntimeError(
f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`."
)
@functools.lru_cache(maxsize=128)
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
batch_size: int,
......
......@@ -622,19 +622,21 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
attention as backend.
"""
from .attention import AttentionModuleMixin
from .attention_dispatch import AttentionBackendName
from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
backend = backend.lower()
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
if backend not in available_backends:
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
backend = AttentionBackendName(backend)
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
_check_attention_backend_requirements(backend)
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
if not isinstance(module, attention_classes):
continue
......@@ -651,6 +653,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
from .attention import AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
if not isinstance(module, attention_classes):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment