Unverified Commit 1c1c9075 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Add Musicgen (#24109)



* Add Audiocraft

* add cross attention

* style

* add for lm

* convert and verify

* introduce t5

* split configs

* load t5 + lm

* clean conversion

* copy from t5

* style

* start pattern provider

* make generation work

* style

* fix pos embs

* propagate shape changes

* propagate shape changes

* style

* delay pattern: pad tokens at end

* audiocraft -> musicgen

* fix inits

* add mdx

* style

* fix pad token in processor

* override generate and add todos

* add init to test

* undo pattern delay mask after gen

* remove cfg logits processor

* remove cfg logits processor

* remove logits processor in favour of mask

* clean pos embs

* make fix copies

* update readmes

* clean pos emb

* refactor encoder/decoder

* make fix copies

* update conversion

* fix config imports

* update config docs

* make style

* send pattern mask to device

* pattern mask with delay

* recover prompted audio tokens

* fix docstrings

* laydown test file

* pattern edge case

* remove t5 ref

* add processing class

* config refactor

* better pattern comment

* check if mask is not present

* check if mask is not present

* refactor to auto class

* remove encoder configs

* fix processor

* processor import

* start updating conversion

* start updating tests

* make style

* convert t5, encodec, lm

* convert as composite

* also convert processor

* run generate

* classifier free gen

* comments and clean up

* make style

* docs for logit proc

* docstring for uncond gen

* start lm tests

* work tests

* let the lm generate

* refactor: reshape inside forward

* undo greedy loop changes

* from_enc_dec -> from_sub_model

* fix input id shapes in docstrings

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* undo generate changes

* from sub model config

* Update src/transformers/models/musicgen/modeling_musicgen.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* make generate work again

* generate uncond -> get uncond inputs

* remove prefix allowed tokens fn

* better error message

* logit proc checks

* Apply suggestions from code review
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* make decoder only tests work

* composite fast tests

* make style

* uncond generation

* feat extr padding

* make audio prompt work

* fix inputs docstrings

* unconditional inputs: dict -> model output

* clean up tests

* more clean up tests

* make style

* t5 encoder -> auto text encoder

* remove comments

* deal with frames

* fix auto text

* slow tests

* nice mdx

* remove can generate

* todo - hub id

* convert m/l

* make fix copies

* only import generation with torch

* ignore decoder from tests

* don't wrap uncond inputs

* make style

* cleaner uncond inputs

* add example to musicgen forward

* fix docs

* ignore MusicGen Model/ForConditionalGeneration in auto mapping

* add doc section to toctree

* add to doc tests

* add processor tests

* fix push to hub in conversion

* tips for decoder only loading

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix conversion for s / m / l checkpoints

* import stopping criteria from module

* remove from pipeline tests

* fix uncond docstring

* decode audio method

* fix docs

* org: sanchit-gandhi -> facebook

* fix max pos embeddings

* remove auto doc (not compatible with shapes)

* bump max pos emb

* make style

* fix doc

* fix config doc

* fix config doc

* ignore musicgen config from docstring

* make style

* fix config

* fix config for doctest

* consistent from_sub_models

* don't automap decoder

* fix mdx save audio file

* fix mdx save audio file

* processor batch decode for audio

* remove keys to ignore

* update doc md

* update generation config

* allow changes for default generation config

* update tests

* make style

* fix docstring for uncond

* fix processor test

* fix processor test

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 2dc5e1a1
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert MusicGen checkpoints from the original repository."""
import argparse
from pathlib import Path
from typing import Dict, OrderedDict, Tuple
import torch
from audiocraft.models import MusicGen
from transformers import (
AutoFeatureExtractor,
AutoTokenizer,
EncodecModel,
MusicgenDecoderConfig,
MusicgenForConditionalGeneration,
MusicgenProcessor,
T5EncoderModel,
)
from transformers.models.musicgen.modeling_musicgen import MusicgenForCausalLM
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
EXPECTED_MISSING_KEYS = ["model.decoder.embed_positions.weights"]
def rename_keys(name):
if "emb" in name:
name = name.replace("emb", "model.decoder.embed_tokens")
if "transformer" in name:
name = name.replace("transformer", "model.decoder")
if "cross_attention" in name:
name = name.replace("cross_attention", "encoder_attn")
if "linear1" in name:
name = name.replace("linear1", "fc1")
if "linear2" in name:
name = name.replace("linear2", "fc2")
if "norm1" in name:
name = name.replace("norm1", "self_attn_layer_norm")
if "norm_cross" in name:
name = name.replace("norm_cross", "encoder_attn_layer_norm")
if "norm2" in name:
name = name.replace("norm2", "final_layer_norm")
if "out_norm" in name:
name = name.replace("out_norm", "model.decoder.layer_norm")
if "linears" in name:
name = name.replace("linears", "lm_heads")
if "condition_provider.conditioners.description.output_proj" in name:
name = name.replace("condition_provider.conditioners.description.output_proj", "enc_to_dec_proj")
return name
def rename_state_dict(state_dict: OrderedDict, hidden_size: int) -> Tuple[Dict, Dict]:
"""Function that takes the fairseq Musicgen state dict and renames it according to the HF
module names. It further partitions the state dict into the decoder (LM) state dict, and that for the
encoder-decoder projection."""
keys = list(state_dict.keys())
enc_dec_proj_state_dict = {}
for key in keys:
val = state_dict.pop(key)
key = rename_keys(key)
if "in_proj_weight" in key:
# split fused qkv proj
state_dict[key.replace("in_proj_weight", "q_proj.weight")] = val[:hidden_size, :]
state_dict[key.replace("in_proj_weight", "k_proj.weight")] = val[hidden_size : 2 * hidden_size, :]
state_dict[key.replace("in_proj_weight", "v_proj.weight")] = val[-hidden_size:, :]
elif "enc_to_dec_proj" in key:
enc_dec_proj_state_dict[key[len("enc_to_dec_proj.") :]] = val
else:
state_dict[key] = val
return state_dict, enc_dec_proj_state_dict
def decoder_config_from_checkpoint(checkpoint: str) -> MusicgenDecoderConfig:
if checkpoint == "small":
# default config values
hidden_size = 1024
num_hidden_layers = 24
num_attention_heads = 16
elif checkpoint == "medium":
hidden_size = 1536
num_hidden_layers = 48
num_attention_heads = 24
elif checkpoint == "large":
hidden_size = 2048
num_hidden_layers = 48
num_attention_heads = 32
else:
raise ValueError(f"Checkpoint should be one of `['small', 'medium', 'large']`, got {checkpoint}.")
config = MusicgenDecoderConfig(
hidden_size=hidden_size,
ffn_dim=hidden_size * 4,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
)
return config
@torch.no_grad()
def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=None, device="cpu"):
fairseq_model = MusicGen.get_pretrained(checkpoint, device=device)
decoder_config = decoder_config_from_checkpoint(checkpoint)
decoder_state_dict = fairseq_model.lm.state_dict()
decoder_state_dict, enc_dec_proj_state_dict = rename_state_dict(
decoder_state_dict, hidden_size=decoder_config.hidden_size
)
text_encoder = T5EncoderModel.from_pretrained("t5-base")
audio_encoder = EncodecModel.from_pretrained("facebook/encodec_32khz")
decoder = MusicgenForCausalLM(decoder_config).eval()
# load all decoder weights - expect that we'll be missing embeddings and enc-dec projection
missing_keys, unexpected_keys = decoder.load_state_dict(decoder_state_dict, strict=False)
for key in missing_keys.copy():
if key.startswith(("text_encoder", "audio_encoder")) or key in EXPECTED_MISSING_KEYS:
missing_keys.remove(key)
if len(missing_keys) > 0:
raise ValueError(f"Missing key(s) in state_dict: {missing_keys}")
if len(unexpected_keys) > 0:
raise ValueError(f"Unexpected key(s) in state_dict: {unexpected_keys}")
# init the composite model
model = MusicgenForConditionalGeneration(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder)
# load the pre-trained enc-dec projection (from the decoder state dict)
model.enc_to_dec_proj.load_state_dict(enc_dec_proj_state_dict)
# check we can do a forward pass
input_ids = torch.arange(0, 8, dtype=torch.long).reshape(2, -1)
decoder_input_ids = input_ids.reshape(2 * 4, -1)
with torch.no_grad():
logits = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits
if logits.shape != (8, 1, 2048):
raise ValueError("Incorrect shape for logits")
# now construct the processor
tokenizer = AutoTokenizer.from_pretrained("t5-base")
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/encodec_32khz", padding_side="left")
processor = MusicgenProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer)
# set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = 2048
model.generation_config.pad_token_id = 2048
# set other default generation config params
model.generation_config.max_length = int(30 * audio_encoder.config.frame_rate)
model.generation_config.do_sample = True
model.generation_config.guidance_scale = 3.0
if pytorch_dump_folder is not None:
Path(pytorch_dump_folder).mkdir(exist_ok=True)
logger.info(f"Saving model {checkpoint} to {pytorch_dump_folder}")
model.save_pretrained(pytorch_dump_folder)
processor.save_pretrained(pytorch_dump_folder)
if repo_id:
logger.info(f"Pushing model {checkpoint} to {repo_id}")
model.push_to_hub(repo_id)
processor.push_to_hub(repo_id)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--checkpoint",
default="small",
type=str,
help="Checkpoint size of the MusicGen model you'd like to convert. Can be one of: `['small', 'medium', 'large']`.",
)
parser.add_argument(
"--pytorch_dump_folder",
required=True,
default=None,
type=str,
help="Path to the output PyTorch model directory.",
)
parser.add_argument(
"--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
)
parser.add_argument(
"--device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
)
args = parser.parse_args()
convert_musicgen_checkpoint(args.checkpoint, args.pytorch_dump_folder, args.push_to_hub)
This diff is collapsed.
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Text/audio processor class for MusicGen
"""
from typing import List, Optional
import numpy as np
from ...processing_utils import ProcessorMixin
from ...utils import to_numpy
class MusicgenProcessor(ProcessorMixin):
r"""
Constructs a MusicGen processor which wraps an EnCodec feature extractor and a T5 tokenizer into a single processor
class.
[`MusicgenProcessor`] offers all the functionalities of [`EncodecFeatureExtractor`] and [`TTokenizer`]. See
[`~MusicgenProcessor.__call__`] and [`~MusicgenProcessor.decode`] for more information.
Args:
feature_extractor (`EncodecFeatureExtractor`):
An instance of [`EncodecFeatureExtractor`]. The feature extractor is a required input.
tokenizer (`T5Tokenizer`):
An instance of [`T5Tokenizer`]. The tokenizer is a required input.
"""
feature_extractor_class = "EncodecFeatureExtractor"
tokenizer_class = ("T5Tokenizer", "T5TokenizerFast")
def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer)
self.current_processor = self.feature_extractor
self._in_target_context_manager = False
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
return self.tokenizer.get_decoder_prompt_ids(task=task, language=language, no_timestamps=no_timestamps)
def __call__(self, *args, **kwargs):
"""
Forwards the `audio` argument to EncodecFeatureExtractor's [`~EncodecFeatureExtractor.__call__`] and the `text`
argument to [`~T5Tokenizer.__call__`]. Please refer to the doctsring of the above two methods for more
information.
"""
# For backward compatibility
if self._in_target_context_manager:
return self.current_processor(*args, **kwargs)
audio = kwargs.pop("audio", None)
sampling_rate = kwargs.pop("sampling_rate", None)
text = kwargs.pop("text", None)
if len(args) > 0:
audio = args[0]
args = args[1:]
if audio is None and text is None:
raise ValueError("You need to specify either an `audio` or `text` input to process.")
if text is not None:
inputs = self.tokenizer(text, **kwargs)
if audio is not None:
audio_inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs)
if audio is None:
return inputs
elif text is None:
return audio_inputs
else:
inputs["input_values"] = audio_inputs["input_values"]
if "padding_mask" in audio_inputs:
inputs["padding_mask"] = audio_inputs["padding_mask"]
return inputs
def batch_decode(self, *args, **kwargs):
"""
This method is used to decode either batches of audio outputs from the MusicGen model, or batches of token ids
from the tokenizer. In the case of decoding token ids, this method forwards all its arguments to T5Tokenizer's
[`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information.
"""
audio_values = kwargs.pop("audio", None)
padding_mask = kwargs.pop("padding_mask", None)
if len(args) > 0:
audio_values = args[0]
args = args[1:]
if audio_values is not None:
return self._decode_audio(audio_values, padding_mask=padding_mask)
else:
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to T5Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the
docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
def _decode_audio(self, audio_values, padding_mask: Optional = None) -> List[np.ndarray]:
"""
This method strips any padding from the audio values to return a list of numpy audio arrays.
"""
audio_values = to_numpy(audio_values)
bsz, channels, seq_len = audio_values.shape
if padding_mask is None:
return list(audio_values)
padding_mask = to_numpy(padding_mask)
# match the sequence length of the padding mask to the generated audio arrays by padding with the **non-padding**
# token (so that the generated audio values are **not** treated as padded tokens)
difference = seq_len - padding_mask.shape[-1]
padding_value = 1 - self.feature_extractor.padding_value
padding_mask = np.pad(padding_mask, ((0, 0), (0, difference)), "constant", constant_values=padding_value)
audio_values = audio_values.tolist()
for i in range(bsz):
sliced_audio = np.asarray(audio_values[i])[
padding_mask[i][None, :] != self.feature_extractor.padding_value
]
audio_values[i] = sliced_audio.reshape(channels, -1)
return audio_values
......@@ -4936,6 +4936,44 @@ class MT5PreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = None
class MusicgenForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MusicgenForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MusicgenModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MusicgenPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MusicgenProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
MVP_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
This diff is collapsed.
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the MusicGen processor."""
import random
import shutil
import tempfile
import unittest
import numpy as np
from transformers import T5Tokenizer, T5TokenizerFast
from transformers.testing_utils import require_sentencepiece, require_torch
from transformers.utils.import_utils import is_speech_available, is_torch_available
if is_torch_available():
pass
if is_speech_available():
from transformers import EncodecFeatureExtractor, MusicgenProcessor
global_rng = random.Random()
def floats_list(shape, scale=1.0, rng=None, name=None):
"""Creates a random float32 tensor"""
if rng is None:
rng = global_rng
values = []
for batch_idx in range(shape[0]):
values.append([])
for _ in range(shape[1]):
values[-1].append(rng.random() * scale)
return values
@require_torch
@require_sentencepiece
class MusicgenProcessorTest(unittest.TestCase):
def setUp(self):
self.checkpoint = "facebook/musicgen-small"
self.tmpdirname = tempfile.mkdtemp()
def get_tokenizer(self, **kwargs):
return T5Tokenizer.from_pretrained(self.checkpoint, **kwargs)
def get_feature_extractor(self, **kwargs):
return EncodecFeatureExtractor.from_pretrained(self.checkpoint, **kwargs)
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_save_load_pretrained_default(self):
tokenizer = self.get_tokenizer()
feature_extractor = self.get_feature_extractor()
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
processor.save_pretrained(self.tmpdirname)
processor = MusicgenProcessor.from_pretrained(self.tmpdirname)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
self.assertIsInstance(processor.tokenizer, T5TokenizerFast)
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
self.assertIsInstance(processor.feature_extractor, EncodecFeatureExtractor)
def test_save_load_pretrained_additional_features(self):
processor = MusicgenProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
processor.save_pretrained(self.tmpdirname)
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0)
processor = MusicgenProcessor.from_pretrained(
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
self.assertIsInstance(processor.tokenizer, T5TokenizerFast)
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, EncodecFeatureExtractor)
def test_feature_extractor(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
raw_speech = floats_list((3, 1000))
input_feat_extract = feature_extractor(raw_speech, return_tensors="np")
input_processor = processor(raw_speech, return_tensors="np")
for key in input_feat_extract.keys():
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
def test_tokenizer(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "This is a test string"
encoded_processor = processor(text=input_str)
encoded_tok = tokenizer(input_str)
for key in encoded_tok.keys():
self.assertListEqual(encoded_tok[key], encoded_processor[key])
def test_tokenizer_decode(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
decoded_processor = processor.batch_decode(sequences=predicted_ids)
decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.assertListEqual(
processor.model_input_names,
feature_extractor.model_input_names,
msg="`processor` and `feature_extractor` model input names do not match",
)
def test_decode_audio(self):
feature_extractor = self.get_feature_extractor(padding_side="left")
tokenizer = self.get_tokenizer()
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
raw_speech = [floats_list((1, x))[0] for x in range(5, 20, 5)]
padding_mask = processor(raw_speech).padding_mask
generated_speech = np.asarray(floats_list((3, 20)))[:, None, :]
decoded_audios = processor.batch_decode(generated_speech, padding_mask=padding_mask)
self.assertIsInstance(decoded_audios, list)
for audio in decoded_audios:
self.assertIsInstance(audio, np.ndarray)
self.assertTrue(decoded_audios[0].shape == (1, 10))
self.assertTrue(decoded_audios[1].shape == (1, 15))
self.assertTrue(decoded_audios[2].shape == (1, 20))
......@@ -37,6 +37,7 @@ _re_checkpoint = re.compile(r"\[(.+?)\]\((https://huggingface\.co/.+?)\)")
CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
"DecisionTransformerConfig",
"EncoderDecoderConfig",
"MusicgenConfig",
"RagConfig",
"SpeechEncoderDecoderConfig",
"TimmBackboneConfig",
......
......@@ -119,6 +119,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
"MegatronBertEncoder", # Building part of bigger (tested) model.
"MegatronBertDecoder", # Building part of bigger (tested) model.
"MegatronBertDecoderWrapper", # Building part of bigger (tested) model.
"MusicgenDecoder", # Building part of bigger (tested) model.
"MvpDecoderWrapper", # Building part of bigger (tested) model.
"MvpEncoder", # Building part of bigger (tested) model.
"PegasusEncoder", # Building part of bigger (tested) model.
......@@ -331,6 +332,8 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"SpeechT5ForSpeechToSpeech",
"SpeechT5ForTextToSpeech",
"SpeechT5HifiGan",
"MusicgenModel",
"MusicgenForConditionalGeneration",
]
# Update this list for models that have multiple model types for the same
......
......@@ -134,6 +134,9 @@ src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py
src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py
src/transformers/models/mobilevit/modeling_mobilevit.py
src/transformers/models/mobilevit/modeling_tf_mobilevit.py
src/transformers/models/musicgen/modeling_musicgen.py
src/transformers/models/musicgen/configuration_musicgen.py
src/transformers/models/musicgen/processing_musicgen.py
src/transformers/models/mvp/configuration_mvp.py
src/transformers/models/nat/configuration_nat.py
src/transformers/models/nat/modeling_nat.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment