"docs/vscode:/vscode.git/clone" did not exist on "48463ebb33c4a3f4035dbdaf55dc43778304f318"
Unverified Commit 0d04b1e2 authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Add Flash Attention 2 support to Musicgen and Musicgen Melody (#29939)

* add FA2 to o.g Musicgen

* make style

* add FA2 support to Musicgen Melody

* add generation FA2 tests to o.g Musicgen

* make style and fix copies

* add Musicgen to FA2 docs + deprecate list

* add sdpa supports to Musicgen's

* make style and fix copies

* refactor attention implementation arguments

* add Copied from to sdpa tests

* add copied form in sdpa tests melody

* add copied for FA2 generation tests

* add FA2 inference copied from

* make style
parent fed27ffc
...@@ -55,6 +55,8 @@ FlashAttention-2 is currently supported for the following architectures: ...@@ -55,6 +55,8 @@ FlashAttention-2 is currently supported for the following architectures:
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel) * [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel) * [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel) * [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel) * [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
...@@ -190,6 +192,8 @@ For now, Transformers supports SDPA inference and training for the following arc ...@@ -190,6 +192,8 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model) * [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model) * [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel) * [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
<Tip> <Tip>
......
...@@ -1470,6 +1470,12 @@ MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP = DeprecatedDict( ...@@ -1470,6 +1470,12 @@ MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP = DeprecatedDict(
MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = DeprecatedList(["facebook/musicgen-small"]) MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = DeprecatedList(["facebook/musicgen-small"])
MUSICGEN_MELODY_PRETRAINED_CONFIG_ARCHIVE_MAP = DeprecatedDict(
{"facebook/musicgen-melody": "https://huggingface.co/facebook/musicgen-melody/resolve/main/config.json"}
)
MUSICGEN_MELODY_PRETRAINED_MODEL_ARCHIVE_LIST = DeprecatedList(["facebook/musicgen-melody"])
MVP_PRETRAINED_MODEL_ARCHIVE_LIST = DeprecatedList( MVP_PRETRAINED_MODEL_ARCHIVE_LIST = DeprecatedList(
[ [
"RUCAIBox/mvp", "RUCAIBox/mvp",
......
...@@ -239,3 +239,20 @@ class MusicgenConfig(PretrainedConfig): ...@@ -239,3 +239,20 @@ class MusicgenConfig(PretrainedConfig):
# This is a property because you might want to change the codec model on the fly # This is a property because you might want to change the codec model on the fly
def sampling_rate(self): def sampling_rate(self):
return self.audio_encoder.sampling_rate return self.audio_encoder.sampling_rate
@property
def _attn_implementation(self):
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
if hasattr(self, "_attn_implementation_internal"):
if self._attn_implementation_internal is None:
# `config.attn_implementation` should never be None, for backward compatibility.
return "eager"
else:
return self._attn_implementation_internal
else:
return "eager"
@_attn_implementation.setter
def _attn_implementation(self, value):
self._attn_implementation_internal = value
self.decoder._attn_implementation = value
...@@ -21,9 +21,7 @@ from ..auto.configuration_auto import AutoConfig ...@@ -21,9 +21,7 @@ from ..auto.configuration_auto import AutoConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
MUSICGEN_MELODY_PRETRAINED_CONFIG_ARCHIVE_MAP = { from ..deprecated._archive_maps import MUSICGEN_MELODY_PRETRAINED_CONFIG_ARCHIVE_MAP # noqa: F401, E402
"facebook/musicgen-melody": "https://huggingface.co/facebook/musicgen-melody/resolve/main/config.json",
}
class MusicgenMelodyDecoderConfig(PretrainedConfig): class MusicgenMelodyDecoderConfig(PretrainedConfig):
...@@ -254,3 +252,20 @@ class MusicgenMelodyConfig(PretrainedConfig): ...@@ -254,3 +252,20 @@ class MusicgenMelodyConfig(PretrainedConfig):
# This is a property because you might want to change the codec model on the fly # This is a property because you might want to change the codec model on the fly
def sampling_rate(self): def sampling_rate(self):
return self.audio_encoder.sampling_rate return self.audio_encoder.sampling_rate
@property
def _attn_implementation(self):
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
if hasattr(self, "_attn_implementation_internal"):
if self._attn_implementation_internal is None:
# `config.attn_implementation` should never be None, for backward compatibility.
return "eager"
else:
return self._attn_implementation_internal
else:
return "eager"
@_attn_implementation.setter
def _attn_implementation(self, value):
self._attn_implementation_internal = value
self.decoder._attn_implementation = value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment