"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "5d8eb93eeec0476b9f0fddc96f2960be0ce782b6"
Unverified Commit efba1a17 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Bump`flash_attn` version to `2.1` (#27079)

* pin FA-2 to `2.1`

* fix on modeling
parent 90412401
...@@ -1273,15 +1273,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1273,15 +1273,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if not is_flash_attn_2_available(): if not is_flash_attn_2_available():
raise ImportError( raise ImportError(
"Flash Attention 2.0 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for" "Flash Attention 2 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for"
" installing it." " installing it. Make sure to have at least the version 2.1.0"
) )
else: else:
flash_attention_version = version.parse(importlib.metadata.version("flash_attn")) flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
is_flash_greater_than_2 = flash_attention_version > version.parse("2.0.0") is_flash_greater_than_2 = flash_attention_version >= version.parse("2.1.0")
if not is_flash_greater_than_2: if not is_flash_greater_than_2:
raise ValueError( raise ValueError(
f"You need flash_attn package version to be greater than 2.0. Make sure to have that version installed - detected version {flash_attention_version}" f"You need flash_attn package version to be greater or equal than 2.1. Make sure to have that version installed - detected version {flash_attention_version}"
) )
_is_bettertransformer = getattr(cls, "use_bettertransformer", False) _is_bettertransformer = getattr(cls, "use_bettertransformer", False)
......
...@@ -73,7 +73,7 @@ _apex_available = _is_package_available("apex") ...@@ -73,7 +73,7 @@ _apex_available = _is_package_available("apex")
_bitsandbytes_available = _is_package_available("bitsandbytes") _bitsandbytes_available = _is_package_available("bitsandbytes")
_flash_attn_2_available = _is_package_available("flash_attn") and version.parse( _flash_attn_2_available = _is_package_available("flash_attn") and version.parse(
importlib.metadata.version("flash_attn") importlib.metadata.version("flash_attn")
) >= version.parse("2.0.0") ) >= version.parse("2.1.0")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None _bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs") _coloredlogs_available = _is_package_available("coloredlogs")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment