Unverified Commit 862f8418 authored by eustlb's avatar eustlb Committed by GitHub
Browse files

Add static cache (#89)



* add rope

* don't include padding in rope

* possibly use cross-attn for prompt

* fix rope

* fix cross-attn

* fix self-attn

* fix dummy model

* clean-up rope

* first gqa implementation

* fix wer eval

* feat: add flash attention and spda

* chore: add README for flash attention

* chore: add benchmark script

* chore: add benchmark attention approach

* multi node and fix wer and fix compile

* Update modeling_parler_tts.py

* fix FA2, SDPA and add cross-attn MHA and attention type forcing

* better cross_attention key values number of heads default + add training arguments for attn implementation

* fix audio padding when torch compile or pad_to_max_length=True

* correct multi node

* make rope faster

* fix encoder sdpa

* fix training with cross attention + with FAZ

* use fp32 as default model dtype + fix generation when using FA2 with autocast

* remove redundant passes in generate + clean and fix attentions

* fix edge case in WER evaluation when longform generation

* better multi-node mapping and saving / add eval dataloader num workers

* remove old benchmarks

* faster audio encoding + checkpointing + fix generation step

* unpin trfms

* remove CFG

* imports and constants
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* attention modifications to handle static cach
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* decoder layer modification to handle static cache
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* ParlerTTSPreTrainedModel modifs to handle static cache
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* ParlerTTSDecoder modifs to handle static cache
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* ParlerTTSModel + ParlerTTSForCausalLM modfis
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* ParlerTTSForConditionalGeneration modifs
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* decoder_attention_mask for static cache
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* create inputs_embeds early to have a good cache initialization
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* _get_cache method
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* init the cache
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* ensure good device
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* pin tfrms version
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* fix attention_mask FA2

* remove unnecessary method

* Update parler_tts/modeling_parler_tts.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update parler_tts/modeling_parler_tts.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* remove unnecessary imports

* replace the hardcoded cache_position with a more elegant approach

* make style

* unpin transformers

* pin transformers

* pin torch

* refactor + unpin torch

* Update parler_tts/modeling_parler_tts.py
Co-authored-by: default avatarYoach Lacombe <52246514+ylacombe@users.noreply.github.com>

* update training script to match 11b209e1



* Update parler_tts/modeling_parler_tts.py
Co-authored-by: default avatarYoach Lacombe <52246514+ylacombe@users.noreply.github.com>

* ensure compatibility with trfms 4.43.3, changes taken from #31980 on trfms

* fix input_ids_length

* warning full attention mask creation

* changes for training compatibility

---------
Co-authored-by: default avatarsanchit-gandhi <sanchit@huggingface.co>
Co-authored-by: default avatarYoach Lacombe <yoach.lacombe@gmail.com>
Co-authored-by: default avatarYoach Lacombe <52246514+ylacombe@users.noreply.github.com>
Co-authored-by: default avatarsang-nguyen-ts <sang.nguyen@trustingsocial.com>
Co-authored-by: yoach@huggingface.co <Yoach Lacombe>
Co-authored-by: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
parent 11b209e1
import gradio as gr import gradio as gr
import torch import torch
from transformers import AutoFeatureExtractor, AutoTokenizer, set_seed
from parler_tts import ParlerTTSForConditionalGeneration from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
device = "cuda:0" if torch.cuda.is_available() else "cpu" device = "cuda:0" if torch.cuda.is_available() else "cpu"
...@@ -57,7 +58,7 @@ css = """ ...@@ -57,7 +58,7 @@ css = """
background-color: #000000; background-color: #000000;
justify-content: center; justify-content: center;
align-items: center; align-items: center;
border-radius: 9999px !important; border-radius: 9999px !important;
width: 13rem; width: 13rem;
margin-top: 10px; margin-top: 10px;
margin-left: auto; margin-left: auto;
......
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import AutoConfig
import os
import argparse import argparse
import os
from transformers import AutoConfig
from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration
if __name__ == "__main__": if __name__ == "__main__":
......
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import AutoConfig
import os
import argparse import argparse
import os
from transformers import AutoConfig
from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import AutoConfig
import os
import argparse import argparse
import os
from transformers import AutoConfig
from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration
if __name__ == "__main__": if __name__ == "__main__":
......
import dac import dac
from transformers import AutoConfig, AutoModel, EncodecFeatureExtractor
from parler_tts import DACConfig, DACModel from parler_tts import DACConfig, DACModel
from transformers import AutoConfig, AutoModel from transformers import AutoConfig, AutoModel
from transformers import EncodecFeatureExtractor from transformers import EncodecFeatureExtractor
......
from transformers import AutoFeatureExtractor, AutoTokenizer
from parler_tts import ParlerTTSForConditionalGeneration from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor
path = "TODO" path = "TODO"
repo_id = "parler_tts_600M" repo_id = "parler_tts_600M"
......
__version__ = "0.1" __version__ = "0.1"
from transformers import AutoConfig, AutoModel
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
from .dac_wrapper import DACConfig, DACModel
from .modeling_parler_tts import ( from .modeling_parler_tts import (
ParlerTTSForCausalLM, ParlerTTSForCausalLM,
ParlerTTSForConditionalGeneration, ParlerTTSForConditionalGeneration,
...@@ -9,8 +12,6 @@ from .modeling_parler_tts import ( ...@@ -9,8 +12,6 @@ from .modeling_parler_tts import (
build_delay_pattern_mask, build_delay_pattern_mask,
) )
from .dac_wrapper import DACConfig, DACModel
from transformers import AutoConfig, AutoModel
AutoConfig.register("dac", DACConfig) AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel) AutoModel.register(DACConfig, DACModel)
from transformers import PretrainedConfig from transformers import PretrainedConfig
from typing import List
class DACConfig(PretrainedConfig): class DACConfig(PretrainedConfig):
......
import torch import torch
from dac.model import DAC
from transformers import PreTrainedModel from transformers import PreTrainedModel
from transformers.models.encodec.modeling_encodec import EncodecEncoderOutput, EncodecDecoderOutput from transformers.models.encodec.modeling_encodec import EncodecDecoderOutput, EncodecEncoderOutput
from .configuration_dac import DACConfig
from dac.model import DAC from .configuration_dac import DACConfig
# model doesn't support batching yet # model doesn't support batching yet
...@@ -134,4 +133,4 @@ class DACModel(PreTrainedModel): ...@@ -134,4 +133,4 @@ class DACModel(PreTrainedModel):
return EncodecDecoderOutput(audio_values) return EncodecDecoderOutput(audio_values)
def forward(self, tensor): def forward(self, tensor):
raise ValueError(f"`DACModel.forward` not implemented yet") raise ValueError("`DACModel.forward` not implemented yet")
...@@ -18,21 +18,28 @@ import inspect ...@@ -18,21 +18,28 @@ import inspect
import math import math
import random import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, List from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModel, AutoModelForTextEncoding from transformers import AutoConfig, AutoModel, AutoModelForTextEncoding
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.cache_utils import (
Cache,
DynamicCache,
EncoderDecoderCache,
SlidingWindowCache,
StaticCache,
)
from transformers.generation.configuration_utils import GenerationConfig from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList from transformers.generation.stopping_criteria import StoppingCriteriaList
from transformers.modeling_attn_mask_utils import ( from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_attention_mask, _prepare_4d_attention_mask,
_prepare_4d_causal_attention_mask,
_prepare_4d_attention_mask_for_sdpa, _prepare_4d_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask_for_sdpa,
) )
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutput, BaseModelOutput,
...@@ -48,13 +55,11 @@ from transformers.utils import ( ...@@ -48,13 +55,11 @@ from transformers.utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
import torch.nn.functional as F from transformers.utils.import_utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
from transformers.utils.import_utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
from .dac_wrapper import DACConfig, DACModel from .dac_wrapper import DACConfig, DACModel
from transformers import AutoConfig, AutoModel
AutoConfig.register("dac", DACConfig) AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel) AutoModel.register(DACConfig, DACModel)
...@@ -80,6 +85,9 @@ MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -80,6 +85,9 @@ MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache}
def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
"""Apply a delay pattern mask to the decoder input ids, only preserving predictions where """Apply a delay pattern mask to the decoder input ids, only preserving predictions where
the mask is set to -1, and otherwise setting to the value detailed in the mask.""" the mask is set to -1, and otherwise setting to the value detailed in the mask."""
...@@ -176,14 +184,10 @@ class ParlerTTSUnconditionalInput(ModelOutput): ...@@ -176,14 +184,10 @@ class ParlerTTSUnconditionalInput(ModelOutput):
attention_mask (`torch.LongTensor`) of shape `(batch_size, sequence_length)`, *optional*): attention_mask (`torch.LongTensor`) of shape `(batch_size, sequence_length)`, *optional*):
Encoder attention mask to avoid performing attention on padding token indices. Mask values selected in `[0, Encoder attention mask to avoid performing attention on padding token indices. Mask values selected in `[0,
1]`: 1 for tokens that are **not masked**, 0 for tokens that are **masked**. 1]`: 1 for tokens that are **not masked**, 0 for tokens that are **masked**.
guidance_scale (`float`, *optional*):
Guidance scale for classifier free guidance, setting the balance between the conditional logits (predicted
from the prompts) and the unconditional logits (predicted without prompts).
""" """
encoder_outputs: Tuple[torch.FloatTensor] = None encoder_outputs: Tuple[torch.FloatTensor] = None
attention_mask: torch.LongTensor = None attention_mask: torch.LongTensor = None
guidance_scale: float = None
# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right # Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right
...@@ -244,7 +248,7 @@ class ParlerTTSSinusoidalPositionalEmbedding(nn.Module): ...@@ -244,7 +248,7 @@ class ParlerTTSSinusoidalPositionalEmbedding(nn.Module):
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
bsz, seq_len, _ = input_ids.size() bsz, seq_len, _ = input_ids.size()
# Create the position ids from the input token ids. # Create the position ids from the input token ids.
position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device) position_ids = torch.arange(seq_len, device=input_ids.device) + past_key_values_length
# expand embeddings if needed # expand embeddings if needed
if seq_len > self.weights.size(0): if seq_len > self.weights.size(0):
self.make_weights(seq_len + self.offset, self.embedding_dim) self.make_weights(seq_len + self.offset, self.embedding_dim)
...@@ -331,6 +335,7 @@ class ParlerTTSAttention(nn.Module): ...@@ -331,6 +335,7 @@ class ParlerTTSAttention(nn.Module):
bias: bool = True, bias: bool = True,
is_causal: bool = False, is_causal: bool = False,
rope_embeddings: bool = False, rope_embeddings: bool = False,
layer_idx: Optional[int] = None,
config: Optional[ParlerTTSDecoderConfig] = None, config: Optional[ParlerTTSDecoderConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -351,6 +356,14 @@ class ParlerTTSAttention(nn.Module): ...@@ -351,6 +356,14 @@ class ParlerTTSAttention(nn.Module):
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.is_causal = is_causal self.is_causal = is_causal
if layer_idx is None and is_decoder:
logger.warning_once(
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
"will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.layer_idx = layer_idx
self.k_proj = nn.Linear(embed_dim, self.num_key_value_heads * self.head_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, self.num_key_value_heads * self.head_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, self.num_key_value_heads * self.head_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, self.num_key_value_heads * self.head_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
...@@ -368,40 +381,42 @@ class ParlerTTSAttention(nn.Module): ...@@ -368,40 +381,42 @@ class ParlerTTSAttention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None, key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[EncoderDecoderCache] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
cos: Optional[torch.LongTensor] = None, cos: Optional[torch.LongTensor] = None,
sin: Optional[torch.LongTensor] = None, sin: Optional[torch.LongTensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer # if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder # for the decoder
is_cross_attention = key_value_states is not None is_cross_attention = key_value_states is not None
bsz, tgt_len = hidden_states.shape[:2] bsz, tgt_len = hidden_states.shape[:2]
# get query proj # get query proj
query_states = self.q_proj(hidden_states) * self.scaling query_states = self.q_proj(hidden_states) * self.scaling
query_states = self._shape_query(query_states, tgt_len, bsz) query_states = self._shape_query(query_states, tgt_len, bsz)
if self.rope_embeddings: if self.rope_embeddings:
query_states = apply_rotary_pos_emb(query_states, cos, sin) query_states = apply_rotary_pos_emb(query_states, cos, sin)
current_states = key_value_states if key_value_states is not None else hidden_states if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
past_key_value.is_updated[self.layer_idx] = True
past_key_value = past_key_value.cross_attention_cache
else:
past_key_value = past_key_value.self_attention_cache
# checking that the `sequence_length` of the `past_key_value` is the same as # use key_value_states if cross attention
# the provided `key_value_states` to support prefix tuning current_states = key_value_states if key_value_states is not None else hidden_states
if ( if is_cross_attention and past_key_value and is_updated:
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions # reuse k,v, cross_attentions
key_states = past_key_value[0] key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value[1] value_states = past_key_value.value_cache[self.layer_idx]
else: else:
key_states = self._shape_key_value(self.k_proj(current_states), -1, bsz) key_states = self._shape_key_value(self.k_proj(current_states), -1, bsz)
value_states = self._shape_key_value(self.v_proj(current_states), -1, bsz) value_states = self._shape_key_value(self.v_proj(current_states), -1, bsz)
...@@ -410,45 +425,21 @@ class ParlerTTSAttention(nn.Module): ...@@ -410,45 +425,21 @@ class ParlerTTSAttention(nn.Module):
# cached key states already have rope applied - only apply to new state # cached key states already have rope applied - only apply to new state
key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states
if past_key_value is not None: if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2) # save all key/value_states to cache to be re-used for fast auto-regressive generation
value_states = torch.cat([past_key_value[1], value_states], dim=2) cache_position = cache_position if not is_cross_attention else None
key_states, value_states = past_key_value.update(
if self.is_decoder: key_states, value_states, self.layer_idx, {"cache_position": cache_position}
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. )
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
proj_shape = (bsz * self.num_heads, -1, self.head_dim) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
query_states = query_states.reshape(*proj_shape)
key_states = key_states.reshape(*proj_shape)
value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): if attention_mask is not None: # no matter the length, we just slice it
raise ValueError( causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" attn_weights = attn_weights + causal_mask
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
...@@ -458,35 +449,25 @@ class ParlerTTSAttention(nn.Module): ...@@ -458,35 +449,25 @@ class ParlerTTSAttention(nn.Module):
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}" f" {layer_head_mask.size()}"
) )
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_probs, value_states)
attn_output = torch.bmm(attn_probs, value_states) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError( raise ValueError(
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}" f" {attn_output.size()}"
) )
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism. # partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value return attn_output, attn_weights, past_key_value
def _get_unpad_data(attention_mask): def _get_unpad_data(attention_mask):
...@@ -522,64 +503,66 @@ class ParlerTTSFlashAttention2(ParlerTTSAttention): ...@@ -522,64 +503,66 @@ class ParlerTTSFlashAttention2(ParlerTTSAttention):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None, key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[EncoderDecoderCache] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
cos: Optional[torch.LongTensor] = None, cos: Optional[torch.LongTensor] = None,
sin: Optional[torch.LongTensor] = None, sin: Optional[torch.LongTensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# ParlerTTSFlashAttention2 attention does not support output_attentions # ParlerTTSFlashAttention2 attention does not support output_attentions
if output_attentions: if isinstance(past_key_value, StaticCache):
raise ValueError("ParlerTTSFlashAttention2 attention does not support output_attentions") raise ValueError(
"The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. "
"Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers"
)
# if key_value_states are provided this layer is used as a cross-attention layer # if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder # for the decoder
is_cross_attention = key_value_states is not None is_cross_attention = key_value_states is not None
bsz, q_len = hidden_states.shape[:2] bsz, tgt_len = hidden_states.shape[:2]
# get query proj # get query proj
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) query_states = self.q_proj(hidden_states).view(bsz, tgt_len, self.num_heads, self.head_dim)
if self.rope_embeddings: if self.rope_embeddings:
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
current_states = key_value_states if key_value_states is not None else hidden_states if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
past_key_value.is_updated[self.layer_idx] = True
past_key_value = past_key_value.cross_attention_cache
else:
past_key_value = past_key_value.self_attention_cache
# checking that the `sequence_length` of the `past_key_value` is the same as # use key_value_states if cross attention
# the provided `key_value_states` to support prefix tuning current_states = key_value_states if key_value_states is not None else hidden_states
if ( if is_cross_attention and past_key_value and is_updated:
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions # reuse k,v, cross_attentions
key_states = past_key_value[0].transpose(1, 2) key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value[1].transpose(1, 2) value_states = past_key_value.value_cache[self.layer_idx]
else: else:
key_states = self.k_proj(current_states).view(bsz, -1, self.num_key_value_heads, self.head_dim) key_states = self._shape_key_value(self.k_proj(current_states), -1, bsz)
value_states = self.v_proj(current_states).view(bsz, -1, self.num_key_value_heads, self.head_dim) value_states = self._shape_key_value(self.v_proj(current_states), -1, bsz)
if not is_cross_attention: if not is_cross_attention and self.rope_embeddings:
# cached key states already have rope applied - only apply to new state # cached key states already have rope applied - only apply to new state
key_states = ( key_states = apply_rotary_pos_emb(key_states, cos, sin)
apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2) if self.rope_embeddings else key_states
)
if past_key_value is not None: if past_key_value is not None:
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) # save all key/value_states to cache to be re-used for fast auto-regressive generation
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) cache_position = cache_position if not is_cross_attention else None
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
if self.is_decoder: # # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
# Further calls to cross_attention layer can then reuse all cross-attention key_states = key_states.transpose(1, 2)
# key/value_states (first "if" case) value_states = value_states.transpose(1, 2)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
# In PEFT, usually we cast the layer norms in float32 for training stability reasons # In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need # therefore the input hidden states gets silently casted in float32. Hence, we need
...@@ -607,10 +590,10 @@ class ParlerTTSFlashAttention2(ParlerTTSAttention): ...@@ -607,10 +590,10 @@ class ParlerTTSFlashAttention2(ParlerTTSAttention):
value_states = value_states.to(target_dtype) value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward( attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout query_states, key_states, value_states, attention_mask, tgt_len, dropout=self.dropout
) )
attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = attn_output.reshape(bsz, tgt_len, -1)
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
if not output_attentions: if not output_attentions:
...@@ -723,12 +706,13 @@ class ParlerTTSSdpaAttention(ParlerTTSAttention): ...@@ -723,12 +706,13 @@ class ParlerTTSSdpaAttention(ParlerTTSAttention):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None, key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[EncoderDecoderCache] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
cos: Optional[torch.LongTensor] = None, cos: Optional[torch.LongTensor] = None,
sin: Optional[torch.LongTensor] = None, sin: Optional[torch.LongTensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
if output_attentions or layer_head_mask is not None: if output_attentions or layer_head_mask is not None:
...@@ -744,6 +728,7 @@ class ParlerTTSSdpaAttention(ParlerTTSAttention): ...@@ -744,6 +728,7 @@ class ParlerTTSSdpaAttention(ParlerTTSAttention):
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
# if key_value_states are provided this layer is used as a cross-attention layer # if key_value_states are provided this layer is used as a cross-attention layer
...@@ -759,39 +744,39 @@ class ParlerTTSSdpaAttention(ParlerTTSAttention): ...@@ -759,39 +744,39 @@ class ParlerTTSSdpaAttention(ParlerTTSAttention):
if self.rope_embeddings: if self.rope_embeddings:
query_states = apply_rotary_pos_emb(query_states, cos, sin) query_states = apply_rotary_pos_emb(query_states, cos, sin)
current_states = key_value_states if key_value_states is not None else hidden_states if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
past_key_value.is_updated[self.layer_idx] = True
past_key_value = past_key_value.cross_attention_cache
else:
past_key_value = past_key_value.self_attention_cache
# checking that the `sequence_length` of the `past_key_value` is the same as # use key_value_states if cross attention
# the provided `key_value_states` to support prefix tuning current_states = key_value_states if key_value_states is not None else hidden_states
if ( if is_cross_attention and past_key_value and is_updated:
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions # reuse k,v, cross_attentions
key_states = past_key_value[0] key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value[1] value_states = past_key_value.value_cache[self.layer_idx]
else: else:
key_states = self._shape_key_value(self.k_proj(current_states), -1, bsz) key_states = self._shape_key_value(self.k_proj(current_states), -1, bsz)
value_states = self._shape_key_value(self.v_proj(current_states), -1, bsz) value_states = self._shape_key_value(self.v_proj(current_states), -1, bsz)
if not is_cross_attention: if not is_cross_attention and self.rope_embeddings:
# cached key states already have rope applied - only apply to new state # cached key states already have rope applied - only apply to new state
key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states key_states = apply_rotary_pos_emb(key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2) # save all key/value_states to cache to be re-used for fast auto-regressive generation
value_states = torch.cat([past_key_value[1], value_states], dim=2) cache_position = cache_position if not is_cross_attention else None
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
if self.is_decoder: causal_mask = attention_mask
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. if attention_mask is not None: # no matter the length, we just slice it
# Further calls to cross_attention layer can then reuse all cross-attention causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
# repeat k/v heads if n_kv_heads < n_heads # repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
...@@ -800,7 +785,7 @@ class ParlerTTSSdpaAttention(ParlerTTSAttention): ...@@ -800,7 +785,7 @@ class ParlerTTSSdpaAttention(ParlerTTSAttention):
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
...@@ -808,7 +793,7 @@ class ParlerTTSSdpaAttention(ParlerTTSAttention): ...@@ -808,7 +793,7 @@ class ParlerTTSSdpaAttention(ParlerTTSAttention):
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=attention_mask, attn_mask=causal_mask,
dropout_p=self.dropout if self.training else 0.0, dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal=is_causal, is_causal=is_causal,
...@@ -839,7 +824,7 @@ PARLERTTS_ATTENTION_CLASSES = { ...@@ -839,7 +824,7 @@ PARLERTTS_ATTENTION_CLASSES = {
class ParlerTTSDecoderLayer(nn.Module): class ParlerTTSDecoderLayer(nn.Module):
def __init__(self, config: ParlerTTSDecoderConfig): def __init__(self, config: ParlerTTSDecoderConfig, layer_idx: int = None):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
...@@ -852,6 +837,7 @@ class ParlerTTSDecoderLayer(nn.Module): ...@@ -852,6 +837,7 @@ class ParlerTTSDecoderLayer(nn.Module):
is_causal=True, is_causal=True,
bias=False, bias=False,
rope_embeddings=config.rope_embeddings, rope_embeddings=config.rope_embeddings,
layer_idx=layer_idx,
config=config, config=config,
) )
self.dropout = config.dropout self.dropout = config.dropout
...@@ -872,6 +858,7 @@ class ParlerTTSDecoderLayer(nn.Module): ...@@ -872,6 +858,7 @@ class ParlerTTSDecoderLayer(nn.Module):
is_decoder=True, is_decoder=True,
bias=False, bias=False,
rope_embeddings=config.rope_embeddings, rope_embeddings=config.rope_embeddings,
layer_idx=layer_idx,
config=config, config=config,
) )
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
...@@ -889,9 +876,10 @@ class ParlerTTSDecoderLayer(nn.Module): ...@@ -889,9 +876,10 @@ class ParlerTTSDecoderLayer(nn.Module):
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[EncoderDecoderCache] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
cache_position: Optional[torch.LongTensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -918,31 +906,24 @@ class ParlerTTSDecoderLayer(nn.Module): ...@@ -918,31 +906,24 @@ class ParlerTTSDecoderLayer(nn.Module):
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention # Self Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
past_key_value=self_attn_past_key_value, past_key_value=past_key_value,
attention_mask=attention_mask, attention_mask=attention_mask,
cos=cos, cos=cos,
sin=sin, sin=sin,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# Cross-Attention Block # Cross-Attention Block
cross_attn_present_key_value = None
cross_attn_weights = None cross_attn_weights = None
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
residual = hidden_states residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states) hidden_states = self.encoder_attn_layer_norm(hidden_states)
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
...@@ -950,14 +931,14 @@ class ParlerTTSDecoderLayer(nn.Module): ...@@ -950,14 +931,14 @@ class ParlerTTSDecoderLayer(nn.Module):
cos=cos, cos=cos,
sin=sin, sin=sin,
layer_head_mask=cross_attn_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# add cross-attn to positions 3,4 of present_key_value tuple # add cross-attn to positions 1 of present_key_value tuple
present_key_value = present_key_value + cross_attn_present_key_value present_key_value = (present_key_value, cross_attn_present_key_value)
# Fully Connected # Fully Connected
residual = hidden_states residual = hidden_states
...@@ -992,6 +973,8 @@ class ParlerTTSPreTrainedModel(PreTrainedModel): ...@@ -992,6 +973,8 @@ class ParlerTTSPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_no_split_modules = ["ParlerTTSDecoderLayer", "ParlerTTSAttention"] _no_split_modules = ["ParlerTTSDecoderLayer", "ParlerTTSAttention"]
_supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_factor std = self.config.initializer_factor
...@@ -1088,14 +1071,18 @@ MUSICGEN_INPUTS_DOCSTRING = r""" ...@@ -1088,14 +1071,18 @@ MUSICGEN_INPUTS_DOCSTRING = r"""
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
TODO: it's passed through enc_to_dec_proj and optionnally we concat the prompt hidden states in certain cases. TODO: it's passed through enc_to_dec_proj and optionnally we concat the prompt hidden states in certain cases.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): past_key_values (`EncoderDecoderCache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are
four sets of pre-computed hidden-states: key and values states in the self-attention blocks (2) and
in the cross-attention blocks (2). The `past_key_values` are returned when `use_cache=True` is passed or
when `config.use_cache=True`
Two formats are allowed:
- An [`~cache_utils.EncoderDecoderCache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
...@@ -1141,6 +1128,9 @@ MUSICGEN_INPUTS_DOCSTRING = r""" ...@@ -1141,6 +1128,9 @@ MUSICGEN_INPUTS_DOCSTRING = r"""
more detail. more detail.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the cache
in the correct position and to infer the complete sequence length.
""" """
MUSICGEN_DECODER_INPUTS_DOCSTRING = r""" MUSICGEN_DECODER_INPUTS_DOCSTRING = r"""
...@@ -1263,8 +1253,9 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -1263,8 +1253,9 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
max_position_embeddings=config.max_position_embeddings, max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta, base=config.rope_theta,
) )
self.layers = nn.ModuleList(
self.layers = nn.ModuleList([ParlerTTSDecoderLayer(config) for _ in range(config.num_hidden_layers)]) [ParlerTTSDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.layer_norm = nn.LayerNorm(config.hidden_size) self.layer_norm = nn.LayerNorm(config.hidden_size)
self.attn_implementation = config._attn_implementation self.attn_implementation = config._attn_implementation
encoder_attn_implementation = config._attn_implementation encoder_attn_implementation = config._attn_implementation
...@@ -1301,6 +1292,7 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -1301,6 +1292,7 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position=None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
...@@ -1323,16 +1315,44 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -1323,16 +1315,44 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
else: else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)]) inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)])
prepended_sequence_length = 0
# if prompt_hidden_states, fuse to inputs_embeds and update input shape # if prompt_hidden_states, fuse to inputs_embeds and update input shape
if prompt_hidden_states is not None: if prompt_hidden_states is not None:
prepended_sequence_length = prompt_hidden_states.shape[-2]
inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1) inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1)
return_legacy_cache = False
return_self_attention_cache = False
if use_cache or past_key_values is not None:
if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
return_self_attention_cache = True
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
elif not isinstance(past_key_values, EncoderDecoderCache):
return_legacy_cache = True
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. "
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
)
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
past_key_values_length = 0
if cache_position is not None:
past_key_values_length = cache_position[0]
elif past_key_values is not None:
past_key_values_length = past_key_values.get_seq_length()
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + input_shape[1] + prepended_sequence_length, device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# NOTE: 1. As it is, the masked ids from the prompt will still count in the positions embeddings # NOTE: 1. As it is, the masked ids from the prompt will still count in the positions embeddings
# NOTE: 2. we want to concatenate the prompt attention mask and the decoder attention mask # NOTE: 2. we want to concatenate the prompt attention mask and the decoder attention mask
# i.i.f `prompt_cross_attention=False`. ParlerTTSForConditionalGeneration's taking care of setting # i.i.f `prompt_cross_attention=False`. ParlerTTSForConditionalGeneration's taking care of setting
...@@ -1343,7 +1363,7 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -1343,7 +1363,7 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
logger.warning_once( logger.warning_once(
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour." "`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour."
) )
if past_key_values is None: if past_key_values_length == 0:
attention_mask = torch.cat( attention_mask = torch.cat(
[ [
prompt_attention_mask, prompt_attention_mask,
...@@ -1401,23 +1421,14 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -1401,23 +1421,14 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
if self.attn_implementation == "flash_attention_2": causal_mask = self._update_causal_mask(
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None attention_mask,
elif self.attn_implementation == "sdpa" and head_mask is None and not output_attentions: inputs_embeds,
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on cache_position,
# the manual implementation that requires a 4D causal mask in all cases. past_key_values.self_attention_cache if past_key_values is not None else None,
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( output_attentions,
attention_mask, )
input_shape,
inputs_embeds,
past_key_values_length,
)
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
if self.encoder_attn_implementation == "flash_attention_2": if self.encoder_attn_implementation == "flash_attention_2":
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
...@@ -1442,12 +1453,10 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -1442,12 +1453,10 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
"`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
...@@ -1465,13 +1474,11 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -1465,13 +1474,11 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): if self.training and (dropout_probability < self.layerdrop):
continue continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.forward, decoder_layer.forward,
hidden_states, hidden_states,
attention_mask, causal_mask,
cos, cos,
sin, sin,
encoder_hidden_states, encoder_hidden_states,
...@@ -1481,11 +1488,12 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -1481,11 +1488,12 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
None, None,
output_attentions, output_attentions,
use_cache, use_cache,
cache_position,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=causal_mask,
cos=cos, cos=cos,
sin=sin, sin=sin,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -1494,15 +1502,13 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -1494,15 +1502,13 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
cross_attn_layer_head_mask=( cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
), ),
past_key_value=past_key_value, past_key_value=past_key_values if use_cache else None,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
if output_attentions: if output_attentions:
all_self_attns += (layer_outputs[1],) all_self_attns += (layer_outputs[1],)
...@@ -1515,7 +1521,11 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -1515,7 +1521,11 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None next_cache = past_key_values if use_cache else None
if return_self_attention_cache:
next_cache = past_key_values.self_attention_cache
if return_legacy_cache:
next_cache = past_key_values.to_legacy_cache()
if not return_dict: if not return_dict:
return tuple( return tuple(
v v
...@@ -1530,6 +1540,87 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -1530,6 +1540,87 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@add_start_docstrings( @add_start_docstrings(
"The bare ParlerTTS decoder model outputting raw hidden-states without any specific head on top.", "The bare ParlerTTS decoder model outputting raw hidden-states without any specific head on top.",
...@@ -1564,12 +1655,13 @@ class ParlerTTSModel(ParlerTTSPreTrainedModel): ...@@ -1564,12 +1655,13 @@ class ParlerTTSModel(ParlerTTSPreTrainedModel):
prompt_attention_mask: Optional[torch.LongTensor] = None, prompt_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
...@@ -1595,6 +1687,7 @@ class ParlerTTSModel(ParlerTTSPreTrainedModel): ...@@ -1595,6 +1687,7 @@ class ParlerTTSModel(ParlerTTSPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position,
) )
if not return_dict: if not return_dict:
...@@ -1665,6 +1758,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1665,6 +1758,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
...@@ -1692,6 +1786,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1692,6 +1786,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
...@@ -1752,7 +1847,8 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1752,7 +1847,8 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
past_key_values=None, past_key_values=None,
use_cache=True, use_cache=True,
delay_pattern_mask=None, delay_pattern_mask=None,
guidance_scale=None, cache_position=None,
inputs_embeds=None,
**kwargs, **kwargs,
): ):
if delay_pattern_mask is None: if delay_pattern_mask is None:
...@@ -1766,22 +1862,11 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1766,22 +1862,11 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
# apply the delay pattern mask # apply the delay pattern mask
input_ids = self.apply_delay_pattern_mask(input_ids, delay_pattern_mask) input_ids = self.apply_delay_pattern_mask(input_ids, delay_pattern_mask)
if guidance_scale is not None and guidance_scale > 1: position_ids = kwargs.get("position_ids", None)
# for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these if attention_mask is not None and position_ids is None:
# before sampling) # create position_ids on the fly for batch generation
input_ids = input_ids.repeat((2, 1)) position_ids = attention_mask.long().cumsum(-1) - 1
if attention_mask is not None: position_ids.masked_fill_(attention_mask == 0, 1)
attention_mask = attention_mask.repeat((2, 1))
if prompt_hidden_states is not None:
prompt_hidden_states = torch.concatenate(
[prompt_hidden_states, torch.zeros_like(prompt_hidden_states)], dim=0
)
if prompt_attention_mask is not None:
prompt_attention_mask = torch.concatenate(
[prompt_attention_mask, torch.zeros_like(prompt_attention_mask)], dim=0
)
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None: if attention_mask is not None and position_ids is None:
...@@ -1798,7 +1883,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1798,7 +1883,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
prompt_hidden_states = None prompt_hidden_states = None
return { return {
"input_ids": input_ids, "input_ids": input_ids.contiguous(), # `contiguous()` needed for compilation use cases
"attention_mask": attention_mask, "attention_mask": attention_mask,
"position_ids": position_ids, "position_ids": position_ids,
"encoder_hidden_states": encoder_hidden_states, "encoder_hidden_states": encoder_hidden_states,
...@@ -1809,6 +1894,8 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1809,6 +1894,8 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
"cross_attn_head_mask": cross_attn_head_mask, "cross_attn_head_mask": cross_attn_head_mask,
"past_key_values": past_key_values, "past_key_values": past_key_values,
"use_cache": use_cache, "use_cache": use_cache,
"cache_position": cache_position,
"inputs_embeds": inputs_embeds,
} }
# Ignore copy # Ignore copy
...@@ -1948,10 +2035,11 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1948,10 +2035,11 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
inputs, generation_config.bos_token_id, model_kwargs inputs, generation_config.bos_token_id, model_kwargs
) )
batch_size = input_ids.shape[0] // self.num_codebooks batch_size = input_ids.shape[0] // self.num_codebooks
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
# 4. Define other model kwargs # 4. Define other model kwargs
model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["use_cache"] = generation_config.use_cache
model_kwargs["guidance_scale"] = generation_config.guidance_scale
requires_attention_mask = "encoder_outputs" not in model_kwargs requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
...@@ -2016,21 +2104,17 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -2016,21 +2104,17 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
and generation_config.do_sample is True and generation_config.do_sample is True
) )
# 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG) # 8. prepare distribution pre_processing samplers
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
generation_config.guidance_scale = None
# 9. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor( logits_processor = self._get_logits_processor(
generation_config=generation_config, generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length, input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids, encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=None, prefix_allowed_tokens_fn=None,
logits_processor=logits_processor, logits_processor=logits_processor,
device=input_ids.device,
) )
# 10. prepare stopping criteria # 9. prepare stopping criteria
stopping_criteria = self._get_stopping_criteria( stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria generation_config=generation_config, stopping_criteria=stopping_criteria
) )
...@@ -2042,8 +2126,8 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -2042,8 +2126,8 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
f"but is {generation_config.num_return_sequences}." f"but is {generation_config.num_return_sequences}."
) )
# 11. run greedy search # 10. run greedy search
outputs = self._greedy_search( outputs = self._sample(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
...@@ -2054,8 +2138,8 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -2054,8 +2138,8 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
) )
elif is_sample_gen_mode: elif is_sample_gen_mode:
# 11. prepare logits warper # 10. prepare logits warper
logits_warper = self._get_logits_warper(generation_config) logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
# expand input_ids with `num_return_sequences` additional sequences per batch # expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
...@@ -2064,7 +2148,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -2064,7 +2148,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
**model_kwargs, **model_kwargs,
) )
# 12. run sample # 11. run sample
outputs = self._sample( outputs = self._sample(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
...@@ -2120,6 +2204,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2120,6 +2204,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
def __init__( def __init__(
self, self,
...@@ -2498,7 +2584,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2498,7 +2584,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
decoder_input_ids: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None, decoder_attention_mask: Optional[torch.BoolTensor] = None,
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
prompt_input_ids: Optional[torch.FloatTensor] = None, prompt_input_ids: Optional[torch.FloatTensor] = None,
...@@ -2510,6 +2596,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2510,6 +2596,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, **kwargs,
) -> Union[Tuple, Seq2SeqLMOutput]: ) -> Union[Tuple, Seq2SeqLMOutput]:
r""" r"""
...@@ -2653,6 +2740,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2653,6 +2740,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
past_key_values=past_key_values, past_key_values=past_key_values,
return_dict=return_dict, return_dict=return_dict,
labels=labels, labels=labels,
cache_position=cache_position,
**kwargs_decoder, **kwargs_decoder,
) )
...@@ -2685,7 +2773,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2685,7 +2773,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
use_cache=None, use_cache=None,
encoder_outputs=None, encoder_outputs=None,
decoder_delay_pattern_mask=None, decoder_delay_pattern_mask=None,
guidance_scale=None, cache_position=None,
inputs_embeds=None,
**kwargs, **kwargs,
): ):
if decoder_delay_pattern_mask is None: if decoder_delay_pattern_mask is None:
...@@ -2699,19 +2788,17 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2699,19 +2788,17 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
# apply the delay pattern mask # apply the delay pattern mask
decoder_input_ids = self.decoder.apply_delay_pattern_mask(decoder_input_ids, decoder_delay_pattern_mask) decoder_input_ids = self.decoder.apply_delay_pattern_mask(decoder_input_ids, decoder_delay_pattern_mask)
if guidance_scale is not None and guidance_scale > 1: past_length = 0
# for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these
# before sampling)
decoder_input_ids = decoder_input_ids.repeat((2, 1))
if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.repeat((2, 1))
if prompt_hidden_states is not None:
prompt_hidden_states = prompt_hidden_states.repeat((2, 1, 1))
if prompt_attention_mask is not None:
prompt_attention_mask = prompt_attention_mask.repeat((2, 1))
if past_key_values is not None: if past_key_values is not None:
past_length = past_key_values[0][0].shape[2] if isinstance(past_key_values, EncoderDecoderCache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
if past_key_values.get_seq_length() > 0:
# we only want to use prompt signal in the 1st generation step
prompt_hidden_states = None
else:
past_length = past_key_values[0][0].shape[2]
# we only want to use prompt signal in the 1st generation step
prompt_hidden_states = None
# Some generation methods already pass only the last input ID # Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length: if decoder_input_ids.shape[1] > past_length:
...@@ -2722,15 +2809,50 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2722,15 +2809,50 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
# if prompt_cross_attention, if cache_position is None:
# we only want to use prompt signal in the 1st generation step cache_position = torch.arange(
prompt_hidden_states = None past_length, past_length + decoder_input_ids.shape[1], device=decoder_input_ids.device
)
elif use_cache:
cur_len = decoder_input_ids.shape[1]
if prompt_hidden_states is not None and not self.prompt_cross_attention:
# meaning we are in 1st generation step and prompt_hidden_state will be prepended
cur_len += prompt_hidden_states.shape[1]
cache_position = cache_position[-cur_len:]
if decoder_attention_mask is None and prompt_attention_mask is not None:
input = decoder_input_ids.reshape(-1, self.decoder.num_codebooks, decoder_input_ids.shape[-1])
bsz, _, seq_len = input.shape
input_shape = (bsz, seq_len)
past_key_values_length = 0
if cache_position is not None:
past_key_values_length = cache_position[0]
elif past_key_values is not None:
past_key_values_length = past_key_values.get_seq_length()
logger.warning_once(
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour."
)
if past_key_values is None or (
isinstance(past_key_values, EncoderDecoderCache) and past_key_values.get_seq_length() == 0
):
decoder_attention_mask = torch.ones(input_shape, device=self.device, dtype=decoder_input_ids.dtype)
elif prompt_attention_mask is not None:
# In the generation case of `prompt_cross_attention=True`, we need to recreate an attention mask from scratch
# to be able to prepend the prompt attention mask.
# Since we generate token per token, we can recompute the generated length from the information we have.
generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1
decoder_attention_mask = torch.ones(
(input_shape[0], generated_length), device=self.device, dtype=prompt_attention_mask.dtype
)
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
"past_key_values": past_key_values, "past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids.contiguous(),
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
...@@ -2739,6 +2861,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2739,6 +2861,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
"prompt_hidden_states": prompt_hidden_states, "prompt_hidden_states": prompt_hidden_states,
"prompt_attention_mask": prompt_attention_mask, "prompt_attention_mask": prompt_attention_mask,
"use_cache": use_cache, "use_cache": use_cache,
"cache_position": cache_position,
"inputs_embeds": inputs_embeds,
} }
def _prepare_decoder_input_ids_for_generation( def _prepare_decoder_input_ids_for_generation(
...@@ -2786,6 +2910,19 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2786,6 +2910,19 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
) )
model_kwargs["decoder_attention_mask"] = decoder_attention_mask model_kwargs["decoder_attention_mask"] = decoder_attention_mask
if not self.prompt_cross_attention:
prompt_hidden_states = model_kwargs["prompt_hidden_states"]
num_codebooks = self.decoder.num_codebooks
input = decoder_input_ids.reshape(-1, num_codebooks, decoder_input_ids.shape[-1])
inputs_embeds = sum(
[
self.decoder.model.decoder.embed_tokens[codebook](input[:, codebook])
for codebook in range(num_codebooks)
]
)
inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1)
model_kwargs["inputs_embeds"] = inputs_embeds
return decoder_input_ids, model_kwargs return decoder_input_ids, model_kwargs
def _prepare_text_encoder_kwargs_for_generation( def _prepare_text_encoder_kwargs_for_generation(
...@@ -2817,7 +2954,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2817,7 +2954,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
} }
encoder_kwargs["output_attentions"] = generation_config.output_attentions encoder_kwargs["output_attentions"] = generation_config.output_attentions
encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
guidance_scale = generation_config.guidance_scale
# 3. make sure that encoder returns `ModelOutput` # 3. make sure that encoder returns `ModelOutput`
model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name
...@@ -2825,14 +2961,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2825,14 +2961,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
encoder_kwargs[model_input_name] = inputs_tensor encoder_kwargs[model_input_name] = inputs_tensor
last_hidden_state = encoder(**encoder_kwargs).last_hidden_state last_hidden_state = encoder(**encoder_kwargs).last_hidden_state
# for classifier free guidance we need to add a 'null' input to our encoder hidden states
if guidance_scale is not None and guidance_scale > 1:
last_hidden_state = torch.concatenate([last_hidden_state, torch.zeros_like(last_hidden_state)], dim=0)
if "attention_mask" in model_kwargs:
model_kwargs["attention_mask"] = torch.concatenate(
[model_kwargs["attention_mask"], torch.zeros_like(model_kwargs["attention_mask"])], dim=0
)
# we optionnally project last_hidden_state to avoid recomputing every time # we optionnally project last_hidden_state to avoid recomputing every time
encoder_hidden_states = last_hidden_state encoder_hidden_states = last_hidden_state
if ( if (
...@@ -2933,7 +3061,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2933,7 +3061,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
return model_kwargs return model_kwargs
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id).transpose(1, 2) return shift_tokens_right(
labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id
).transpose(1, 2)
def resize_token_embeddings(self, *args, **kwargs): def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError( raise NotImplementedError(
...@@ -2970,6 +3100,81 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2970,6 +3100,81 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
break break
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
def _get_decoder_start_token_id(
self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
) -> int:
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id is not None
else self.generation_config.decoder_start_token_id
)
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
if decoder_start_token_id is not None:
return decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache:
"""
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache.
Returns the resulting cache object.
"""
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
)
if hasattr(self, "_cache"):
cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache
if cache_implementation == "sliding_window":
max_cache_len = min(self.config.sliding_window, max_cache_len)
need_new_cache = (
not hasattr(self, "_cache")
or (not isinstance(cache_to_check, cache_cls))
or cache_to_check.max_batch_size != max_batch_size
or cache_to_check.max_cache_len < max_cache_len
)
if requires_cross_attention_cache and hasattr(self, "_cache"):
need_new_cache = (
need_new_cache
or self._cache.cross_attention_cache.max_cache_len != model_kwargs["encoder_outputs"][0].shape[1]
)
if need_new_cache:
if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype
else:
cache_dtype = self.dtype
cache_kwargs = {
"config": self.config.decoder,
"max_batch_size": max_batch_size,
"max_cache_len": max_cache_len,
"device": self.device,
"dtype": cache_dtype,
}
self._cache = cache_cls(**cache_kwargs)
if requires_cross_attention_cache:
encoder_kwargs = cache_kwargs.copy()
encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1]
config_cross_attention_cache = copy.deepcopy(self.config.decoder)
config_cross_attention_cache.update(
{"num_key_value_heads": self.config.decoder.num_cross_attention_key_value_heads}
)
encoder_kwargs["config"] = config_cross_attention_cache
self._cache = EncoderDecoderCache(self._cache, cache_cls(**encoder_kwargs))
else:
self._cache.reset()
return self._cache
def freeze_encoders(self, freeze_text_encoder=True): def freeze_encoders(self, freeze_text_encoder=True):
if freeze_text_encoder: if freeze_text_encoder:
for param in self.text_encoder.parameters(): for param in self.text_encoder.parameters():
...@@ -3070,46 +3275,28 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -3070,46 +3275,28 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None: kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
# 3. Define model inputs # 3. Define model inputs
# inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed
# otherwise model_input_name is None
# all model-specific keyword inputs are removed from `model_kwargs`
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs inputs, generation_config.bos_token_id, model_kwargs
) )
batch_size = inputs_tensor.shape[0] batch_size = inputs_tensor.shape[0]
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device)
# 4. Define other model kwargs # 4. Define other model kwargs
model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["use_cache"] = generation_config.use_cache
model_kwargs["guidance_scale"] = generation_config.guidance_scale
requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
) )
if "encoder_outputs" not in model_kwargs: if "encoder_outputs" not in model_kwargs:
# encoder_outputs are created and added to `model_kwargs` # encoder_outputs are created and added to `model_kwargs`
model_kwargs = self._prepare_text_encoder_kwargs_for_generation( model_kwargs = self._prepare_text_encoder_kwargs_for_generation(
inputs_tensor, inputs_tensor, model_kwargs, model_input_name, generation_config
model_kwargs,
model_input_name,
generation_config,
) )
if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs: if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs:
...@@ -3130,46 +3317,80 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -3130,46 +3317,80 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
batch_size=batch_size, batch_size=batch_size,
model_input_name=model_input_name, model_input_name=model_input_name,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
decoder_start_token_id=generation_config.decoder_start_token_id, decoder_start_token_id=generation_config._decoder_start_token_tensor,
bos_token_id=generation_config.bos_token_id, bos_token_id=generation_config._bos_token_tensor,
device=inputs_tensor.device, device=inputs_tensor.device,
) )
# 6. Prepare `max_length` depending on other stopping criteria. # 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1] input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None: has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
logger.warning( generation_config = self._prepare_generated_length(
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) " generation_config=generation_config,
"to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation." has_default_max_length=has_default_max_length,
) has_default_min_length=has_default_min_length,
elif generation_config.max_new_tokens is not None: model_input_name=model_input_name,
if not has_default_max_length: inputs_tensor=inputs_tensor,
logger.warning( input_ids_length=input_ids_length,
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" )
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None:
raise ValueError( raise ValueError(
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a "
f" the maximum length ({generation_config.max_length})" "Cache object) is unsupported. Please use only one of the two."
) )
if input_ids_seq_length >= generation_config.max_length: elif generation_config.cache_implementation is not None:
logger.warning( if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to" if generation_config.cache_implementation == "static" and not self._supports_static_cache:
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" raise ValueError(
" increasing `max_new_tokens`." "This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
)
if not self.prompt_cross_attention:
# when we prepend prompt_hidden_state to inputs_embeds, max_cache_len needs to be actualised
# generation_config.max_length has already been increased by input_ids_length which is
# already counted in input_embeds_seq_length so we remove it
input_embeds_seq_length = model_kwargs["inputs_embeds"].shape[1]
max_cache_len = generation_config.max_length + input_embeds_seq_length - input_ids_length
else:
max_cache_len = self.generation_config.max_length
model_kwargs["past_key_values"] = self._get_cache(
generation_config.cache_implementation,
getattr(generation_config, "num_beams", 1) * batch_size,
max_cache_len,
model_kwargs,
)
elif generation_config.cache_implementation == "quantized":
raise ValueError(
"This model does not support the quantized cache. If you want your model to support quantized "
"cache, please open an issue on the Parler-TTS repository https://github.com/huggingface/parler-tts"
)
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
past = model_kwargs.get("past_key_values", None)
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
) )
if past is None:
model_kwargs["past_key_values"] = (
DynamicCache()
if not requires_cross_attention_cache
else EncoderDecoderCache(DynamicCache(), DynamicCache())
)
elif isinstance(past, tuple):
model_kwargs["past_key_values"] = (
DynamicCache.from_legacy_cache(past)
if not requires_cross_attention_cache
else EncoderDecoderCache.from_legacy_cache(past)
)
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler-TTS) # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler-TTS)
input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
input_ids, input_ids,
bos_token_id=generation_config.bos_token_id, bos_token_id=generation_config._bos_token_tensor,
pad_token_id=generation_config.pad_token_id, pad_token_id=generation_config._pad_token_tensor,
max_length=generation_config.max_length, max_length=generation_config.max_length,
) )
# stash the delay mask so that we don't have to recompute in each forward pass # stash the delay mask so that we don't have to recompute in each forward pass
...@@ -3191,21 +3412,17 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -3191,21 +3412,17 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
and generation_config.do_sample is True and generation_config.do_sample is True
) )
# 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG) # 8. prepare distribution pre_processing samplers
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
generation_config.guidance_scale = None
# 9. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor( logits_processor = self._get_logits_processor(
generation_config=generation_config, generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length, input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor, encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None, prefix_allowed_tokens_fn=None,
logits_processor=logits_processor, logits_processor=logits_processor,
device=input_ids.device,
) )
# 10. prepare stopping criteria # 9. prepare stopping criteria
stopping_criteria = self._get_stopping_criteria( stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria generation_config=generation_config, stopping_criteria=stopping_criteria
) )
...@@ -3217,8 +3434,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -3217,8 +3434,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
f"but is {generation_config.num_return_sequences}." f"but is {generation_config.num_return_sequences}."
) )
# 11. run greedy search # 10. run greedy search
outputs = self._greedy_search( outputs = self._sample(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
...@@ -3229,8 +3446,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -3229,8 +3446,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
) )
elif is_sample_gen_mode: elif is_sample_gen_mode:
# 11. prepare logits warper # 10. prepare logits warper
logits_warper = self._get_logits_warper(generation_config) logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
# expand input_ids with `num_return_sequences` additional sequences per batch # expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
...@@ -3240,7 +3457,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -3240,7 +3457,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
**model_kwargs, **model_kwargs,
) )
# 12. run sample # 11. run sample
outputs = self._sample( outputs = self._sample(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
...@@ -3269,8 +3486,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -3269,8 +3486,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
# Revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask # Revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
_, mask = self.decoder.build_delay_pattern_mask( _, mask = self.decoder.build_delay_pattern_mask(
input_ids, input_ids,
bos_token_id=generation_config.bos_token_id, bos_token_id=generation_config._bos_token_tensor,
pad_token_id=generation_config.pad_token_id, pad_token_id=generation_config._pad_token_tensor,
max_length=output_ids.shape[1], max_length=output_ids.shape[1],
) )
......
...@@ -13,11 +13,12 @@ ...@@ -13,11 +13,12 @@
# limitations under the License. # limitations under the License.
import os import os
import setuptools import setuptools
_deps = [ _deps = [
"transformers>=4.39.0,<4.41.0", "transformers>=4.43.0,<=4.43.3",
"torch", "torch",
"sentencepiece", "sentencepiece",
"descript-audio-codec", "descript-audio-codec",
...@@ -60,7 +61,7 @@ setuptools.setup( ...@@ -60,7 +61,7 @@ setuptools.setup(
packages=setuptools.find_packages(), packages=setuptools.find_packages(),
install_requires=_deps, install_requires=_deps,
extras_require={ extras_require={
"dev": [_extras_dev_deps], "dev": _extras_dev_deps,
"train": [_extras_training_deps], "train": _extras_training_deps,
}, },
) )
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Union, Set from typing import Dict, List, Optional, Set, Union
import torch
import numpy as np
import datasets import datasets
from datasets import load_dataset, Dataset, IterableDataset, interleave_datasets, concatenate_datasets import numpy as np
from transformers import AutoFeatureExtractor, AutoTokenizer import torch
from tqdm import tqdm
from accelerate import Accelerator from accelerate import Accelerator
from datasets import Dataset, IterableDataset, concatenate_datasets, interleave_datasets, load_dataset
from tqdm import tqdm
from transformers import AutoFeatureExtractor, AutoTokenizer
@dataclass @dataclass
......
import os import os
import re import re
import shutil import shutil
from pathlib import Path
from dataclasses import field from dataclasses import field
from pathlib import Path
from typing import Dict, List from typing import Dict, List
import torch import torch
from datasets import concatenate_datasets, load_from_disk
from wandb import Audio from wandb import Audio
from datasets import load_from_disk, concatenate_datasets from datasets import load_from_disk, concatenate_datasets
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment