"tests/t5/test_tokenization_t5.py" did not exist on "b13c6c18d007c63eb2bcd5cee8f2c40d8ca0d229"
Unverified Commit 1c1c9075 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Add Musicgen (#24109)



* Add Audiocraft

* add cross attention

* style

* add for lm

* convert and verify

* introduce t5

* split configs

* load t5 + lm

* clean conversion

* copy from t5

* style

* start pattern provider

* make generation work

* style

* fix pos embs

* propagate shape changes

* propagate shape changes

* style

* delay pattern: pad tokens at end

* audiocraft -> musicgen

* fix inits

* add mdx

* style

* fix pad token in processor

* override generate and add todos

* add init to test

* undo pattern delay mask after gen

* remove cfg logits processor

* remove cfg logits processor

* remove logits processor in favour of mask

* clean pos embs

* make fix copies

* update readmes

* clean pos emb

* refactor encoder/decoder

* make fix copies

* update conversion

* fix config imports

* update config docs

* make style

* send pattern mask to device

* pattern mask with delay

* recover prompted audio tokens

* fix docstrings

* laydown test file

* pattern edge case

* remove t5 ref

* add processing class

* config refactor

* better pattern comment

* check if mask is not present

* check if mask is not present

* refactor to auto class

* remove encoder configs

* fix processor

* processor import

* start updating conversion

* start updating tests

* make style

* convert t5, encodec, lm

* convert as composite

* also convert processor

* run generate

* classifier free gen

* comments and clean up

* make style

* docs for logit proc

* docstring for uncond gen

* start lm tests

* work tests

* let the lm generate

* refactor: reshape inside forward

* undo greedy loop changes

* from_enc_dec -> from_sub_model

* fix input id shapes in docstrings

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* undo generate changes

* from sub model config

* Update src/transformers/models/musicgen/modeling_musicgen.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* make generate work again

* generate uncond -> get uncond inputs

* remove prefix allowed tokens fn

* better error message

* logit proc checks

* Apply suggestions from code review
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* make decoder only tests work

* composite fast tests

* make style

* uncond generation

* feat extr padding

* make audio prompt work

* fix inputs docstrings

* unconditional inputs: dict -> model output

* clean up tests

* more clean up tests

* make style

* t5 encoder -> auto text encoder

* remove comments

* deal with frames

* fix auto text

* slow tests

* nice mdx

* remove can generate

* todo - hub id

* convert m/l

* make fix copies

* only import generation with torch

* ignore decoder from tests

* don't wrap uncond inputs

* make style

* cleaner uncond inputs

* add example to musicgen forward

* fix docs

* ignore MusicGen Model/ForConditionalGeneration in auto mapping

* add doc section to toctree

* add to doc tests

* add processor tests

* fix push to hub in conversion

* tips for decoder only loading

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix conversion for s / m / l checkpoints

* import stopping criteria from module

* remove from pipeline tests

* fix uncond docstring

* decode audio method

* fix docs

* org: sanchit-gandhi -> facebook

* fix max pos embeddings

* remove auto doc (not compatible with shapes)

* bump max pos emb

* make style

* fix doc

* fix config doc

* fix config doc

* ignore musicgen config from docstring

* make style

* fix config

* fix config for doctest

* consistent from_sub_models

* don't automap decoder

* fix mdx save audio file

* fix mdx save audio file

* processor batch decode for audio

* remove keys to ignore

* update doc md

* update generation config

* allow changes for default generation config

* update tests

* make style

* fix docstring for uncond

* fix processor test

* fix processor test

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 2dc5e1a1
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert MusicGen checkpoints from the original repository."""
import argparse
from pathlib import Path
from typing import Dict, OrderedDict, Tuple
import torch
from audiocraft.models import MusicGen
from transformers import (
AutoFeatureExtractor,
AutoTokenizer,
EncodecModel,
MusicgenDecoderConfig,
MusicgenForConditionalGeneration,
MusicgenProcessor,
T5EncoderModel,
)
from transformers.models.musicgen.modeling_musicgen import MusicgenForCausalLM
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
EXPECTED_MISSING_KEYS = ["model.decoder.embed_positions.weights"]
def rename_keys(name):
if "emb" in name:
name = name.replace("emb", "model.decoder.embed_tokens")
if "transformer" in name:
name = name.replace("transformer", "model.decoder")
if "cross_attention" in name:
name = name.replace("cross_attention", "encoder_attn")
if "linear1" in name:
name = name.replace("linear1", "fc1")
if "linear2" in name:
name = name.replace("linear2", "fc2")
if "norm1" in name:
name = name.replace("norm1", "self_attn_layer_norm")
if "norm_cross" in name:
name = name.replace("norm_cross", "encoder_attn_layer_norm")
if "norm2" in name:
name = name.replace("norm2", "final_layer_norm")
if "out_norm" in name:
name = name.replace("out_norm", "model.decoder.layer_norm")
if "linears" in name:
name = name.replace("linears", "lm_heads")
if "condition_provider.conditioners.description.output_proj" in name:
name = name.replace("condition_provider.conditioners.description.output_proj", "enc_to_dec_proj")
return name
def rename_state_dict(state_dict: OrderedDict, hidden_size: int) -> Tuple[Dict, Dict]:
"""Function that takes the fairseq Musicgen state dict and renames it according to the HF
module names. It further partitions the state dict into the decoder (LM) state dict, and that for the
encoder-decoder projection."""
keys = list(state_dict.keys())
enc_dec_proj_state_dict = {}
for key in keys:
val = state_dict.pop(key)
key = rename_keys(key)
if "in_proj_weight" in key:
# split fused qkv proj
state_dict[key.replace("in_proj_weight", "q_proj.weight")] = val[:hidden_size, :]
state_dict[key.replace("in_proj_weight", "k_proj.weight")] = val[hidden_size : 2 * hidden_size, :]
state_dict[key.replace("in_proj_weight", "v_proj.weight")] = val[-hidden_size:, :]
elif "enc_to_dec_proj" in key:
enc_dec_proj_state_dict[key[len("enc_to_dec_proj.") :]] = val
else:
state_dict[key] = val
return state_dict, enc_dec_proj_state_dict
def decoder_config_from_checkpoint(checkpoint: str) -> MusicgenDecoderConfig:
if checkpoint == "small":
# default config values
hidden_size = 1024
num_hidden_layers = 24
num_attention_heads = 16
elif checkpoint == "medium":
hidden_size = 1536
num_hidden_layers = 48
num_attention_heads = 24
elif checkpoint == "large":
hidden_size = 2048
num_hidden_layers = 48
num_attention_heads = 32
else:
raise ValueError(f"Checkpoint should be one of `['small', 'medium', 'large']`, got {checkpoint}.")
config = MusicgenDecoderConfig(
hidden_size=hidden_size,
ffn_dim=hidden_size * 4,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
)
return config
@torch.no_grad()
def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=None, device="cpu"):
fairseq_model = MusicGen.get_pretrained(checkpoint, device=device)
decoder_config = decoder_config_from_checkpoint(checkpoint)
decoder_state_dict = fairseq_model.lm.state_dict()
decoder_state_dict, enc_dec_proj_state_dict = rename_state_dict(
decoder_state_dict, hidden_size=decoder_config.hidden_size
)
text_encoder = T5EncoderModel.from_pretrained("t5-base")
audio_encoder = EncodecModel.from_pretrained("facebook/encodec_32khz")
decoder = MusicgenForCausalLM(decoder_config).eval()
# load all decoder weights - expect that we'll be missing embeddings and enc-dec projection
missing_keys, unexpected_keys = decoder.load_state_dict(decoder_state_dict, strict=False)
for key in missing_keys.copy():
if key.startswith(("text_encoder", "audio_encoder")) or key in EXPECTED_MISSING_KEYS:
missing_keys.remove(key)
if len(missing_keys) > 0:
raise ValueError(f"Missing key(s) in state_dict: {missing_keys}")
if len(unexpected_keys) > 0:
raise ValueError(f"Unexpected key(s) in state_dict: {unexpected_keys}")
# init the composite model
model = MusicgenForConditionalGeneration(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder)
# load the pre-trained enc-dec projection (from the decoder state dict)
model.enc_to_dec_proj.load_state_dict(enc_dec_proj_state_dict)
# check we can do a forward pass
input_ids = torch.arange(0, 8, dtype=torch.long).reshape(2, -1)
decoder_input_ids = input_ids.reshape(2 * 4, -1)
with torch.no_grad():
logits = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits
if logits.shape != (8, 1, 2048):
raise ValueError("Incorrect shape for logits")
# now construct the processor
tokenizer = AutoTokenizer.from_pretrained("t5-base")
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/encodec_32khz", padding_side="left")
processor = MusicgenProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer)
# set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = 2048
model.generation_config.pad_token_id = 2048
# set other default generation config params
model.generation_config.max_length = int(30 * audio_encoder.config.frame_rate)
model.generation_config.do_sample = True
model.generation_config.guidance_scale = 3.0
if pytorch_dump_folder is not None:
Path(pytorch_dump_folder).mkdir(exist_ok=True)
logger.info(f"Saving model {checkpoint} to {pytorch_dump_folder}")
model.save_pretrained(pytorch_dump_folder)
processor.save_pretrained(pytorch_dump_folder)
if repo_id:
logger.info(f"Pushing model {checkpoint} to {repo_id}")
model.push_to_hub(repo_id)
processor.push_to_hub(repo_id)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--checkpoint",
default="small",
type=str,
help="Checkpoint size of the MusicGen model you'd like to convert. Can be one of: `['small', 'medium', 'large']`.",
)
parser.add_argument(
"--pytorch_dump_folder",
required=True,
default=None,
type=str,
help="Path to the output PyTorch model directory.",
)
parser.add_argument(
"--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
)
parser.add_argument(
"--device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
)
args = parser.parse_args()
convert_musicgen_checkpoint(args.checkpoint, args.pytorch_dump_folder, args.push_to_hub)
# coding=utf-8
# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Musicgen model."""
import copy
import inspect
import math
import random
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint
from ...activations import ACT2FN
from ...generation.configuration_utils import GenerationConfig
from ...generation.logits_process import LogitsProcessorList
from ...generation.stopping_criteria import StoppingCriteriaList
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
ModelOutput,
Seq2SeqLMOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from ..auto.configuration_auto import AutoConfig
from ..auto.modeling_auto import AutoModel
from .configuration_musicgen import MusicgenConfig, MusicgenDecoderConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MusicgenConfig"
_CHECKPOINT_FOR_DOC = "facebook/musicgen-small"
MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/musicgen-small",
# See all Musicgen models at https://huggingface.co/models?filter=musicgen
]
@dataclass
class MusicgenUnconditionalInput(ModelOutput):
"""
Args:
encoder_outputs (`Tuple[torch.FloatTensor]` of length 1, with tensor shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the text encoder model.
attention_mask (`torch.LongTensor`) of shape `(batch_size, sequence_length)`, *optional*):
Encoder attention mask to avoid performing attention on padding token indices. Mask values selected in `[0,
1]`: 1 for tokens that are **not masked**, 0 for tokens that are **masked**.
guidance_scale (`float`, *optional*):
Guidance scale for classifier free guidance, setting the balance between the conditional logits (predicted
from the prompts) and the unconditional logits (predicted without prompts).
"""
encoder_outputs: Tuple[torch.FloatTensor] = None
attention_mask: torch.LongTensor = None
guidance_scale: float = None
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
class MusicgenSinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions: int, embedding_dim: int):
super().__init__()
self.embedding_dim = embedding_dim
self.make_weights(num_positions, embedding_dim)
def make_weights(self, num_embeddings: int, embedding_dim: int):
emb_weights = self.get_embedding(num_embeddings, embedding_dim)
if hasattr(self, "weights"):
# in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
self.weights = nn.Parameter(emb_weights)
self.weights.requires_grad = False
self.weights.detach_()
@staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int):
"""
Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the
description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
return emb.to(torch.get_default_dtype())
@torch.no_grad()
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
bsz, codebooks, seq_len = input_ids.size()
# Create the position ids from the input token ids.
position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device)
# expand embeddings if needed
if seq_len > self.weights.size(0):
self.make_weights(seq_len + self.offset, self.embedding_dim)
return self.weights.index_select(0, position_ids.view(-1)).detach()
# Copied from transformers.models.bart.modeling_bart.BartAttention
class MusicgenAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.reshape(*proj_shape)
value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
class MusicgenDecoderLayer(nn.Module):
def __init__(self, config: MusicgenDecoderConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = MusicgenAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
bias=False,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = MusicgenAttention(
self.embed_dim,
config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
bias=False,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False)
self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=False)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
encoder_hidden_states (`torch.FloatTensor`):
cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(decoder_attention_heads,)`.
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
# Cross-Attention Block
cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
# add cross-attn to positions 3,4 of present_key_value tuple
present_key_value = present_key_value + cross_attn_present_key_value
# Fully Connected
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
if use_cache:
outputs += (present_key_value,)
return outputs
class MusicgenPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = MusicgenDecoderConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"]
def _init_weights(self, module):
std = self.config.initializer_factor
if isinstance(module, (nn.Linear, nn.Conv1d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, MusicgenDecoder):
module.gradient_checkpointing = value
MUSICGEN_START_DOCSTRING = r"""
The Musicgen model was proposed in [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by
Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi, Alexandre Défossez. It is an
encoder decoder transformer trained on the task of conditional music generation
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`MusicgenConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
MUSICGEN_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary, corresponding to the sequence of audio codes.
Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes,
such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details.
[What are decoder input IDs?](../glossary#decoder-input-ids)
<Tip warning={true}>
The `decoder_input_ids` will automatically be converted from shape `(batch_size * num_codebooks,
target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If
you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of
frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks,
target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as
`decoder_input_ids`.
</Tip>
decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape
`(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you
can choose to directly pass an embedded representation. This is useful if you want more control over how to
convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
input (see `past_key_values`). This is useful if you want more control over how to convert
`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
of `inputs_embeds`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
MUSICGEN_DECODER_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`):
Indices of input sequence tokens in the vocabulary, corresponding to the sequence of audio codes.
Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes,
such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details.
[What are input IDs?](../glossary#input-ids)
<Tip warning={true}>
The `input_ids` will automatically be converted from shape `(batch_size * num_codebooks,
target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If
you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of
frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks,
target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as
`input_ids`.
</Tip>
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
the decoder.
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
cross-attention on hidden heads. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape
`(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you
can choose to directly pass an embedded representation. This is useful if you want more control over how to
convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class MusicgenDecoder(MusicgenPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MusicgenDecoderLayer`]
"""
def __init__(self, config: MusicgenDecoderConfig):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.layerdrop
self.max_target_positions = config.max_position_embeddings
self.d_model = config.hidden_size
self.num_codebooks = config.num_codebooks
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
embed_dim = config.vocab_size + 1
self.embed_tokens = nn.ModuleList(
[nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)]
)
self.embed_positions = MusicgenSinusoidalPositionalEmbedding(
config.max_position_embeddings,
config.hidden_size,
)
self.layers = nn.ModuleList([MusicgenDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layer_norm = nn.LayerNorm(config.hidden_size)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
@add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
# (bsz * codebooks, seq_len) -> (bsz, codebooks, seq_len)
input = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
bsz, num_codebooks, seq_len = input.shape
input_shape = (bsz, seq_len)
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1:]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = torch.zeros((bsz, seq_len, self.d_model), device=input_ids.device)
for codebook in range(num_codebooks):
inputs_embeds += self.embed_tokens[codebook](input[:, codebook])
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
# embed positions
positions = self.embed_positions(input, past_key_values_length)
hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
if attn_mask.size()[0] != len(self.layers):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {attn_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
hidden_states = self.layer_norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
@add_start_docstrings(
"The bare Musicgen decoder model outputting raw hidden-states without any specific head on top.",
MUSICGEN_START_DOCSTRING,
)
class MusicgenModel(MusicgenPreTrainedModel):
def __init__(self, config: MusicgenDecoderConfig):
super().__init__(config)
self.decoder = MusicgenDecoder(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.decoder.embed_tokens
def set_input_embeddings(self, value):
self.decoder.embed_tokens = value
def get_decoder(self):
return self.decoder
@add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return decoder_outputs
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
hidden_states=decoder_outputs.hidden_states,
attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
)
@add_start_docstrings(
"The MusicGen decoder model with a language modelling head on top.",
MUSICGEN_START_DOCSTRING,
)
class MusicgenForCausalLM(MusicgenPreTrainedModel):
def __init__(self, config: MusicgenDecoderConfig):
super().__init__(config)
self.model = MusicgenModel(config)
self.num_codebooks = config.num_codebooks
self.lm_heads = nn.ModuleList(
[nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_codebooks)]
)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.decoder.embed_tokens
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value
def get_output_embeddings(self):
return self.lm_heads
def set_output_embeddings(self, new_embeddings):
self.lm_heads = new_embeddings
def set_decoder(self, decoder):
self.model.decoder = decoder
def get_decoder(self):
return self.model.decoder
@add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
Returns:
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
lm_logits = torch.stack([head(hidden_states) for head in self.lm_heads], dim=1)
loss = None
if labels is not None:
raise NotImplementedError("Training is not implemented for Musicgen.")
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
use_cache=True,
delay_pattern_mask=None,
guidance_scale=None,
**kwargs,
):
if delay_pattern_mask is None:
input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
input_ids,
pad_token_id=self.generation_config.pad_token_id,
max_length=self.generation_config.max_length,
)
# apply the delay pattern mask
input_ids = self.apply_delay_pattern_mask(input_ids, delay_pattern_mask)
if guidance_scale is not None and guidance_scale > 1:
# for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these
# before sampling)
input_ids = input_ids.repeat((2, 1))
if attention_mask is not None:
attention_mask = attention_mask.repeat((2, 1))
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
"head_mask": head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
def build_delay_pattern_mask(self, input_ids: torch.LongTensor, pad_token_id: int, max_length: int = None):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
seq_len)`:
- [P, -1, -1, -1, -1, P, P, P]
- [P, P, -1, -1, -1, -1, P, P]
- [P, P, P, -1, -1, -1, -1, P]
- [P, P, P, P, -1, -1, -1, -1]
where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include
a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
mask is set to the value in the prompt:
- [P, a, b, -1, -1, P, P, P]
- [P, P, c, d, -1, -1, P, P]
- [P, P, P, e, f, -1, -1, P]
- [P, P, P, P, g, h, -1, -1]
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
tokens in our prediction.
"""
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
input_ids = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
bsz, num_codebooks, seq_len = input_ids.shape
max_length = max_length if max_length is not None else self.generation_config.max_length
input_ids_shifted = (
torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
)
# we only apply the mask if we have a large enough seq len - otherwise we return as is
if max_length < 2 * num_codebooks - 1:
return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1)
# fill the shifted ids with the prompt entries, offset by the codebook idx
for codebook in range(num_codebooks):
input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
# construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding)
delay_pattern = torch.triu(
torch.ones((num_codebooks, max_length), dtype=torch.bool), diagonal=max_length - num_codebooks + 1
)
# then fill the lower triangular part (the BOS padding)
delay_pattern = delay_pattern + torch.tril(torch.ones((num_codebooks, max_length), dtype=torch.bool))
mask = ~delay_pattern.to(input_ids.device)
input_ids = mask * input_ids_shifted + ~mask * pad_token_id
# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
first_codebook_ids = input_ids[:, 0, :]
start_ids = (first_codebook_ids == -1).nonzero()[:, 1]
if len(start_ids) > 0:
first_start_id = min(start_ids)
else:
# we have no tokens that need to be filled - return entire matrix of input ids
first_start_id = seq_len
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
pattern_mask = input_ids.reshape(bsz * num_codebooks, -1)
input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
return input_ids, pattern_mask
@staticmethod
def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
"""Apply a delay pattern mask to the decoder input ids, only preserving predictions where
the mask is set to -1, and otherwise setting to the value detailed in the mask."""
seq_len = input_ids.shape[-1]
decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len]
input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask)
return input_ids
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
synced_gpus: Optional[bool] = None,
**kwargs,
):
"""
Generates sequences of token ids for models with a language modeling head.
<Tip warning={true}>
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
model's default generation configuration. You can override any `generation_config` by passing the corresponding
parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
For an overview of generation strategies and code examples, check out the [following
guide](./generation_strategies).
</Tip>
Parameters:
inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of
`input_ids`, `input_values`, `input_features`, or `pixel_values`.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
logits_processor (`LogitsProcessorList`, *optional*):
Custom logits processors that complement the default logits processors built from arguments and
generation config. If a logit processor is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
Return:
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchDecoderOnlyOutput`],
- [`~generation.SampleDecoderOnlyOutput`],
- [`~generation.BeamSearchDecoderOnlyOutput`],
- [`~generation.BeamSampleDecoderOnlyOutput`]
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
"""
# 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
if generation_config is None:
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())
# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
# 3. Define model inputs
# inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed
# otherwise model_input_name is None
# all model-specific keyword inputs are removed from `model_kwargs`
input_ids, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = input_ids.shape[0] // self.num_codebooks
# 4. Define other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
model_kwargs["guidance_scale"] = generation_config.guidance_scale
requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
input_ids, generation_config.pad_token_id, generation_config.eos_token_id
)
# 5. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
logger.warning(
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif generation_config.max_new_tokens is not None:
if not has_default_max_length:
logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError(
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
f" the maximum length ({generation_config.max_length})"
)
if input_ids_seq_length >= generation_config.max_length:
logger.warning(
f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
# 6. Prepare `input_ids` which will be used for auto-regressive generation
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
input_ids,
pad_token_id=generation_config.decoder_start_token_id,
max_length=generation_config.max_length,
)
# stash the delay mask so that we don't have to recompute it in each forward pass
model_kwargs["delay_pattern_mask"] = delay_pattern_mask
# 7. determine generation mode
is_greedy_gen_mode = (
(generation_config.num_beams == 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is False
)
is_sample_gen_mode = (
(generation_config.num_beams == 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is True
)
# 8. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
)
# 9. prepare stopping criteria
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
if is_greedy_gen_mode:
if generation_config.num_return_sequences > 1:
raise ValueError(
"num_return_sequences has to be 1 when doing greedy search, "
f"but is {generation_config.num_return_sequences}."
)
# 8. run greedy search
outputs = self.greedy_search(
input_ids,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
elif is_sample_gen_mode:
# 9. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
# expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences,
**model_kwargs,
)
# 10. run sample
outputs = self.sample(
input_ids,
logits_processor=logits_processor,
logits_warper=logits_warper,
stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
else:
raise ValueError(
"Got incompatible mode for generation, should be one of greedy or sampling."
"Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`."
)
if generation_config.return_dict_in_generate:
output_ids = outputs.sequences
else:
output_ids = outputs
# apply the pattern mask to the final ids
output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
# revert the pattern delay mask by filtering the pad token id
output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape(
batch_size, self.num_codebooks, -1
)
if generation_config.return_dict_in_generate:
outputs.sequences = output_ids
return outputs
else:
return output_ids
@add_start_docstrings(
"The composite MusicGen model with a text encoder, audio encoder and Musicgen decoder,"
"for music generation tasks with one or both of text and audio prompts.",
MUSICGEN_START_DOCSTRING,
)
class MusicgenForConditionalGeneration(PreTrainedModel):
config_class = MusicgenConfig
base_model_prefix = "encoder_decoder"
main_input_name = "input_ids"
supports_gradient_checkpointing = True
def __init__(
self,
config: Optional[MusicgenConfig] = None,
text_encoder: Optional[PreTrainedModel] = None,
audio_encoder: Optional[PreTrainedModel] = None,
decoder: Optional[MusicgenForCausalLM] = None,
):
if config is None and (text_encoder is None or audio_encoder is None or decoder is None):
raise ValueError(
"Either a configuration has to be provided, or all three of text encoder, audio encoder and MusicGen decoder."
)
if config is None:
config = MusicgenConfig.from_sub_models_config(text_encoder.config, audio_encoder.config, decoder.config)
else:
if not isinstance(config, self.config_class):
raise ValueError(f"Config: {config} has to be of type {self.config_class}")
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.text_encoder.hidden_size:
raise ValueError(
"If `cross_attention_hidden_size` is specified in the MusicGen decoder's configuration, it has to be equal"
f" to the text encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
f" `config.decoder.cross_attention_hidden_size` and {config.text_encoder.hidden_size} for"
" `config.text_encoder.hidden_size`."
)
# initialize with config
super().__init__(config)
if text_encoder is None:
from ..auto.modeling_auto import AutoModelForTextEncoding
text_encoder = AutoModelForTextEncoding.from_config(config.text_encoder)
if audio_encoder is None:
from ..auto.modeling_auto import AutoModel
audio_encoder = AutoModel.from_config(config.audio_encoder)
if decoder is None:
decoder = MusicgenForCausalLM(config.decoder)
self.text_encoder = text_encoder
self.audio_encoder = audio_encoder
self.decoder = decoder
if self.text_encoder.config.to_dict() != self.config.text_encoder.to_dict():
logger.warning(
f"Config of the text_encoder: {self.text_encoder.__class__} is overwritten by shared text_encoder config:"
f" {self.config.text_encoder}"
)
if self.audio_encoder.config.to_dict() != self.config.audio_encoder.to_dict():
logger.warning(
f"Config of the audio_encoder: {self.audio_encoder.__class__} is overwritten by shared audio_encoder config:"
f" {self.config.audio_encoder}"
)
if self.decoder.config.to_dict() != self.config.decoder.to_dict():
logger.warning(
f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
f" {self.config.decoder}"
)
# make sure that the individual model's config refers to the shared config
# so that the updates to the config will be synced
self.text_encoder.config = self.config.text_encoder
self.audio_encoder.config = self.config.audio_encoder
self.decoder.config = self.config.decoder
# text encoder outputs might need to be projected to different dimension for decoder
if (
self.text_encoder.config.hidden_size != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
self.enc_to_dec_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size)
if self.text_encoder.get_output_embeddings() is not None:
raise ValueError(
f"The encoder {self.text_encoder} should not have a LM Head. Please use a model without and LM Head"
)
decoder_signature = set(inspect.signature(self.decoder.forward).parameters.keys())
if "encoder_hidden_states" not in decoder_signature:
raise ValueError(
"The selected decoder is not prepared for the encoder hidden states to be passed. Please see the "
"following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
)
# tie text encoder, decoder weights if config set accordingly
self.tie_weights()
def tie_weights(self):
# tie text encoder & decoder if needed
if self.config.tie_encoder_decoder:
# tie text encoder and decoder base model
decoder_base_model_prefix = self.decoder.base_model_prefix
self._tie_encoder_decoder_weights(
self.text_encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix
)
def _set_gradient_checkpointing(self, module, value=False):
# call both encoder and decoder function on gradient checkpointing
self.text_encoder._set_gradient_checkpointing(module, value=value)
self.decoder._set_gradient_checkpointing(module, value=value)
def get_audio_encoder(self):
return self.audio_encoder
def get_text_encoder(self):
return self.text_encoder
def get_encoder(self):
# get the text encoder to compute the encoder hidden-states for generation
return self.get_text_encoder()
def get_decoder(self):
return self.decoder
def get_input_embeddings(self):
return self.text_encoder.get_input_embeddings()
def get_output_embeddings(self):
return self.decoder.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
return self.decoder.set_output_embeddings(new_embeddings)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""
Example:
```python
>>> from transformers import MusicgenForConditionalGeneration
>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
```"""
# At the moment fast initialization is not supported for composite models
if kwargs.get("_fast_init", False):
logger.warning(
"Fast initialization is currently not supported for MusicgenForConditionalGeneration. "
"Falling back to slow initialization..."
)
kwargs["_fast_init"] = False
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
@classmethod
def from_sub_models_pretrained(
cls,
text_encoder_pretrained_model_name_or_path: str = None,
audio_encoder_pretrained_model_name_or_path: str = None,
decoder_pretrained_model_name_or_path: str = None,
*model_args,
**kwargs,
) -> PreTrainedModel:
r"""
Instantiate a text encoder, an audio encoder, and a MusicGen decoder from one, two or three base classes of the
library from pretrained model checkpoints.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
the model, you need to first set it back in training mode with `model.train()`.
Params:
text_encoder_pretrained_model_name_or_path (`str`, *optional*):
Information necessary to initiate the text encoder. Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like `t5-base`, or namespaced under a user or
organization name, like `google/flan-t5-base.
- A path to a *directory* containing model weights saved using
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
audio_encoder_pretrained_model_name_or_path (`str`, *optional*):
Information necessary to initiate the audio encoder. Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
user or organization name, like `facebook/encodec_24khz`.
- A path to a *directory* containing model weights saved using
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
Information necessary to initiate the decoder. Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like `gpt2`, or namespaced under a user or
organization name, like `facebook/musicgen-small`.
- A path to a *directory* containing model weights saved using
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
model_args (remaining positional arguments, *optional*):
All remaining positional arguments will be passed to the underlying model's `__init__` method.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`).
- To update the text encoder configuration, use the prefix *text_encoder_* for each configuration
parameter.
- To update the audio encoder configuration, use the prefix *audio_encoder_* for each configuration
parameter.
- To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
- To update the parent model configuration, do not use a prefix for each configuration parameter.
Behaves differently depending on whether a `config` is provided or automatically loaded.
Example:
```python
>>> from transformers import MusicgenForConditionalGeneration
>>> # initialize a musicgen model from a t5 text encoder, encodec audio encoder, and musicgen decoder
>>> model = MusicgenForConditionalGeneration.from_sub_models_pretrained(
... text_encoder_pretrained_model_name_or_path="t5-base",
... audio_encoder_pretrained_model_name_or_path="facebook/encodec_24khz",
... decoder_pretrained_model_name_or_path="facebook/musicgen-small",
... )
>>> # saving model after fine-tuning
>>> model.save_pretrained("./musicgen-ft")
>>> # load fine-tuned model
>>> model = MusicgenForConditionalGeneration.from_pretrained("./musicgen-ft")
```"""
kwargs_text_encoder = {
argument[len("text_encoder_") :]: value
for argument, value in kwargs.items()
if argument.startswith("text_encoder_")
}
kwargs_audio_encoder = {
argument[len("audio_encoder_") :]: value
for argument, value in kwargs.items()
if argument.startswith("audio_encoder_")
}
kwargs_decoder = {
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
# remove text encoder, audio encoder and decoder kwargs from kwargs
for key in kwargs_text_encoder.keys():
del kwargs["text_encoder_" + key]
for key in kwargs_audio_encoder.keys():
del kwargs["audio_encoder_" + key]
for key in kwargs_decoder.keys():
del kwargs["decoder_" + key]
# Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made
# by the value of the flag `is_decoder` that we need to set correctly.
text_encoder = kwargs_text_encoder.pop("model", None)
if text_encoder is None:
if text_encoder_pretrained_model_name_or_path is None:
raise ValueError(
"If `text_encoder_model` is not defined as an argument, a `text_encoder_pretrained_model_name_or_path` has "
"to be defined."
)
if "config" not in kwargs_text_encoder:
encoder_config, kwargs_text_encoder = AutoConfig.from_pretrained(
text_encoder_pretrained_model_name_or_path, **kwargs_text_encoder, return_unused_kwargs=True
)
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
logger.info(
f"Initializing {text_encoder_pretrained_model_name_or_path} as a text_encoder model "
"from a decoder model. Cross-attention and casual mask are disabled."
)
encoder_config.is_decoder = False
encoder_config.add_cross_attention = False
kwargs_text_encoder["config"] = encoder_config
text_encoder = AutoModel.from_pretrained(
text_encoder_pretrained_model_name_or_path, *model_args, **kwargs_text_encoder
)
audio_encoder = kwargs_audio_encoder.pop("model", None)
if audio_encoder is None:
if audio_encoder_pretrained_model_name_or_path is None:
raise ValueError(
"If `audio_encoder_model` is not defined as an argument, an `audio_encoder_pretrained_model_name_or_path` has "
"to be defined."
)
if "config" not in kwargs_audio_encoder:
encoder_config, kwargs_audio_encoder = AutoConfig.from_pretrained(
audio_encoder_pretrained_model_name_or_path, **kwargs_audio_encoder, return_unused_kwargs=True
)
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
logger.info(
f"Initializing {audio_encoder_pretrained_model_name_or_path} as an audio_encoder model "
"from a decoder model. Cross-attention and casual mask are disabled."
)
encoder_config.is_decoder = False
encoder_config.add_cross_attention = False
kwargs_audio_encoder["config"] = encoder_config
audio_encoder = AutoModel.from_pretrained(
audio_encoder_pretrained_model_name_or_path, *model_args, **kwargs_audio_encoder
)
decoder = kwargs_decoder.pop("model", None)
if decoder is None:
if decoder_pretrained_model_name_or_path is None:
raise ValueError(
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
"to be defined."
)
if "config" not in kwargs_decoder:
decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
)
if isinstance(decoder_config, MusicgenConfig):
decoder_config = decoder_config.decoder
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
kwargs_decoder["config"] = decoder_config
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
logger.warning(
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
"passed to `.from_sub_models_pretrained(...)` are set to `True` or do not pass a "
"`decoder_config` to `.from_sub_models_pretrained(...)`"
)
decoder = MusicgenForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
# instantiate config with corresponding kwargs
config = MusicgenConfig.from_sub_models_config(
text_encoder.config, audio_encoder.config, decoder.config, **kwargs
)
return cls(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder, config=config)
@add_start_docstrings_to_model_forward(MUSICGEN_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
input_values: Optional[torch.FloatTensor] = None,
padding_mask: Optional[torch.BoolTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, Seq2SeqLMOutput]:
r"""
Returns:
Examples:
```python
>>> from transformers import AutoProcessor, MusicgenForConditionalGeneration
>>> import torch
>>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
>>> inputs = processor(
... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
... padding=True,
... return_tensors="pt",
... )
>>> pad_token_id = model.generation_config.pad_token_id
>>> decoder_input_ids = (
... torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long)
... * pad_token_id
... )
>>> logits = model(**inputs, decoder_input_ids=decoder_input_ids).logits
>>> logits.shape # (bsz * num_codebooks, tgt_len, vocab_size)
torch.Size([8, 1, 2048])
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
kwargs_text_encoder = {
argument[len("text_encoder_")]: value
for argument, value in kwargs.items()
if argument.startswith("text_encoder_")
}
kwargs_audio_encoder = {
argument[len("audio_encoder_")]: value
for argument, value in kwargs.items()
if argument.startswith("audio_encoder_")
}
kwargs_decoder = {
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
if encoder_outputs is None:
encoder_outputs = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs_text_encoder,
)
elif isinstance(encoder_outputs, tuple):
encoder_outputs = BaseModelOutput(*encoder_outputs)
encoder_hidden_states = encoder_outputs[0]
# optionally project encoder_hidden_states
if (
self.text_encoder.config.hidden_size != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
if attention_mask is not None:
encoder_hidden_states = encoder_hidden_states * attention_mask[..., None]
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
elif decoder_input_ids is None and decoder_inputs_embeds is None:
audio_encoder_outputs = self.audio_encoder(
input_values=input_values,
padding_mask=padding_mask,
**kwargs_audio_encoder,
)
audio_codes = audio_encoder_outputs.audio_codes
frames, bsz, codebooks, seq_len = audio_codes.shape
if frames != 1:
raise ValueError(
f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is "
"disabled by setting `chunk_length=None` in the audio encoder."
)
decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len)
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=attention_mask,
inputs_embeds=decoder_inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
past_key_values=past_key_values,
return_dict=return_dict,
**kwargs_decoder,
)
loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
if loss is not None:
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs
return Seq2SeqLMOutput(
loss=loss,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_attention_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
decoder_delay_pattern_mask=None,
guidance_scale=None,
**kwargs,
):
if decoder_delay_pattern_mask is None:
decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
decoder_input_ids,
self.generation_config.pad_token_id,
max_length=self.generation_config.max_length,
)
# apply the delay pattern mask
decoder_input_ids = self.decoder.apply_delay_pattern_mask(decoder_input_ids, decoder_delay_pattern_mask)
if guidance_scale is not None and guidance_scale > 1:
# for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these
# before sampling)
decoder_input_ids = decoder_input_ids.repeat((2, 1))
if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.repeat((2, 1))
# cut decoder_input_ids if past is used
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}
def _prepare_decoder_input_ids_for_generation(
self,
batch_size: int,
model_input_name: str,
model_kwargs: Dict[str, torch.Tensor],
decoder_start_token_id: int = None,
bos_token_id: int = None,
device: torch.device = None,
) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
# 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
# we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
elif "input_ids" in model_kwargs and model_input_name != "input_ids":
decoder_input_ids = model_kwargs.pop("input_ids")
else:
decoder_input_ids = None
# 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
if device is None:
device = self.device
decoder_input_ids_start = (
torch.ones((batch_size * self.decoder.num_codebooks, 1), dtype=torch.long, device=device)
* decoder_start_token_id
)
# no user input -> use decoder_start_token_id as decoder_input_ids
if decoder_input_ids is None:
decoder_input_ids = decoder_input_ids_start
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
# decoder_attention_mask if provided)
elif (decoder_input_ids[..., 0] != decoder_start_token_id).all().item():
decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1)
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
decoder_attention_mask = torch.cat(
(torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
dim=-1,
)
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
return decoder_input_ids, model_kwargs
def _prepare_text_encoder_kwargs_for_generation(
self,
inputs_tensor: torch.Tensor,
model_kwargs,
model_input_name: Optional[str] = None,
guidance_scale: Optional[float] = None,
) -> Dict[str, Any]:
# 1. get text encoder
encoder = self.get_text_encoder()
# Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
# as the inputs.
if hasattr(encoder, "_hf_hook"):
encoder._hf_hook.io_same_device = True
# 2. Prepare encoder args and encoder kwargs from model kwargs.
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
encoder_kwargs = {
argument: value
for argument, value in model_kwargs.items()
if not any(argument.startswith(p) for p in irrelevant_prefix)
}
encoder_signature = set(inspect.signature(encoder.forward).parameters)
encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
if not encoder_accepts_wildcard:
encoder_kwargs = {
argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
}
# 3. make sure that encoder returns `ModelOutput`
model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name
encoder_kwargs["return_dict"] = True
encoder_kwargs[model_input_name] = inputs_tensor
last_hidden_state = encoder(**encoder_kwargs).last_hidden_state
# for classifier free guidance we need to add a 'null' input to our encoder hidden states
if guidance_scale is not None and guidance_scale > 1:
last_hidden_state = torch.concatenate([last_hidden_state, torch.zeros_like(last_hidden_state)], dim=0)
if "attention_mask" in model_kwargs:
model_kwargs["attention_mask"] = torch.concatenate(
[model_kwargs["attention_mask"], torch.zeros_like(model_kwargs["attention_mask"])], dim=0
)
model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=last_hidden_state)
return model_kwargs
def _prepare_audio_encoder_kwargs_for_generation(
self, input_values, model_kwargs, model_input_name: Optional[str] = None
):
# 1. get audio encoder
encoder = self.get_audio_encoder()
# Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
# as the inputs.
if hasattr(encoder, "_hf_hook"):
encoder._hf_hook.io_same_device = True
# 2. Prepare encoder args and encoder kwargs from model kwargs.
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
encoder_kwargs = {
argument: value
for argument, value in model_kwargs.items()
if not any(argument.startswith(p) for p in irrelevant_prefix)
}
encoder_signature = set(inspect.signature(encoder.forward).parameters)
encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
if not encoder_accepts_wildcard:
encoder_kwargs = {
argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
}
# 3. make sure that encoder returns `ModelOutput`
model_input_name = model_input_name if model_input_name is not None else self.audio_encoder.main_input_name
encoder_kwargs["return_dict"] = True
encoder_kwargs[model_input_name] = input_values
audio_encoder_outputs = encoder.encode(**encoder_kwargs)
audio_codes = audio_encoder_outputs.audio_codes
frames, bsz, codebooks, seq_len = audio_codes.shape
if frames != 1:
raise ValueError(
f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is "
"disabled by setting `chunk_length=None` in the audio encoder."
)
decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len)
model_kwargs["decoder_input_ids"] = decoder_input_ids
model_kwargs["audio_scales"] = audio_encoder_outputs.audio_scales
return model_kwargs
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
"Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the"
" respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
" model.decoder.resize_token_embeddings(...))"
)
def _maybe_initialize_input_ids_for_generation(
self,
inputs: Optional[torch.Tensor] = None,
bos_token_id: Optional[int] = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.LongTensor:
"""Initializes input ids for generation, if necessary."""
if inputs is not None:
return inputs
encoder_outputs = model_kwargs.get("encoder_outputs")
if encoder_outputs is not None:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape = encoder_outputs[0].size()[:-1]
return torch.ones(shape, dtype=torch.long, device=self.device) * -100
if bos_token_id is None:
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
# If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
# soft-prompting or in multimodal implementations built on top of decoder-only language models.
batch_size = 1
for value in model_kwargs.values():
if isinstance(value, torch.Tensor):
batch_size = value.shape[0]
break
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
synced_gpus: Optional[bool] = None,
**kwargs,
):
"""
Generates sequences of token ids for models with a language modeling head.
<Tip warning={true}>
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
model's default generation configuration. You can override any `generation_config` by passing the corresponding
parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
For an overview of generation strategies and code examples, check out the [following
guide](./generation_strategies).
</Tip>
Parameters:
inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of
`input_ids`, `input_values`, `input_features`, or `pixel_values`.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
logits_processor (`LogitsProcessorList`, *optional*):
Custom logits processors that complement the default logits processors built from arguments and
generation config. If a logit processor is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
Return:
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchDecoderOnlyOutput`],
- [`~generation.SampleDecoderOnlyOutput`],
- [`~generation.BeamSearchDecoderOnlyOutput`],
- [`~generation.BeamSampleDecoderOnlyOutput`]
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
"""
# 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
if generation_config is None:
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())
if model_kwargs.get("encoder_outputs") is not None and type(model_kwargs["encoder_outputs"]) == tuple:
# wrap the unconditional outputs as a BaseModelOutput for compatibility with the rest of generate
model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=model_kwargs["encoder_outputs"][0])
# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
# 3. Define model inputs
# inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed
# otherwise model_input_name is None
# all model-specific keyword inputs are removed from `model_kwargs`
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
# 4. Define other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
model_kwargs["guidance_scale"] = generation_config.guidance_scale
requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
)
if "encoder_outputs" not in model_kwargs:
# encoder_outputs are created and added to `model_kwargs`
model_kwargs = self._prepare_text_encoder_kwargs_for_generation(
inputs_tensor,
model_kwargs,
model_input_name,
guidance_scale=generation_config.guidance_scale,
)
if "decoder_input_ids" not in model_kwargs and "input_values" in model_kwargs:
model_kwargs = self._prepare_audio_encoder_kwargs_for_generation(
model_kwargs["input_values"],
model_kwargs,
)
# 5. Prepare `input_ids` which will be used for auto-regressive generation
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
batch_size=batch_size,
model_input_name=model_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config.decoder_start_token_id,
bos_token_id=generation_config.bos_token_id,
device=inputs_tensor.device,
)
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
logger.warning(
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif generation_config.max_new_tokens is not None:
if not has_default_max_length:
logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError(
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
f" the maximum length ({generation_config.max_length})"
)
if input_ids_seq_length >= generation_config.max_length:
logger.warning(
f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
input_ids,
pad_token_id=generation_config.decoder_start_token_id,
max_length=generation_config.max_length,
)
# stash the delay mask so that we don't have to recompute in each forward pass
model_kwargs["decoder_delay_pattern_mask"] = decoder_delay_pattern_mask
# 7. determine generation mode
is_greedy_gen_mode = (
(generation_config.num_beams == 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is False
)
is_sample_gen_mode = (
(generation_config.num_beams == 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is True
)
# 8. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
)
# 9. prepare stopping criteria
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
if is_greedy_gen_mode:
if generation_config.num_return_sequences > 1:
raise ValueError(
"num_return_sequences has to be 1 when doing greedy search, "
f"but is {generation_config.num_return_sequences}."
)
# 10. run greedy search
outputs = self.greedy_search(
input_ids,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
elif is_sample_gen_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
# expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 12. run sample
outputs = self.sample(
input_ids,
logits_processor=logits_processor,
logits_warper=logits_warper,
stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
else:
raise ValueError(
"Got incompatible mode for generation, should be one of greedy or sampling."
"Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`."
)
if generation_config.return_dict_in_generate:
output_ids = outputs.sequences
else:
output_ids = outputs
# apply the pattern mask to the final ids
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
# revert the pattern delay mask by filtering the pad token id
output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape(
batch_size, self.decoder.num_codebooks, -1
)
# append the frame dimension back to the audio codes
output_ids = output_ids[None, ...]
audio_scales = model_kwargs.get("audio_scales")
if audio_scales is None:
audio_scales = [None] * batch_size
output_values = self.audio_encoder.decode(
output_ids,
audio_scales=audio_scales,
)
if generation_config.return_dict_in_generate:
outputs.sequences = output_values.audio_values
return outputs
else:
return output_values.audio_values
def get_unconditional_inputs(self, num_samples=1):
"""
Helper function to get null inputs for unconditional generation, enabling the model to be used without the
feature extractor or tokenizer.
Args:
num_samples (int, *optional*):
Number of audio samples to unconditionally generate.
max_new_tokens (int, *optional*):
Number of tokens to generate for each sample. More tokens means longer audio samples, at the expense of
longer inference (since more audio tokens need to be generated per sample).
Example:
```python
>>> from transformers import MusicgenForConditionalGeneration
>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
>>> # get the unconditional (or 'null') inputs for the model
>>> unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
>>> audio_samples = model.generate(**unconditional_inputs, max_new_tokens=256)
```"""
last_hidden_state = torch.zeros(
(num_samples, 1, self.config.text_encoder.hidden_size), device=self.device, dtype=self.dtype
)
attention_mask = torch.zeros((num_samples, 1), device=self.device, dtype=torch.long)
return MusicgenUnconditionalInput(
encoder_outputs=(last_hidden_state,),
attention_mask=attention_mask,
guidance_scale=1.0,
)
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Text/audio processor class for MusicGen
"""
from typing import List, Optional
import numpy as np
from ...processing_utils import ProcessorMixin
from ...utils import to_numpy
class MusicgenProcessor(ProcessorMixin):
r"""
Constructs a MusicGen processor which wraps an EnCodec feature extractor and a T5 tokenizer into a single processor
class.
[`MusicgenProcessor`] offers all the functionalities of [`EncodecFeatureExtractor`] and [`TTokenizer`]. See
[`~MusicgenProcessor.__call__`] and [`~MusicgenProcessor.decode`] for more information.
Args:
feature_extractor (`EncodecFeatureExtractor`):
An instance of [`EncodecFeatureExtractor`]. The feature extractor is a required input.
tokenizer (`T5Tokenizer`):
An instance of [`T5Tokenizer`]. The tokenizer is a required input.
"""
feature_extractor_class = "EncodecFeatureExtractor"
tokenizer_class = ("T5Tokenizer", "T5TokenizerFast")
def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer)
self.current_processor = self.feature_extractor
self._in_target_context_manager = False
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
return self.tokenizer.get_decoder_prompt_ids(task=task, language=language, no_timestamps=no_timestamps)
def __call__(self, *args, **kwargs):
"""
Forwards the `audio` argument to EncodecFeatureExtractor's [`~EncodecFeatureExtractor.__call__`] and the `text`
argument to [`~T5Tokenizer.__call__`]. Please refer to the doctsring of the above two methods for more
information.
"""
# For backward compatibility
if self._in_target_context_manager:
return self.current_processor(*args, **kwargs)
audio = kwargs.pop("audio", None)
sampling_rate = kwargs.pop("sampling_rate", None)
text = kwargs.pop("text", None)
if len(args) > 0:
audio = args[0]
args = args[1:]
if audio is None and text is None:
raise ValueError("You need to specify either an `audio` or `text` input to process.")
if text is not None:
inputs = self.tokenizer(text, **kwargs)
if audio is not None:
audio_inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs)
if audio is None:
return inputs
elif text is None:
return audio_inputs
else:
inputs["input_values"] = audio_inputs["input_values"]
if "padding_mask" in audio_inputs:
inputs["padding_mask"] = audio_inputs["padding_mask"]
return inputs
def batch_decode(self, *args, **kwargs):
"""
This method is used to decode either batches of audio outputs from the MusicGen model, or batches of token ids
from the tokenizer. In the case of decoding token ids, this method forwards all its arguments to T5Tokenizer's
[`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information.
"""
audio_values = kwargs.pop("audio", None)
padding_mask = kwargs.pop("padding_mask", None)
if len(args) > 0:
audio_values = args[0]
args = args[1:]
if audio_values is not None:
return self._decode_audio(audio_values, padding_mask=padding_mask)
else:
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to T5Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the
docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
def _decode_audio(self, audio_values, padding_mask: Optional = None) -> List[np.ndarray]:
"""
This method strips any padding from the audio values to return a list of numpy audio arrays.
"""
audio_values = to_numpy(audio_values)
bsz, channels, seq_len = audio_values.shape
if padding_mask is None:
return list(audio_values)
padding_mask = to_numpy(padding_mask)
# match the sequence length of the padding mask to the generated audio arrays by padding with the **non-padding**
# token (so that the generated audio values are **not** treated as padded tokens)
difference = seq_len - padding_mask.shape[-1]
padding_value = 1 - self.feature_extractor.padding_value
padding_mask = np.pad(padding_mask, ((0, 0), (0, difference)), "constant", constant_values=padding_value)
audio_values = audio_values.tolist()
for i in range(bsz):
sliced_audio = np.asarray(audio_values[i])[
padding_mask[i][None, :] != self.feature_extractor.padding_value
]
audio_values[i] = sliced_audio.reshape(channels, -1)
return audio_values
...@@ -4936,6 +4936,44 @@ class MT5PreTrainedModel(metaclass=DummyObject): ...@@ -4936,6 +4936,44 @@ class MT5PreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = None
class MusicgenForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MusicgenForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MusicgenModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MusicgenPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MusicgenProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
MVP_PRETRAINED_MODEL_ARCHIVE_LIST = None MVP_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# coding=utf-8
# Copyright 2021, The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch Musicgen model. """
import copy
import inspect
import math
import unittest
import numpy as np
from transformers import (
EncodecConfig,
MusicgenConfig,
MusicgenDecoderConfig,
MusicgenProcessor,
PretrainedConfig,
T5Config,
)
from transformers.testing_utils import is_torch_available, require_torch, slow, torch_device
from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import (
MusicgenForCausalLM,
MusicgenForConditionalGeneration,
MusicgenModel,
set_seed,
)
from transformers.generation import (
GreedySearchDecoderOnlyOutput,
GreedySearchEncoderDecoderOutput,
InfNanRemoveLogitsProcessor,
LogitsProcessorList,
SampleDecoderOnlyOutput,
SampleEncoderDecoderOutput,
)
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
for key in configs_no_init.__dict__.keys():
if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
setattr(configs_no_init, key, 1e-10)
if isinstance(getattr(configs_no_init, key, None), PretrainedConfig):
no_init_subconfig = _config_zero_init(getattr(configs_no_init, key))
setattr(configs_no_init, key, no_init_subconfig)
return configs_no_init
def prepare_musicgen_decoder_inputs_dict(
config,
input_ids,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
cross_attn_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.reshape(-1, config.num_codebooks, input_ids.shape[-1])[:, 0, :]
attention_mask = attention_mask.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.num_hidden_layers, config.num_attention_heads, device=torch_device)
if encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.num_hidden_layers, config.num_attention_heads, device=torch_device)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
"head_mask": head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
}
class MusicgenDecoderTester:
def __init__(
self,
parent,
batch_size=2,
seq_length=7,
is_training=False,
use_labels=False,
vocab_size=99,
hidden_size=16,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=4,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=100,
pad_token_id=99,
bos_token_id=99,
num_codebooks=4,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.num_codebooks = num_codebooks
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size * self.num_codebooks, self.seq_length], self.vocab_size)
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
config = self.get_config()
inputs_dict = prepare_musicgen_decoder_inputs_dict(
config, input_ids, encoder_hidden_states=encoder_hidden_states
)
return config, inputs_dict
def get_config(self):
config = MusicgenDecoderConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
d_ff=self.intermediate_size,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.bos_token_id,
bos_token_id=self.bos_token_id,
num_codebooks=self.num_codebooks,
tie_word_embeddings=False,
)
return config
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
return config, inputs_dict
@require_torch
class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (MusicgenModel, MusicgenForCausalLM) if is_torch_available() else ()
greedy_sample_model_classes = (
(MusicgenForCausalLM,) if is_torch_available() else ()
) # we don't want to run all the generation tests, only a specific subset
pipeline_model_mapping = {}
test_pruning = False
test_resize_embeddings = False
def setUp(self):
self.model_tester = MusicgenDecoderTester(self)
self.config_tester = ConfigTester(self, config_class=MusicgenDecoderConfig, hidden_size=16)
def test_config(self):
self.config_tester.run_common_tests()
# override since we have to compute the input embeddings over codebooks
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
input_ids = inputs["input_ids"]
del inputs["input_ids"]
embed_tokens = model.get_input_embeddings()
input_ids = input_ids.reshape(-1, config.num_codebooks, input_ids.shape[-1])
inputs["inputs_embeds"] = sum(
[embed_tokens[codebook](input_ids[:, codebook]) for codebook in range(config.num_codebooks)]
)
with torch.no_grad():
model(**inputs)[0]
# override since we have embeddings / LM heads over multiple codebooks
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
first_embed = model.get_input_embeddings()[0]
self.assertIsInstance(first_embed, torch.nn.Embedding)
lm_heads = model.get_output_embeddings()
self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))
# skip as this model doesn't support all arguments tested
def test_model_outputs_equivalence(self):
pass
# skip as this model has multiple inputs embeds and lm heads that should not be tied
def test_tie_model_weights(self):
pass
# skip as this model has multiple inputs embeds and lm heads that should not be tied
def test_tied_weights_keys(self):
pass
def _get_input_ids_and_config(self, batch_size=2):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"]
# take max batch_size
sequence_length = input_ids.shape[-1]
input_ids = input_ids[: batch_size * config.num_codebooks, :]
# generate max 3 tokens
max_length = input_ids.shape[-1] + 3
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
return config, input_ids, attention_mask, max_length
@staticmethod
def _get_logits_processor_and_kwargs(
input_length,
eos_token_id,
forced_bos_token_id=None,
forced_eos_token_id=None,
max_length=None,
diversity_penalty=None,
):
process_kwargs = {
"min_length": input_length + 1 if max_length is None else max_length - 1,
}
logits_processor = LogitsProcessorList()
return process_kwargs, logits_processor
# override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform
# additional post-processing in the former
def test_greedy_generate_dict_outputs(self):
for model_class in self.greedy_sample_model_classes:
# disable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
self.assertNotIn(config.pad_token_id, output_generate)
# override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform
# additional post-processing in the former
def test_greedy_generate_dict_outputs_use_cache(self):
for model_class in self.greedy_sample_model_classes:
# enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
# override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
# additional post-processing in the former
def test_sample_generate(self):
for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1],
model.config.eos_token_id,
forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length,
)
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
# check `generate()` and `sample()` are equal
output_sample, output_generate = self._sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_return_sequences=3,
logits_processor=logits_processor,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs,
)
self.assertIsInstance(output_sample, torch.Tensor)
self.assertIsInstance(output_generate, torch.Tensor)
# override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
# additional post-processing in the former
def test_sample_generate_dict_output(self):
for model_class in self.greedy_sample_model_classes:
# disable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1],
model.config.eos_token_id,
forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length,
)
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
output_sample, output_generate = self._sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_return_sequences=1,
logits_processor=logits_processor,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_sample, SampleDecoderOnlyOutput)
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
def prepare_musicgen_inputs_dict(
config,
input_ids,
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.reshape(
-1, config.decoder.num_codebooks, decoder_input_ids.shape[-1]
)[:, 0, :]
decoder_attention_mask = decoder_attention_mask.ne(config.decoder.pad_token_id)
if head_mask is None:
head_mask = torch.ones(
config.text_encoder.num_hidden_layers, config.text_encoder.num_attention_heads, device=torch_device
)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(
config.decoder.num_hidden_layers, config.decoder.num_attention_heads, device=torch_device
)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(
config.decoder.num_hidden_layers, config.decoder.num_attention_heads, device=torch_device
)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
}
class MusicgenTester:
def __init__(
self,
parent,
batch_size=2,
seq_length=7,
is_training=False,
use_labels=False,
vocab_size=99,
hidden_size=16,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=4,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=100,
pad_token_id=99,
bos_token_id=99,
num_codebooks=4,
num_filters=4,
codebook_size=128,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.num_codebooks = num_codebooks
self.num_filters = num_filters
self.codebook_size = codebook_size
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
decoder_input_ids = ids_tensor([self.batch_size * self.num_codebooks, self.seq_length], self.vocab_size)
config = self.get_config()
inputs_dict = prepare_musicgen_inputs_dict(config, input_ids, decoder_input_ids=decoder_input_ids)
return config, inputs_dict
def get_config(self):
text_encoder_config = T5Config(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
d_ff=self.intermediate_size,
num_layers=self.num_hidden_layers,
num_heads=self.num_attention_heads,
)
audio_encoder_config = EncodecConfig(
hidden_size=self.vocab_size,
compress=1,
num_filters=self.num_filters,
codebook_size=self.codebook_size,
codebook_dim=self.vocab_size,
)
decoder_config = MusicgenDecoderConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
ffn_dim=self.intermediate_size,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.bos_token_id,
bos_token_id=self.bos_token_id,
num_codebooks=self.num_codebooks,
tie_word_embeddings=False,
)
config = MusicgenConfig.from_sub_models_config(text_encoder_config, audio_encoder_config, decoder_config)
return config
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
return config, inputs_dict
@require_torch
class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (MusicgenForConditionalGeneration,) if is_torch_available() else ()
greedy_sample_model_classes = (MusicgenForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = {}
test_pruning = False # training is not supported yet for MusicGen
test_headmasking = False
test_resize_embeddings = False
def setUp(self):
self.model_tester = MusicgenTester(self)
def _check_output_with_attentions(self, outputs, config, input_ids, decoder_input_ids):
text_encoder_config = config.text_encoder
decoder_config = config.decoder
encoder_attentions = outputs["encoder_attentions"]
self.assertEqual(len(encoder_attentions), text_encoder_config.num_hidden_layers)
self.assertEqual(
encoder_attentions[0].shape[-3:],
(text_encoder_config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]),
)
decoder_attentions = outputs["decoder_attentions"]
num_decoder_layers = decoder_config.num_hidden_layers
self.assertEqual(len(decoder_attentions), num_decoder_layers)
self.assertEqual(
decoder_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
)
cross_attentions = outputs["cross_attentions"]
self.assertEqual(len(cross_attentions), num_decoder_layers)
cross_attention_input_seq_len = decoder_input_ids.shape[-1]
self.assertEqual(
cross_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
)
def check_musicgen_model_output_attentions(
self,
model_class,
config,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
**kwargs,
):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
**kwargs,
)
self._check_output_with_attentions(outputs, config, input_ids, decoder_input_ids)
def check_musicgen_model_output_attentions_from_config(
self,
model_class,
config,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
**kwargs,
):
# Similar to `check_musicgen_model_output_attentions`, but with `output_attentions` triggered from the
# config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded
# from the inner models' configurations.
config.output_attentions = True # model config -> won't work
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
**kwargs,
)
self.assertTrue(
all(key not in outputs for key in ["encoder_attentions", "decoder_attentions", "cross_attentions"])
)
config.text_encoder.output_attentions = True # inner model config -> will work
config.audio_encoder.output_attentions = True
config.decoder.output_attentions = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
**kwargs,
)
self._check_output_with_attentions(outputs, config, input_ids, decoder_input_ids)
# override since changing `output_attentions` from the top-level model config won't work
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
self.check_musicgen_model_output_attentions(model_class, config, **inputs_dict)
self.check_musicgen_model_output_attentions_from_config(model_class, config, **inputs_dict)
# override since we have a specific forward signature for musicgen
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = [
"input_ids",
"attention_mask",
"input_values",
"padding_mask",
"decoder_input_ids",
"decoder_attention_mask",
]
expected_arg_names.extend(
["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
else ["encoder_outputs"]
)
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
# override since changing `gradient_checkpointing` from the top-level model config won't work
def test_gradient_checkpointing_backward_compatibility(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if not model_class.supports_gradient_checkpointing:
continue
config.text_encoder.gradient_checkpointing = True
config.audio_encoder.gradient_checkpointing = True
config.decoder.gradient_checkpointing = True
model = model_class(config)
self.assertTrue(model.is_gradient_checkpointing)
# skip as this model has multiple inputs embeds and lm heads that should not be tied
def test_tie_model_weights(self):
pass
# skip as this model has multiple inputs embeds and lm heads that should not be tied
def test_tied_model_weights_key_ignore(self):
pass
# skip as this model has multiple inputs embeds and lm heads that should not be tied
def test_tied_weights_keys(self):
pass
# override since changing `output_hidden_states` / `output_attentions` from the top-level model config won't work
def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.text_encoder.output_hidden_states = True
config.audio_encoder.output_hidden_states = True
config.decoder.output_hidden_states = True
config.text_encoder.output_attentions = True
config.decoder.output_attentions = True
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
model.to(torch_device)
inputs = self._prepare_for_class(inputs_dict, model_class)
outputs = model(**inputs)
output = outputs[0]
encoder_hidden_states = outputs.encoder_hidden_states[0]
encoder_hidden_states.retain_grad()
decoder_hidden_states = outputs.decoder_hidden_states[0]
decoder_hidden_states.retain_grad()
if self.has_attentions:
encoder_attentions = outputs.encoder_attentions[0]
encoder_attentions.retain_grad()
decoder_attentions = outputs.decoder_attentions[0]
decoder_attentions.retain_grad()
cross_attentions = outputs.cross_attentions[0]
cross_attentions.retain_grad()
output.flatten()[0].backward(retain_graph=True)
self.assertIsNotNone(encoder_hidden_states.grad)
self.assertIsNotNone(decoder_hidden_states.grad)
if self.has_attentions:
self.assertIsNotNone(encoder_attentions.grad)
self.assertIsNotNone(decoder_attentions.grad)
self.assertIsNotNone(cross_attentions.grad)
# override since changing `output_hidden_states` from the top-level model config won't work
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states
expected_num_layers = self.model_tester.num_hidden_layers + 1
self.assertEqual(len(hidden_states), expected_num_layers)
seq_length = self.model_tester.seq_length
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[seq_length, self.model_tester.hidden_size],
)
hidden_states = outputs.decoder_hidden_states
self.assertIsInstance(hidden_states, (list, tuple))
self.assertEqual(len(hidden_states), expected_num_layers)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.text_encoder.output_hidden_states = True
config.audio_encoder.output_hidden_states = True
config.decoder.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
# override since the conv layers and lstm's in encodec are exceptions
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
uniform_init_parms = ["conv"]
ignore_init = ["lstm"]
if param.requires_grad:
if any([x in name for x in uniform_init_parms]):
self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
elif not any([x in name for x in ignore_init]):
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
# override since we have embeddings / LM heads over multiple codebooks
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), torch.nn.Embedding)
lm_heads = model.get_output_embeddings()
self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))
def _get_input_ids_and_config(self, batch_size=2):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"]
# take max batch_size
sequence_length = input_ids.shape[-1]
input_ids = input_ids[:batch_size, :]
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
# generate max 3 tokens
decoder_input_ids = inputs_dict["decoder_input_ids"]
max_length = decoder_input_ids.shape[-1] + 3
decoder_input_ids = decoder_input_ids[: batch_size * config.decoder.num_codebooks, :]
return config, input_ids, attention_mask, decoder_input_ids, max_length
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are
# different modalities -> different shapes)
def _greedy_generate(
self,
model,
input_ids,
attention_mask,
decoder_input_ids,
max_length,
output_scores=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
):
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1],
eos_token_id=model.config.eos_token_id,
forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length,
)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
do_sample=False,
num_beams=1,
max_length=max_length,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**logits_process_kwargs,
**model_kwargs,
)
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_greedy = model.greedy_search(
decoder_input_ids,
max_length=max_length,
logits_processor=logits_processor,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
encoder_outputs=encoder_outputs,
**model_kwargs,
)
return output_greedy, output_generate
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are
# different modalities -> different shapes)
def _sample_generate(
self,
model,
input_ids,
attention_mask,
decoder_input_ids,
max_length,
num_return_sequences,
logits_processor,
logits_warper,
logits_warper_kwargs,
process_kwargs,
output_scores=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
):
torch.manual_seed(0)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
do_sample=True,
num_beams=1,
max_length=max_length,
num_return_sequences=num_return_sequences,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**logits_warper_kwargs,
**process_kwargs,
**model_kwargs,
)
torch.manual_seed(0)
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
num_interleave=num_return_sequences,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# prevent flaky generation test failures
logits_processor.append(InfNanRemoveLogitsProcessor())
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_sample = model.sample(
decoder_input_ids.repeat_interleave(num_return_sequences, dim=0),
max_length=max_length,
logits_processor=logits_processor,
logits_warper=logits_warper,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
encoder_outputs=encoder_outputs,
**model_kwargs,
)
return output_sample, output_generate
@staticmethod
def _get_logits_processor_and_kwargs(
input_length,
eos_token_id,
forced_bos_token_id=None,
forced_eos_token_id=None,
max_length=None,
diversity_penalty=None,
):
process_kwargs = {
"min_length": input_length + 1 if max_length is None else max_length - 1,
}
logits_processor = LogitsProcessorList()
return process_kwargs, logits_processor
def test_greedy_generate_dict_outputs(self):
for model_class in self.greedy_sample_model_classes:
# disable cache
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
max_length=max_length,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
self.assertNotIn(config.pad_token_id, output_generate)
def test_greedy_generate_dict_outputs_use_cache(self):
for model_class in self.greedy_sample_model_classes:
# enable cache
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config()
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
max_length=max_length,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
def test_sample_generate(self):
for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1],
model.config.eos_token_id,
forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length,
)
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
# check `generate()` and `sample()` are equal
output_sample, output_generate = self._sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
max_length=max_length,
num_return_sequences=1,
logits_processor=logits_processor,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs,
)
self.assertIsInstance(output_sample, torch.Tensor)
self.assertIsInstance(output_generate, torch.Tensor)
def test_sample_generate_dict_output(self):
for model_class in self.greedy_sample_model_classes:
# disable cache
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1],
model.config.eos_token_id,
forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length,
)
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
output_sample, output_generate = self._sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
max_length=max_length,
num_return_sequences=3,
logits_processor=logits_processor,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_sample, SampleEncoderDecoderOutput)
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
def test_generate_without_input_ids(self):
config, _, _, _, max_length = self._get_input_ids_and_config()
# if no bos token id => cannot generate from None
if config.bos_token_id is None:
return
for model_class in self.greedy_sample_model_classes:
model = model_class(config).to(torch_device)
model.eval()
output_ids_generate = model.generate(do_sample=False, max_length=max_length, remove_invalid_values=True)
self.assertIsNotNone(output_ids_generate)
def test_generate_fp16(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()
for model_class in self.greedy_sample_model_classes:
model = model_class(config).eval().to(torch_device)
if torch_device == "cuda":
model.half()
model.generate(**input_dict, max_new_tokens=10)
model.generate(**input_dict, do_sample=True, max_new_tokens=10)
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
"""Produces a series of 'bip bip' sounds at a given frequency."""
timesteps = np.arange(int(duration * sample_rate)) / sample_rate
wav = np.cos(2 * math.pi * 440 * timesteps)
time_period = (timesteps % (2 * bip_duration)) / (2 * bip_duration)
envelope = time_period >= 0.5
return wav * envelope
def place_dict_on_device(dict_to_place, device):
for key in dict_to_place:
if dict_to_place[key] is not None and isinstance(dict_to_place[key], torch.Tensor):
dict_to_place[key] = dict_to_place[key].to(device)
return dict_to_place
@require_torch
class MusicgenIntegrationTests(unittest.TestCase):
@cached_property
def model(self):
return MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small").to(torch_device)
@cached_property
def processor(self):
return MusicgenProcessor.from_pretrained("facebook/musicgen-small")
@slow
def test_logits_text_prompt(self):
model = self.model
processor = self.processor
inputs = processor(text=["80s music", "Club techno"], padding=True, return_tensors="pt")
# prepare the encoder inputs
input_ids = inputs.input_ids.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
# prepare the decoder inputs
pad_token_id = model.generation_config.pad_token_id
decoder_input_ids = (
torch.ones((input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long).to(torch_device)
* pad_token_id
)
with torch.no_grad():
logits = model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
).logits
# fmt: off
EXPECTED_LOGITS = torch.tensor(
[
-0.9708, -3.0149, -4.6415, -1.4754, -0.2786, -2.3523, -2.6049, -6.7467,
-1.0206, -3.2984, -3.3968, -1.5108, -1.5786, -3.1493, -1.1503, -0.0545,
]
)
# fmt: on
self.assertTrue(logits.shape == (*decoder_input_ids.shape, model.decoder.config.vocab_size))
self.assertTrue(torch.allclose(logits[0, 0, :16].cpu(), EXPECTED_LOGITS, atol=1e-4))
@slow
def test_logits_text_audio_prompt(self):
model = self.model
processor = self.processor
audio = [get_bip_bip(duration=0.5), get_bip_bip(duration=1.0)]
text = ["80s music", "Club techno"]
inputs = processor(audio=audio, text=text, padding=True, return_tensors="pt")
# prepare the text encoder inputs
input_ids = inputs.input_ids.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
# prepare the audio encoder inputs
input_values = inputs.input_values.to(torch_device)
padding_mask = inputs.padding_mask.to(torch_device)
with torch.no_grad():
logits = model(
input_ids,
attention_mask=attention_mask,
input_values=input_values,
padding_mask=padding_mask,
).logits
# fmt: off
EXPECTED_LOGITS = torch.tensor(
[
0.1841, -2.9324, -0.7898, 0.1857, 0.4971, -2.8685, -1.6525, -1.6541,
2.7757, -2.5942, -3.0959, -1.0120, -1.0147, -0.4605, -0.8885, 0.6820,
]
)
# fmt: on
self.assertTrue(logits.shape == (8, 50, 2048))
self.assertTrue(torch.allclose(logits[0, -1, :16].cpu(), EXPECTED_LOGITS, atol=1e-4))
@slow
def test_generate_unconditional_greedy(self):
model = self.model
# only generate 1 sample with greedy - since it's deterministic all elements of the batch will be the same
unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
unconditional_inputs = place_dict_on_device(unconditional_inputs, device=torch_device)
output_values = model.generate(**unconditional_inputs, do_sample=False, max_new_tokens=5)
# fmt: off
EXPECTED_VALUES = torch.tensor(
[
0.0056, 0.0064, 0.0063, 0.0054, 0.0042, 0.0033, 0.0024, 0.0015,
0.0015, 0.0010, 0.0004, -0.0012, -0.0036, -0.0055, -0.0067, -0.0071,
]
)
# fmt: on
self.assertTrue(output_values.shape == (1, 1, 3200))
self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4))
@slow
def test_generate_unconditional_sampling(self):
model = self.model
# for stochastic sampling we can generate multiple outputs
unconditional_inputs = model.get_unconditional_inputs(num_samples=2)
unconditional_inputs = place_dict_on_device(unconditional_inputs, device=torch_device)
set_seed(0)
output_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=10)
# fmt: off
EXPECTED_VALUES = torch.tensor(
[
0.0765, 0.0758, 0.0749, 0.0759, 0.0759, 0.0771, 0.0775, 0.0760,
0.0762, 0.0765, 0.0767, 0.0760, 0.0738, 0.0714, 0.0713, 0.0730,
]
)
# fmt: on
self.assertTrue(output_values.shape == (2, 1, 4480))
self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4))
@slow
def test_generate_text_prompt_greedy(self):
model = self.model
processor = self.processor
inputs = processor(text=["80s music", "Club techno"], padding=True, return_tensors="pt")
# prepare the encoder inputs
input_ids = inputs.input_ids.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
output_values = model.generate(
input_ids, attention_mask=attention_mask, do_sample=False, guidance_scale=None, max_new_tokens=10
)
# fmt: off
EXPECTED_VALUES = torch.tensor(
[
-1.1998e-04, -2.2302e-04, 4.6296e-04, 1.0524e-03, 2.4827e-04,
-4.0288e-05, -1.2468e-04, 4.9846e-05, 7.1485e-04, 4.4197e-04,
]
)
# fmt: on
self.assertTrue(output_values.shape == (2, 1, 4480))
self.assertTrue(torch.allclose(output_values[0, 0, :10].cpu(), EXPECTED_VALUES, atol=1e-4))
@slow
def test_generate_text_prompt_greedy_with_classifier_free_guidance(self):
model = self.model
processor = self.processor
inputs = processor(text=["80s music", "Club techno"], padding=True, return_tensors="pt")
# prepare the encoder inputs
input_ids = inputs.input_ids.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
output_values = model.generate(
input_ids, attention_mask=attention_mask, do_sample=False, guidance_scale=3, max_new_tokens=10
)
# fmt: off
EXPECTED_VALUES = torch.tensor(
[
0.0283, 0.0246, 0.0650, 0.0640, 0.0599, 0.0711, 0.0420, 0.0112,
0.0511, 0.0746, 0.1363, 0.1213, 0.0185, -0.0578, -0.0908, 0.0443,
]
)
# fmt: on
self.assertTrue(output_values.shape == (2, 1, 4480))
self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4))
@slow
def test_generate_text_prompt_sampling(self):
model = self.model
processor = self.processor
inputs = processor(text=["80s music", "Club techno"], padding=True, return_tensors="pt")
# prepare the encoder inputs
input_ids = inputs.input_ids.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
set_seed(0)
output_values = model.generate(
input_ids, attention_mask=attention_mask, do_sample=True, guidance_scale=None, max_new_tokens=10
)
# fmt: off
EXPECTED_VALUES = torch.tensor(
[
-0.0047, -0.0094, -0.0028, -0.0018, -0.0057, -0.0007, -0.0104, -0.0211,
-0.0097, -0.0150, -0.0066, -0.0004, -0.0201, -0.0325, -0.0326, -0.0098,
]
)
# fmt: on
self.assertTrue(output_values.shape == (2, 1, 4480))
self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4))
@slow
def test_generate_text_audio_prompt(self):
model = self.model
processor = self.processor
audio = [get_bip_bip(duration=0.5), get_bip_bip(duration=1.0)]
text = ["80s music", "Club techno"]
inputs = processor(audio=audio, text=text, padding=True, return_tensors="pt")
inputs = place_dict_on_device(inputs, device=torch_device)
output_values = model.generate(**inputs, do_sample=False, guidance_scale=None, max_new_tokens=10)
# fmt: off
EXPECTED_VALUES = torch.tensor(
[
-0.0036, -0.0130, -0.0261, -0.0384, -0.0557, -0.0718, -0.0680, -0.0632,
-0.0529, -0.0403, -0.0289, -0.0198, -0.0136, -0.0101, -0.0095, -0.0040,
]
)
# fmt: on
self.assertTrue(
output_values.shape == (2, 1, 36480)
) # input values take shape 32000 and we generate from there
self.assertTrue(torch.allclose(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, atol=1e-4))
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the MusicGen processor."""
import random
import shutil
import tempfile
import unittest
import numpy as np
from transformers import T5Tokenizer, T5TokenizerFast
from transformers.testing_utils import require_sentencepiece, require_torch
from transformers.utils.import_utils import is_speech_available, is_torch_available
if is_torch_available():
pass
if is_speech_available():
from transformers import EncodecFeatureExtractor, MusicgenProcessor
global_rng = random.Random()
def floats_list(shape, scale=1.0, rng=None, name=None):
"""Creates a random float32 tensor"""
if rng is None:
rng = global_rng
values = []
for batch_idx in range(shape[0]):
values.append([])
for _ in range(shape[1]):
values[-1].append(rng.random() * scale)
return values
@require_torch
@require_sentencepiece
class MusicgenProcessorTest(unittest.TestCase):
def setUp(self):
self.checkpoint = "facebook/musicgen-small"
self.tmpdirname = tempfile.mkdtemp()
def get_tokenizer(self, **kwargs):
return T5Tokenizer.from_pretrained(self.checkpoint, **kwargs)
def get_feature_extractor(self, **kwargs):
return EncodecFeatureExtractor.from_pretrained(self.checkpoint, **kwargs)
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_save_load_pretrained_default(self):
tokenizer = self.get_tokenizer()
feature_extractor = self.get_feature_extractor()
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
processor.save_pretrained(self.tmpdirname)
processor = MusicgenProcessor.from_pretrained(self.tmpdirname)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
self.assertIsInstance(processor.tokenizer, T5TokenizerFast)
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
self.assertIsInstance(processor.feature_extractor, EncodecFeatureExtractor)
def test_save_load_pretrained_additional_features(self):
processor = MusicgenProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
processor.save_pretrained(self.tmpdirname)
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0)
processor = MusicgenProcessor.from_pretrained(
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
self.assertIsInstance(processor.tokenizer, T5TokenizerFast)
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, EncodecFeatureExtractor)
def test_feature_extractor(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
raw_speech = floats_list((3, 1000))
input_feat_extract = feature_extractor(raw_speech, return_tensors="np")
input_processor = processor(raw_speech, return_tensors="np")
for key in input_feat_extract.keys():
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
def test_tokenizer(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "This is a test string"
encoded_processor = processor(text=input_str)
encoded_tok = tokenizer(input_str)
for key in encoded_tok.keys():
self.assertListEqual(encoded_tok[key], encoded_processor[key])
def test_tokenizer_decode(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
decoded_processor = processor.batch_decode(sequences=predicted_ids)
decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.assertListEqual(
processor.model_input_names,
feature_extractor.model_input_names,
msg="`processor` and `feature_extractor` model input names do not match",
)
def test_decode_audio(self):
feature_extractor = self.get_feature_extractor(padding_side="left")
tokenizer = self.get_tokenizer()
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
raw_speech = [floats_list((1, x))[0] for x in range(5, 20, 5)]
padding_mask = processor(raw_speech).padding_mask
generated_speech = np.asarray(floats_list((3, 20)))[:, None, :]
decoded_audios = processor.batch_decode(generated_speech, padding_mask=padding_mask)
self.assertIsInstance(decoded_audios, list)
for audio in decoded_audios:
self.assertIsInstance(audio, np.ndarray)
self.assertTrue(decoded_audios[0].shape == (1, 10))
self.assertTrue(decoded_audios[1].shape == (1, 15))
self.assertTrue(decoded_audios[2].shape == (1, 20))
...@@ -37,6 +37,7 @@ _re_checkpoint = re.compile(r"\[(.+?)\]\((https://huggingface\.co/.+?)\)") ...@@ -37,6 +37,7 @@ _re_checkpoint = re.compile(r"\[(.+?)\]\((https://huggingface\.co/.+?)\)")
CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = { CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
"DecisionTransformerConfig", "DecisionTransformerConfig",
"EncoderDecoderConfig", "EncoderDecoderConfig",
"MusicgenConfig",
"RagConfig", "RagConfig",
"SpeechEncoderDecoderConfig", "SpeechEncoderDecoderConfig",
"TimmBackboneConfig", "TimmBackboneConfig",
......
...@@ -119,6 +119,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ ...@@ -119,6 +119,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
"MegatronBertEncoder", # Building part of bigger (tested) model. "MegatronBertEncoder", # Building part of bigger (tested) model.
"MegatronBertDecoder", # Building part of bigger (tested) model. "MegatronBertDecoder", # Building part of bigger (tested) model.
"MegatronBertDecoderWrapper", # Building part of bigger (tested) model. "MegatronBertDecoderWrapper", # Building part of bigger (tested) model.
"MusicgenDecoder", # Building part of bigger (tested) model.
"MvpDecoderWrapper", # Building part of bigger (tested) model. "MvpDecoderWrapper", # Building part of bigger (tested) model.
"MvpEncoder", # Building part of bigger (tested) model. "MvpEncoder", # Building part of bigger (tested) model.
"PegasusEncoder", # Building part of bigger (tested) model. "PegasusEncoder", # Building part of bigger (tested) model.
...@@ -331,6 +332,8 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ ...@@ -331,6 +332,8 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"SpeechT5ForSpeechToSpeech", "SpeechT5ForSpeechToSpeech",
"SpeechT5ForTextToSpeech", "SpeechT5ForTextToSpeech",
"SpeechT5HifiGan", "SpeechT5HifiGan",
"MusicgenModel",
"MusicgenForConditionalGeneration",
] ]
# Update this list for models that have multiple model types for the same # Update this list for models that have multiple model types for the same
......
...@@ -134,6 +134,9 @@ src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py ...@@ -134,6 +134,9 @@ src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py
src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py
src/transformers/models/mobilevit/modeling_mobilevit.py src/transformers/models/mobilevit/modeling_mobilevit.py
src/transformers/models/mobilevit/modeling_tf_mobilevit.py src/transformers/models/mobilevit/modeling_tf_mobilevit.py
src/transformers/models/musicgen/modeling_musicgen.py
src/transformers/models/musicgen/configuration_musicgen.py
src/transformers/models/musicgen/processing_musicgen.py
src/transformers/models/mvp/configuration_mvp.py src/transformers/models/mvp/configuration_mvp.py
src/transformers/models/nat/configuration_nat.py src/transformers/models/nat/configuration_nat.py
src/transformers/models/nat/modeling_nat.py src/transformers/models/nat/modeling_nat.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment