Unverified Commit 6df9179c authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core`] Fix fa-2 import (#26785)

* fix fa-2 import

* nit
parent 5bfda28d
...@@ -70,7 +70,7 @@ from .utils import ( ...@@ -70,7 +70,7 @@ from .utils import (
is_accelerate_available, is_accelerate_available,
is_auto_gptq_available, is_auto_gptq_available,
is_bitsandbytes_available, is_bitsandbytes_available,
is_flash_attn_available, is_flash_attn_2_available,
is_offline_mode, is_offline_mode,
is_optimum_available, is_optimum_available,
is_peft_available, is_peft_available,
...@@ -1269,7 +1269,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1269,7 +1269,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"request support for this architecture: https://github.com/huggingface/transformers/issues/new" "request support for this architecture: https://github.com/huggingface/transformers/issues/new"
) )
if not is_flash_attn_available(): if not is_flash_attn_2_available():
raise ImportError( raise ImportError(
"Flash Attention 2.0 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for" "Flash Attention 2.0 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for"
" installing it." " installing it."
......
...@@ -35,13 +35,13 @@ from ...utils import ( ...@@ -35,13 +35,13 @@ from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_available, is_flash_attn_2_available,
logging, logging,
) )
from .configuration_falcon import FalconConfig from .configuration_falcon import FalconConfig
if is_flash_attn_available(): if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
......
...@@ -34,14 +34,14 @@ from ...pytorch_utils import ALL_LAYERNORM_LAYERS ...@@ -34,14 +34,14 @@ from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_available, is_flash_attn_2_available,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from .configuration_llama import LlamaConfig from .configuration_llama import LlamaConfig
if is_flash_attn_available(): if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
......
...@@ -34,14 +34,14 @@ from ...modeling_utils import PreTrainedModel ...@@ -34,14 +34,14 @@ from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_available, is_flash_attn_2_available,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from .configuration_mistral import MistralConfig from .configuration_mistral import MistralConfig
if is_flash_attn_available(): if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
......
...@@ -60,7 +60,7 @@ from .utils import ( ...@@ -60,7 +60,7 @@ from .utils import (
is_detectron2_available, is_detectron2_available,
is_essentia_available, is_essentia_available,
is_faiss_available, is_faiss_available,
is_flash_attn_available, is_flash_attn_2_available,
is_flax_available, is_flax_available,
is_fsdp_available, is_fsdp_available,
is_ftfy_available, is_ftfy_available,
...@@ -432,7 +432,7 @@ def require_flash_attn(test_case): ...@@ -432,7 +432,7 @@ def require_flash_attn(test_case):
These tests are skipped when Flash Attention isn't installed. These tests are skipped when Flash Attention isn't installed.
""" """
return unittest.skipUnless(is_flash_attn_available(), "test requires Flash Attention")(test_case) return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case)
def require_peft(test_case): def require_peft(test_case):
......
...@@ -115,7 +115,7 @@ from .import_utils import ( ...@@ -115,7 +115,7 @@ from .import_utils import (
is_detectron2_available, is_detectron2_available,
is_essentia_available, is_essentia_available,
is_faiss_available, is_faiss_available,
is_flash_attn_available, is_flash_attn_2_available,
is_flax_available, is_flax_available,
is_fsdp_available, is_fsdp_available,
is_ftfy_available, is_ftfy_available,
......
...@@ -71,7 +71,9 @@ TORCH_FX_REQUIRED_VERSION = version.parse("1.10") ...@@ -71,7 +71,9 @@ TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True) _accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
_apex_available = _is_package_available("apex") _apex_available = _is_package_available("apex")
_bitsandbytes_available = _is_package_available("bitsandbytes") _bitsandbytes_available = _is_package_available("bitsandbytes")
_flash_attn_available = _is_package_available("flash_attn") _flash_attn_2_available = _is_package_available("flash_attn") and version.parse(
importlib.metadata.version("flash_attn")
) >= version.parse("2.0.0")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None _bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs") _coloredlogs_available = _is_package_available("coloredlogs")
...@@ -579,14 +581,14 @@ def is_bitsandbytes_available(): ...@@ -579,14 +581,14 @@ def is_bitsandbytes_available():
return _bitsandbytes_available and torch.cuda.is_available() return _bitsandbytes_available and torch.cuda.is_available()
def is_flash_attn_available(): def is_flash_attn_2_available():
if not is_torch_available(): if not is_torch_available():
return False return False
# Let's add an extra check to see if cuda is available # Let's add an extra check to see if cuda is available
import torch import torch
return _flash_attn_available and torch.cuda.is_available() return _flash_attn_2_available and torch.cuda.is_available()
def is_torchdistx_available(): def is_torchdistx_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment