Unverified Commit 6df9179c authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core`] Fix fa-2 import (#26785)

* fix fa-2 import

* nit
parent 5bfda28d
......@@ -70,7 +70,7 @@ from .utils import (
is_accelerate_available,
is_auto_gptq_available,
is_bitsandbytes_available,
is_flash_attn_available,
is_flash_attn_2_available,
is_offline_mode,
is_optimum_available,
is_peft_available,
......@@ -1269,7 +1269,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"request support for this architecture: https://github.com/huggingface/transformers/issues/new"
)
if not is_flash_attn_available():
if not is_flash_attn_2_available():
raise ImportError(
"Flash Attention 2.0 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for"
" installing it."
......
......@@ -35,13 +35,13 @@ from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_available,
is_flash_attn_2_available,
logging,
)
from .configuration_falcon import FalconConfig
if is_flash_attn_available():
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
......
......@@ -34,14 +34,14 @@ from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_available,
is_flash_attn_2_available,
logging,
replace_return_docstrings,
)
from .configuration_llama import LlamaConfig
if is_flash_attn_available():
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
......
......@@ -34,14 +34,14 @@ from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_available,
is_flash_attn_2_available,
logging,
replace_return_docstrings,
)
from .configuration_mistral import MistralConfig
if is_flash_attn_available():
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
......
......@@ -60,7 +60,7 @@ from .utils import (
is_detectron2_available,
is_essentia_available,
is_faiss_available,
is_flash_attn_available,
is_flash_attn_2_available,
is_flax_available,
is_fsdp_available,
is_ftfy_available,
......@@ -432,7 +432,7 @@ def require_flash_attn(test_case):
These tests are skipped when Flash Attention isn't installed.
"""
return unittest.skipUnless(is_flash_attn_available(), "test requires Flash Attention")(test_case)
return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case)
def require_peft(test_case):
......
......@@ -115,7 +115,7 @@ from .import_utils import (
is_detectron2_available,
is_essentia_available,
is_faiss_available,
is_flash_attn_available,
is_flash_attn_2_available,
is_flax_available,
is_fsdp_available,
is_ftfy_available,
......
......@@ -71,7 +71,9 @@ TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
_apex_available = _is_package_available("apex")
_bitsandbytes_available = _is_package_available("bitsandbytes")
_flash_attn_available = _is_package_available("flash_attn")
_flash_attn_2_available = _is_package_available("flash_attn") and version.parse(
importlib.metadata.version("flash_attn")
) >= version.parse("2.0.0")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs")
......@@ -579,14 +581,14 @@ def is_bitsandbytes_available():
return _bitsandbytes_available and torch.cuda.is_available()
def is_flash_attn_available():
def is_flash_attn_2_available():
if not is_torch_available():
return False
# Let's add an extra check to see if cuda is available
import torch
return _flash_attn_available and torch.cuda.is_available()
return _flash_attn_2_available and torch.cuda.is_available()
def is_torchdistx_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment