Unverified Commit ab71f3c8 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[core] Refactor hub attn kernels (#12475)



* refactor how attention kernels from hub are used.

* up

* refactor according to Dhruv's ideas.
Co-authored-by: default avatarDhruv Nair <dhruv@huggingface.co>

* empty
Co-authored-by: default avatarDhruv Nair <dhruv@huggingface.co>

* empty
Co-authored-by: default avatarDhruv Nair <dhruv@huggingface.co>

* empty
Co-authored-by: default avatardn6 <dhruv@huggingface.co>

* up

---------
Co-authored-by: default avatarDhruv Nair <dhruv@huggingface.co>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent b7df4a53
...@@ -16,6 +16,7 @@ import contextlib ...@@ -16,6 +16,7 @@ import contextlib
import functools import functools
import inspect import inspect
import math import math
from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
...@@ -42,7 +43,7 @@ from ..utils import ( ...@@ -42,7 +43,7 @@ from ..utils import (
is_xformers_available, is_xformers_available,
is_xformers_version, is_xformers_version,
) )
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -82,24 +83,11 @@ else: ...@@ -82,24 +83,11 @@ else:
flash_attn_3_func = None flash_attn_3_func = None
flash_attn_3_varlen_func = None flash_attn_3_varlen_func = None
if _CAN_USE_AITER_ATTN: if _CAN_USE_AITER_ATTN:
from aiter import flash_attn_func as aiter_flash_attn_func from aiter import flash_attn_func as aiter_flash_attn_func
else: else:
aiter_flash_attn_func = None aiter_flash_attn_func = None
if DIFFUSERS_ENABLE_HUB_KERNELS:
if not is_kernels_available():
raise ImportError(
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
)
from ..utils.kernels_utils import _get_fa3_from_hub
flash_attn_interface_hub = _get_fa3_from_hub()
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
else:
flash_attn_3_func_hub = None
if _CAN_USE_SAGE_ATTN: if _CAN_USE_SAGE_ATTN:
from sageattention import ( from sageattention import (
sageattn, sageattn,
...@@ -261,6 +249,25 @@ class _AttentionBackendRegistry: ...@@ -261,6 +249,25 @@ class _AttentionBackendRegistry:
return supports_context_parallel return supports_context_parallel
@dataclass
class _HubKernelConfig:
"""Configuration for downloading and using a hub-based attention kernel."""
repo_id: str
function_attr: str
revision: Optional[str] = None
kernel_fn: Optional[Callable] = None
# Registry for hub-based attention kernels
_HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
# TODO: temporary revision for now. Remove when merged upstream into `main`.
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
)
}
@contextlib.contextmanager @contextlib.contextmanager
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE): def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
""" """
...@@ -415,13 +422,9 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None ...@@ -415,13 +422,9 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
# TODO: add support Hub variant of FA3 varlen later # TODO: add support Hub variant of FA3 varlen later
elif backend in [AttentionBackendName._FLASH_3_HUB]: elif backend in [AttentionBackendName._FLASH_3_HUB]:
if not DIFFUSERS_ENABLE_HUB_KERNELS:
raise RuntimeError(
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
)
if not is_kernels_available(): if not is_kernels_available():
raise RuntimeError( raise RuntimeError(
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
) )
elif backend == AttentionBackendName.AITER: elif backend == AttentionBackendName.AITER:
...@@ -571,6 +574,29 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): ...@@ -571,6 +574,29 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
return q_idx >= kv_idx return q_idx >= kv_idx
# ===== Helpers for downloading kernels =====
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
if backend not in _HUB_KERNELS_REGISTRY:
return
config = _HUB_KERNELS_REGISTRY[backend]
if config.kernel_fn is not None:
return
try:
from kernels import get_kernel
kernel_module = get_kernel(config.repo_id, revision=config.revision)
kernel_func = getattr(kernel_module, config.function_attr)
# Cache the downloaded kernel function in the config object
config.kernel_fn = kernel_func
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
raise
# ===== torch op registrations ===== # ===== torch op registrations =====
# Registrations are required for fullgraph tracing compatibility # Registrations are required for fullgraph tracing compatibility
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding # TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
...@@ -1418,7 +1444,8 @@ def _flash_attention_3_hub( ...@@ -1418,7 +1444,8 @@ def _flash_attention_3_hub(
return_attn_probs: bool = False, return_attn_probs: bool = False,
_parallel_config: Optional["ParallelConfig"] = None, _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor: ) -> torch.Tensor:
out = flash_attn_3_func_hub( func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query, q=query,
k=key, k=key,
v=value, v=value,
......
...@@ -595,7 +595,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -595,7 +595,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
attention as backend. attention as backend.
""" """
from .attention import AttentionModuleMixin from .attention import AttentionModuleMixin
from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements from .attention_dispatch import (
AttentionBackendName,
_check_attention_backend_requirements,
_maybe_download_kernel_for_backend,
)
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin # TODO: the following will not be required when everything is refactored to AttentionModuleMixin
from .attention_processor import Attention, MochiAttention from .attention_processor import Attention, MochiAttention
...@@ -606,8 +610,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -606,8 +610,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
available_backends = {x.value for x in AttentionBackendName.__members__.values()} available_backends = {x.value for x in AttentionBackendName.__members__.values()}
if backend not in available_backends: if backend not in available_backends:
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
backend = AttentionBackendName(backend) backend = AttentionBackendName(backend)
_check_attention_backend_requirements(backend) _check_attention_backend_requirements(backend)
_maybe_download_kernel_for_backend(backend)
attention_classes = (Attention, MochiAttention, AttentionModuleMixin) attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules(): for module in self.modules():
......
...@@ -46,7 +46,6 @@ DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0").upper() in ENV_V ...@@ -46,7 +46,6 @@ DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0").upper() in ENV_V
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8 DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES
DIFFUSERS_ENABLE_HUB_KERNELS = os.environ.get("DIFFUSERS_ENABLE_HUB_KERNELS", "").upper() in ENV_VARS_TRUE_VALUES
# Below should be `True` if the current version of `peft` and `transformers` are compatible with # Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
......
from ..utils import get_logger
from .import_utils import is_kernels_available
logger = get_logger(__name__)
_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
def _get_fa3_from_hub():
if not is_kernels_available():
return None
else:
from kernels import get_kernel
try:
# TODO: temporary revision for now. Remove when merged upstream into `main`.
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs")
return flash_attn_3_hub
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
raise
...@@ -7,7 +7,6 @@ To run this test suite: ...@@ -7,7 +7,6 @@ To run this test suite:
```bash ```bash
export RUN_ATTENTION_BACKEND_TESTS=yes export RUN_ATTENTION_BACKEND_TESTS=yes
export DIFFUSERS_ENABLE_HUB_KERNELS=yes
pytest tests/others/test_attention_backends.py pytest tests/others/test_attention_backends.py
``` ```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment