Unverified Commit 862f8418 authored by eustlb's avatar eustlb Committed by GitHub
Browse files

Add static cache (#89)



* add rope

* don't include padding in rope

* possibly use cross-attn for prompt

* fix rope

* fix cross-attn

* fix self-attn

* fix dummy model

* clean-up rope

* first gqa implementation

* fix wer eval

* feat: add flash attention and spda

* chore: add README for flash attention

* chore: add benchmark script

* chore: add benchmark attention approach

* multi node and fix wer and fix compile

* Update modeling_parler_tts.py

* fix FA2, SDPA and add cross-attn MHA and attention type forcing

* better cross_attention key values number of heads default + add training arguments for attn implementation

* fix audio padding when torch compile or pad_to_max_length=True

* correct multi node

* make rope faster

* fix encoder sdpa

* fix training with cross attention + with FAZ

* use fp32 as default model dtype + fix generation when using FA2 with autocast

* remove redundant passes in generate + clean and fix attentions

* fix edge case in WER evaluation when longform generation

* better multi-node mapping and saving / add eval dataloader num workers

* remove old benchmarks

* faster audio encoding + checkpointing + fix generation step

* unpin trfms

* remove CFG

* imports and constants
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* attention modifications to handle static cach
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* decoder layer modification to handle static cache
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* ParlerTTSPreTrainedModel modifs to handle static cache
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* ParlerTTSDecoder modifs to handle static cache
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* ParlerTTSModel + ParlerTTSForCausalLM modfis
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* ParlerTTSForConditionalGeneration modifs
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* decoder_attention_mask for static cache
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* create inputs_embeds early to have a good cache initialization
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* _get_cache method
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* init the cache
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* ensure good device
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* pin tfrms version
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* fix attention_mask FA2

* remove unnecessary method

* Update parler_tts/modeling_parler_tts.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update parler_tts/modeling_parler_tts.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* remove unnecessary imports

* replace the hardcoded cache_position with a more elegant approach

* make style

* unpin transformers

* pin transformers

* pin torch

* refactor + unpin torch

* Update parler_tts/modeling_parler_tts.py
Co-authored-by: default avatarYoach Lacombe <52246514+ylacombe@users.noreply.github.com>

* update training script to match 11b209e1



* Update parler_tts/modeling_parler_tts.py
Co-authored-by: default avatarYoach Lacombe <52246514+ylacombe@users.noreply.github.com>

* ensure compatibility with trfms 4.43.3, changes taken from #31980 on trfms

* fix input_ids_length

* warning full attention mask creation

* changes for training compatibility

---------
Co-authored-by: default avatarsanchit-gandhi <sanchit@huggingface.co>
Co-authored-by: default avatarYoach Lacombe <yoach.lacombe@gmail.com>
Co-authored-by: default avatarYoach Lacombe <52246514+ylacombe@users.noreply.github.com>
Co-authored-by: default avatarsang-nguyen-ts <sang.nguyen@trustingsocial.com>
Co-authored-by: yoach@huggingface.co <Yoach Lacombe>
Co-authored-by: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
parent 11b209e1
import gradio as gr
import torch
from transformers import AutoFeatureExtractor, AutoTokenizer, set_seed
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
device = "cuda:0" if torch.cuda.is_available() else "cpu"
......@@ -57,7 +58,7 @@ css = """
background-color: #000000;
justify-content: center;
align-items: center;
border-radius: 9999px !important;
border-radius: 9999px !important;
width: 13rem;
margin-top: 10px;
margin-left: auto;
......
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import AutoConfig
import os
import argparse
import os
from transformers import AutoConfig
from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration
if __name__ == "__main__":
......
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import AutoConfig
import os
import argparse
import os
from transformers import AutoConfig
from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration
if __name__ == "__main__":
parser = argparse.ArgumentParser()
......
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import AutoConfig
import os
import argparse
import os
from transformers import AutoConfig
from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration
if __name__ == "__main__":
......
import dac
from transformers import AutoConfig, AutoModel, EncodecFeatureExtractor
from parler_tts import DACConfig, DACModel
from transformers import AutoConfig, AutoModel
from transformers import EncodecFeatureExtractor
......
from transformers import AutoFeatureExtractor, AutoTokenizer
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor
path = "TODO"
repo_id = "parler_tts_600M"
......
__version__ = "0.1"
from transformers import AutoConfig, AutoModel
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
from .dac_wrapper import DACConfig, DACModel
from .modeling_parler_tts import (
ParlerTTSForCausalLM,
ParlerTTSForConditionalGeneration,
......@@ -9,8 +12,6 @@ from .modeling_parler_tts import (
build_delay_pattern_mask,
)
from .dac_wrapper import DACConfig, DACModel
from transformers import AutoConfig, AutoModel
AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
from transformers import PretrainedConfig
from typing import List
class DACConfig(PretrainedConfig):
......
import torch
from dac.model import DAC
from transformers import PreTrainedModel
from transformers.models.encodec.modeling_encodec import EncodecEncoderOutput, EncodecDecoderOutput
from .configuration_dac import DACConfig
from transformers.models.encodec.modeling_encodec import EncodecDecoderOutput, EncodecEncoderOutput
from dac.model import DAC
from .configuration_dac import DACConfig
# model doesn't support batching yet
......@@ -134,4 +133,4 @@ class DACModel(PreTrainedModel):
return EncodecDecoderOutput(audio_values)
def forward(self, tensor):
raise ValueError(f"`DACModel.forward` not implemented yet")
raise ValueError("`DACModel.forward` not implemented yet")
......@@ -18,21 +18,28 @@ import inspect
import math
import random
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, List
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModel, AutoModelForTextEncoding
from transformers.activations import ACT2FN
from transformers.cache_utils import (
Cache,
DynamicCache,
EncoderDecoderCache,
SlidingWindowCache,
StaticCache,
)
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_attention_mask,
_prepare_4d_causal_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import (
BaseModelOutput,
......@@ -48,13 +55,11 @@ from transformers.utils import (
logging,
replace_return_docstrings,
)
import torch.nn.functional as F
from transformers.utils.import_utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
from transformers.utils.import_utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
from .dac_wrapper import DACConfig, DACModel
from transformers import AutoConfig, AutoModel
AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
......@@ -80,6 +85,9 @@ MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache}
def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
"""Apply a delay pattern mask to the decoder input ids, only preserving predictions where
the mask is set to -1, and otherwise setting to the value detailed in the mask."""
......@@ -176,14 +184,10 @@ class ParlerTTSUnconditionalInput(ModelOutput):
attention_mask (`torch.LongTensor`) of shape `(batch_size, sequence_length)`, *optional*):
Encoder attention mask to avoid performing attention on padding token indices. Mask values selected in `[0,
1]`: 1 for tokens that are **not masked**, 0 for tokens that are **masked**.
guidance_scale (`float`, *optional*):
Guidance scale for classifier free guidance, setting the balance between the conditional logits (predicted
from the prompts) and the unconditional logits (predicted without prompts).
"""
encoder_outputs: Tuple[torch.FloatTensor] = None
attention_mask: torch.LongTensor = None
guidance_scale: float = None
# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right
......@@ -244,7 +248,7 @@ class ParlerTTSSinusoidalPositionalEmbedding(nn.Module):
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
bsz, seq_len, _ = input_ids.size()
# Create the position ids from the input token ids.
position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device)
position_ids = torch.arange(seq_len, device=input_ids.device) + past_key_values_length
# expand embeddings if needed
if seq_len > self.weights.size(0):
self.make_weights(seq_len + self.offset, self.embedding_dim)
......@@ -331,6 +335,7 @@ class ParlerTTSAttention(nn.Module):
bias: bool = True,
is_causal: bool = False,
rope_embeddings: bool = False,
layer_idx: Optional[int] = None,
config: Optional[ParlerTTSDecoderConfig] = None,
):
super().__init__()
......@@ -351,6 +356,14 @@ class ParlerTTSAttention(nn.Module):
self.is_decoder = is_decoder
self.is_causal = is_causal
if layer_idx is None and is_decoder:
logger.warning_once(
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
"will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.layer_idx = layer_idx
self.k_proj = nn.Linear(embed_dim, self.num_key_value_heads * self.head_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, self.num_key_value_heads * self.head_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
......@@ -368,40 +381,42 @@ class ParlerTTSAttention(nn.Module):
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[EncoderDecoderCache] = None,
attention_mask: Optional[torch.Tensor] = None,
cos: Optional[torch.LongTensor] = None,
sin: Optional[torch.LongTensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len = hidden_states.shape[:2]
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
query_states = self._shape_query(query_states, tgt_len, bsz)
if self.rope_embeddings:
query_states = apply_rotary_pos_emb(query_states, cos, sin)
current_states = key_value_states if key_value_states is not None else hidden_states
if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
past_key_value.is_updated[self.layer_idx] = True
past_key_value = past_key_value.cross_attention_cache
else:
past_key_value = past_key_value.self_attention_cache
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# use key_value_states if cross attention
current_states = key_value_states if key_value_states is not None else hidden_states
if is_cross_attention and past_key_value and is_updated:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx]
else:
key_states = self._shape_key_value(self.k_proj(current_states), -1, bsz)
value_states = self._shape_key_value(self.v_proj(current_states), -1, bsz)
......@@ -410,45 +425,21 @@ class ParlerTTSAttention(nn.Module):
# cached key states already have rope applied - only apply to new state
key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = query_states.reshape(*proj_shape)
key_states = key_states.reshape(*proj_shape)
value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
......@@ -458,35 +449,25 @@ class ParlerTTSAttention(nn.Module):
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_probs, value_states)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
return attn_output, attn_weights, past_key_value
def _get_unpad_data(attention_mask):
......@@ -522,64 +503,66 @@ class ParlerTTSFlashAttention2(ParlerTTSAttention):
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[EncoderDecoderCache] = None,
attention_mask: Optional[torch.Tensor] = None,
cos: Optional[torch.LongTensor] = None,
sin: Optional[torch.LongTensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# ParlerTTSFlashAttention2 attention does not support output_attentions
if output_attentions:
raise ValueError("ParlerTTSFlashAttention2 attention does not support output_attentions")
if isinstance(past_key_value, StaticCache):
raise ValueError(
"The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. "
"Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers"
)
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, q_len = hidden_states.shape[:2]
bsz, tgt_len = hidden_states.shape[:2]
# get query proj
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
query_states = self.q_proj(hidden_states).view(bsz, tgt_len, self.num_heads, self.head_dim)
if self.rope_embeddings:
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
current_states = key_value_states if key_value_states is not None else hidden_states
if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
past_key_value.is_updated[self.layer_idx] = True
past_key_value = past_key_value.cross_attention_cache
else:
past_key_value = past_key_value.self_attention_cache
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# use key_value_states if cross attention
current_states = key_value_states if key_value_states is not None else hidden_states
if is_cross_attention and past_key_value and is_updated:
# reuse k,v, cross_attentions
key_states = past_key_value[0].transpose(1, 2)
value_states = past_key_value[1].transpose(1, 2)
key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx]
else:
key_states = self.k_proj(current_states).view(bsz, -1, self.num_key_value_heads, self.head_dim)
value_states = self.v_proj(current_states).view(bsz, -1, self.num_key_value_heads, self.head_dim)
key_states = self._shape_key_value(self.k_proj(current_states), -1, bsz)
value_states = self._shape_key_value(self.v_proj(current_states), -1, bsz)
if not is_cross_attention:
if not is_cross_attention and self.rope_embeddings:
# cached key states already have rope applied - only apply to new state
key_states = (
apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2) if self.rope_embeddings else key_states
)
key_states = apply_rotary_pos_emb(key_states, cos, sin)
if past_key_value is not None:
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
# # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
# # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
......@@ -607,10 +590,10 @@ class ParlerTTSFlashAttention2(ParlerTTSAttention):
value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
query_states, key_states, value_states, attention_mask, tgt_len, dropout=self.dropout
)
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = attn_output.reshape(bsz, tgt_len, -1)
attn_output = self.out_proj(attn_output)
if not output_attentions:
......@@ -723,12 +706,13 @@ class ParlerTTSSdpaAttention(ParlerTTSAttention):
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[EncoderDecoderCache] = None,
attention_mask: Optional[torch.Tensor] = None,
cos: Optional[torch.LongTensor] = None,
sin: Optional[torch.LongTensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
if output_attentions or layer_head_mask is not None:
......@@ -744,6 +728,7 @@ class ParlerTTSSdpaAttention(ParlerTTSAttention):
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
cache_position=cache_position,
)
# if key_value_states are provided this layer is used as a cross-attention layer
......@@ -759,39 +744,39 @@ class ParlerTTSSdpaAttention(ParlerTTSAttention):
if self.rope_embeddings:
query_states = apply_rotary_pos_emb(query_states, cos, sin)
current_states = key_value_states if key_value_states is not None else hidden_states
if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
past_key_value.is_updated[self.layer_idx] = True
past_key_value = past_key_value.cross_attention_cache
else:
past_key_value = past_key_value.self_attention_cache
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# use key_value_states if cross attention
current_states = key_value_states if key_value_states is not None else hidden_states
if is_cross_attention and past_key_value and is_updated:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx]
else:
key_states = self._shape_key_value(self.k_proj(current_states), -1, bsz)
value_states = self._shape_key_value(self.v_proj(current_states), -1, bsz)
if not is_cross_attention:
if not is_cross_attention and self.rope_embeddings:
# cached key states already have rope applied - only apply to new state
key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states
key_states = apply_rotary_pos_emb(key_states, cos, sin)
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
causal_mask = attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
......@@ -800,7 +785,7 @@ class ParlerTTSSdpaAttention(ParlerTTSAttention):
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
......@@ -808,7 +793,7 @@ class ParlerTTSSdpaAttention(ParlerTTSAttention):
query_states,
key_states,
value_states,
attn_mask=attention_mask,
attn_mask=causal_mask,
dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal=is_causal,
......@@ -839,7 +824,7 @@ PARLERTTS_ATTENTION_CLASSES = {
class ParlerTTSDecoderLayer(nn.Module):
def __init__(self, config: ParlerTTSDecoderConfig):
def __init__(self, config: ParlerTTSDecoderConfig, layer_idx: int = None):
super().__init__()
self.embed_dim = config.hidden_size
......@@ -852,6 +837,7 @@ class ParlerTTSDecoderLayer(nn.Module):
is_causal=True,
bias=False,
rope_embeddings=config.rope_embeddings,
layer_idx=layer_idx,
config=config,
)
self.dropout = config.dropout
......@@ -872,6 +858,7 @@ class ParlerTTSDecoderLayer(nn.Module):
is_decoder=True,
bias=False,
rope_embeddings=config.rope_embeddings,
layer_idx=layer_idx,
config=config,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
......@@ -889,9 +876,10 @@ class ParlerTTSDecoderLayer(nn.Module):
encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[EncoderDecoderCache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
cache_position: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
"""
Args:
......@@ -918,31 +906,24 @@ class ParlerTTSDecoderLayer(nn.Module):
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
past_key_value=past_key_value,
attention_mask=attention_mask,
cos=cos,
sin=sin,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
cache_position=cache_position,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
# Cross-Attention Block
cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
......@@ -950,14 +931,14 @@ class ParlerTTSDecoderLayer(nn.Module):
cos=cos,
sin=sin,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value,
past_key_value=past_key_value,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
# add cross-attn to positions 3,4 of present_key_value tuple
present_key_value = present_key_value + cross_attn_present_key_value
# add cross-attn to positions 1 of present_key_value tuple
present_key_value = (present_key_value, cross_attn_present_key_value)
# Fully Connected
residual = hidden_states
......@@ -992,6 +973,8 @@ class ParlerTTSPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_no_split_modules = ["ParlerTTSDecoderLayer", "ParlerTTSAttention"]
_supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_factor
......@@ -1088,14 +1071,18 @@ MUSICGEN_INPUTS_DOCSTRING = r"""
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
TODO: it's passed through enc_to_dec_proj and optionnally we concat the prompt hidden states in certain cases.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
past_key_values (`EncoderDecoderCache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are
four sets of pre-computed hidden-states: key and values states in the self-attention blocks (2) and
in the cross-attention blocks (2). The `past_key_values` are returned when `use_cache=True` is passed or
when `config.use_cache=True`
Two formats are allowed:
- An [`~cache_utils.EncoderDecoderCache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
......@@ -1141,6 +1128,9 @@ MUSICGEN_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the cache
in the correct position and to infer the complete sequence length.
"""
MUSICGEN_DECODER_INPUTS_DOCSTRING = r"""
......@@ -1263,8 +1253,9 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
)
self.layers = nn.ModuleList([ParlerTTSDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layers = nn.ModuleList(
[ParlerTTSDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.layer_norm = nn.LayerNorm(config.hidden_size)
self.attn_implementation = config._attn_implementation
encoder_attn_implementation = config._attn_implementation
......@@ -1301,6 +1292,7 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position=None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
......@@ -1323,16 +1315,44 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)])
prepended_sequence_length = 0
# if prompt_hidden_states, fuse to inputs_embeds and update input shape
if prompt_hidden_states is not None:
prepended_sequence_length = prompt_hidden_states.shape[-2]
inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1)
return_legacy_cache = False
return_self_attention_cache = False
if use_cache or past_key_values is not None:
if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
return_self_attention_cache = True
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
elif not isinstance(past_key_values, EncoderDecoderCache):
return_legacy_cache = True
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. "
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
)
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
past_key_values_length = 0
if cache_position is not None:
past_key_values_length = cache_position[0]
elif past_key_values is not None:
past_key_values_length = past_key_values.get_seq_length()
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + input_shape[1] + prepended_sequence_length, device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# NOTE: 1. As it is, the masked ids from the prompt will still count in the positions embeddings
# NOTE: 2. we want to concatenate the prompt attention mask and the decoder attention mask
# i.i.f `prompt_cross_attention=False`. ParlerTTSForConditionalGeneration's taking care of setting
......@@ -1343,7 +1363,7 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
logger.warning_once(
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour."
)
if past_key_values is None:
if past_key_values_length == 0:
attention_mask = torch.cat(
[
prompt_attention_mask,
......@@ -1401,23 +1421,14 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
if self.attn_implementation == "flash_attention_2":
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self.attn_implementation == "sdpa" and head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
)
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values.self_attention_cache if past_key_values is not None else None,
output_attentions,
)
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if self.encoder_attn_implementation == "flash_attention_2":
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
......@@ -1442,12 +1453,10 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
"`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
......@@ -1465,13 +1474,11 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop):
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.forward,
hidden_states,
attention_mask,
causal_mask,
cos,
sin,
encoder_hidden_states,
......@@ -1481,11 +1488,12 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
None,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
attention_mask=causal_mask,
cos=cos,
sin=sin,
encoder_hidden_states=encoder_hidden_states,
......@@ -1494,15 +1502,13 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
past_key_value=past_key_values if use_cache else None,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
......@@ -1515,7 +1521,11 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
next_cache = past_key_values if use_cache else None
if return_self_attention_cache:
next_cache = past_key_values.self_attention_cache
if return_legacy_cache:
next_cache = past_key_values.to_legacy_cache()
if not return_dict:
return tuple(
v
......@@ -1530,6 +1540,87 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
cross_attentions=all_cross_attentions,
)
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@add_start_docstrings(
"The bare ParlerTTS decoder model outputting raw hidden-states without any specific head on top.",
......@@ -1564,12 +1655,13 @@ class ParlerTTSModel(ParlerTTSPreTrainedModel):
prompt_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
......@@ -1595,6 +1687,7 @@ class ParlerTTSModel(ParlerTTSPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
if not return_dict:
......@@ -1665,6 +1758,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
......@@ -1692,6 +1786,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
......@@ -1752,7 +1847,8 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
past_key_values=None,
use_cache=True,
delay_pattern_mask=None,
guidance_scale=None,
cache_position=None,
inputs_embeds=None,
**kwargs,
):
if delay_pattern_mask is None:
......@@ -1766,22 +1862,11 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
# apply the delay pattern mask
input_ids = self.apply_delay_pattern_mask(input_ids, delay_pattern_mask)
if guidance_scale is not None and guidance_scale > 1:
# for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these
# before sampling)
input_ids = input_ids.repeat((2, 1))
if attention_mask is not None:
attention_mask = attention_mask.repeat((2, 1))
if prompt_hidden_states is not None:
prompt_hidden_states = torch.concatenate(
[prompt_hidden_states, torch.zeros_like(prompt_hidden_states)], dim=0
)
if prompt_attention_mask is not None:
prompt_attention_mask = torch.concatenate(
[prompt_attention_mask, torch.zeros_like(prompt_attention_mask)], dim=0
)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
......@@ -1798,7 +1883,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
prompt_hidden_states = None
return {
"input_ids": input_ids,
"input_ids": input_ids.contiguous(), # `contiguous()` needed for compilation use cases
"attention_mask": attention_mask,
"position_ids": position_ids,
"encoder_hidden_states": encoder_hidden_states,
......@@ -1809,6 +1894,8 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
"cross_attn_head_mask": cross_attn_head_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
"cache_position": cache_position,
"inputs_embeds": inputs_embeds,
}
# Ignore copy
......@@ -1948,10 +2035,11 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = input_ids.shape[0] // self.num_codebooks
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
# 4. Define other model kwargs
model_kwargs["use_cache"] = generation_config.use_cache
model_kwargs["guidance_scale"] = generation_config.guidance_scale
requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
......@@ -2016,21 +2104,17 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
and generation_config.do_sample is True
)
# 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG)
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
generation_config.guidance_scale = None
# 9. prepare distribution pre_processing samplers
# 8. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
device=input_ids.device,
)
# 10. prepare stopping criteria
# 9. prepare stopping criteria
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
......@@ -2042,8 +2126,8 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
f"but is {generation_config.num_return_sequences}."
)
# 11. run greedy search
outputs = self._greedy_search(
# 10. run greedy search
outputs = self._sample(
input_ids,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
......@@ -2054,8 +2138,8 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
)
elif is_sample_gen_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
# 10. prepare logits warper
logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
# expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
......@@ -2064,7 +2148,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
**model_kwargs,
)
# 12. run sample
# 11. run sample
outputs = self._sample(
input_ids,
logits_processor=logits_processor,
......@@ -2120,6 +2204,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
def __init__(
self,
......@@ -2498,7 +2584,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
prompt_input_ids: Optional[torch.FloatTensor] = None,
......@@ -2510,6 +2596,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, Seq2SeqLMOutput]:
r"""
......@@ -2653,6 +2740,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
past_key_values=past_key_values,
return_dict=return_dict,
labels=labels,
cache_position=cache_position,
**kwargs_decoder,
)
......@@ -2685,7 +2773,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
use_cache=None,
encoder_outputs=None,
decoder_delay_pattern_mask=None,
guidance_scale=None,
cache_position=None,
inputs_embeds=None,
**kwargs,
):
if decoder_delay_pattern_mask is None:
......@@ -2699,19 +2788,17 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
# apply the delay pattern mask
decoder_input_ids = self.decoder.apply_delay_pattern_mask(decoder_input_ids, decoder_delay_pattern_mask)
if guidance_scale is not None and guidance_scale > 1:
# for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these
# before sampling)
decoder_input_ids = decoder_input_ids.repeat((2, 1))
if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.repeat((2, 1))
if prompt_hidden_states is not None:
prompt_hidden_states = prompt_hidden_states.repeat((2, 1, 1))
if prompt_attention_mask is not None:
prompt_attention_mask = prompt_attention_mask.repeat((2, 1))
past_length = 0
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
if isinstance(past_key_values, EncoderDecoderCache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
if past_key_values.get_seq_length() > 0:
# we only want to use prompt signal in the 1st generation step
prompt_hidden_states = None
else:
past_length = past_key_values[0][0].shape[2]
# we only want to use prompt signal in the 1st generation step
prompt_hidden_states = None
# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
......@@ -2722,15 +2809,50 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
# if prompt_cross_attention,
# we only want to use prompt signal in the 1st generation step
prompt_hidden_states = None
if cache_position is None:
cache_position = torch.arange(
past_length, past_length + decoder_input_ids.shape[1], device=decoder_input_ids.device
)
elif use_cache:
cur_len = decoder_input_ids.shape[1]
if prompt_hidden_states is not None and not self.prompt_cross_attention:
# meaning we are in 1st generation step and prompt_hidden_state will be prepended
cur_len += prompt_hidden_states.shape[1]
cache_position = cache_position[-cur_len:]
if decoder_attention_mask is None and prompt_attention_mask is not None:
input = decoder_input_ids.reshape(-1, self.decoder.num_codebooks, decoder_input_ids.shape[-1])
bsz, _, seq_len = input.shape
input_shape = (bsz, seq_len)
past_key_values_length = 0
if cache_position is not None:
past_key_values_length = cache_position[0]
elif past_key_values is not None:
past_key_values_length = past_key_values.get_seq_length()
logger.warning_once(
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour."
)
if past_key_values is None or (
isinstance(past_key_values, EncoderDecoderCache) and past_key_values.get_seq_length() == 0
):
decoder_attention_mask = torch.ones(input_shape, device=self.device, dtype=decoder_input_ids.dtype)
elif prompt_attention_mask is not None:
# In the generation case of `prompt_cross_attention=True`, we need to recreate an attention mask from scratch
# to be able to prepend the prompt attention mask.
# Since we generate token per token, we can recompute the generated length from the information we have.
generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1
decoder_attention_mask = torch.ones(
(input_shape[0], generated_length), device=self.device, dtype=prompt_attention_mask.dtype
)
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"decoder_input_ids": decoder_input_ids.contiguous(),
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
......@@ -2739,6 +2861,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
"prompt_hidden_states": prompt_hidden_states,
"prompt_attention_mask": prompt_attention_mask,
"use_cache": use_cache,
"cache_position": cache_position,
"inputs_embeds": inputs_embeds,
}
def _prepare_decoder_input_ids_for_generation(
......@@ -2786,6 +2910,19 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
)
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
if not self.prompt_cross_attention:
prompt_hidden_states = model_kwargs["prompt_hidden_states"]
num_codebooks = self.decoder.num_codebooks
input = decoder_input_ids.reshape(-1, num_codebooks, decoder_input_ids.shape[-1])
inputs_embeds = sum(
[
self.decoder.model.decoder.embed_tokens[codebook](input[:, codebook])
for codebook in range(num_codebooks)
]
)
inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1)
model_kwargs["inputs_embeds"] = inputs_embeds
return decoder_input_ids, model_kwargs
def _prepare_text_encoder_kwargs_for_generation(
......@@ -2817,7 +2954,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
}
encoder_kwargs["output_attentions"] = generation_config.output_attentions
encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
guidance_scale = generation_config.guidance_scale
# 3. make sure that encoder returns `ModelOutput`
model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name
......@@ -2825,14 +2961,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
encoder_kwargs[model_input_name] = inputs_tensor
last_hidden_state = encoder(**encoder_kwargs).last_hidden_state
# for classifier free guidance we need to add a 'null' input to our encoder hidden states
if guidance_scale is not None and guidance_scale > 1:
last_hidden_state = torch.concatenate([last_hidden_state, torch.zeros_like(last_hidden_state)], dim=0)
if "attention_mask" in model_kwargs:
model_kwargs["attention_mask"] = torch.concatenate(
[model_kwargs["attention_mask"], torch.zeros_like(model_kwargs["attention_mask"])], dim=0
)
# we optionnally project last_hidden_state to avoid recomputing every time
encoder_hidden_states = last_hidden_state
if (
......@@ -2933,7 +3061,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
return model_kwargs
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id).transpose(1, 2)
return shift_tokens_right(
labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id
).transpose(1, 2)
def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
......@@ -2970,6 +3100,81 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
break
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
def _get_decoder_start_token_id(
self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
) -> int:
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id is not None
else self.generation_config.decoder_start_token_id
)
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
if decoder_start_token_id is not None:
return decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache:
"""
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache.
Returns the resulting cache object.
"""
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
)
if hasattr(self, "_cache"):
cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache
if cache_implementation == "sliding_window":
max_cache_len = min(self.config.sliding_window, max_cache_len)
need_new_cache = (
not hasattr(self, "_cache")
or (not isinstance(cache_to_check, cache_cls))
or cache_to_check.max_batch_size != max_batch_size
or cache_to_check.max_cache_len < max_cache_len
)
if requires_cross_attention_cache and hasattr(self, "_cache"):
need_new_cache = (
need_new_cache
or self._cache.cross_attention_cache.max_cache_len != model_kwargs["encoder_outputs"][0].shape[1]
)
if need_new_cache:
if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype
else:
cache_dtype = self.dtype
cache_kwargs = {
"config": self.config.decoder,
"max_batch_size": max_batch_size,
"max_cache_len": max_cache_len,
"device": self.device,
"dtype": cache_dtype,
}
self._cache = cache_cls(**cache_kwargs)
if requires_cross_attention_cache:
encoder_kwargs = cache_kwargs.copy()
encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1]
config_cross_attention_cache = copy.deepcopy(self.config.decoder)
config_cross_attention_cache.update(
{"num_key_value_heads": self.config.decoder.num_cross_attention_key_value_heads}
)
encoder_kwargs["config"] = config_cross_attention_cache
self._cache = EncoderDecoderCache(self._cache, cache_cls(**encoder_kwargs))
else:
self._cache.reset()
return self._cache
def freeze_encoders(self, freeze_text_encoder=True):
if freeze_text_encoder:
for param in self.text_encoder.parameters():
......@@ -3070,46 +3275,28 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
requires_attention_mask = "encoder_outputs" not in model_kwargs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
# 3. Define model inputs
# inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed
# otherwise model_input_name is None
# all model-specific keyword inputs are removed from `model_kwargs`
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device)
# 4. Define other model kwargs
model_kwargs["use_cache"] = generation_config.use_cache
model_kwargs["guidance_scale"] = generation_config.guidance_scale
requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
)
if "encoder_outputs" not in model_kwargs:
# encoder_outputs are created and added to `model_kwargs`
model_kwargs = self._prepare_text_encoder_kwargs_for_generation(
inputs_tensor,
model_kwargs,
model_input_name,
generation_config,
inputs_tensor, model_kwargs, model_input_name, generation_config
)
if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs:
......@@ -3130,46 +3317,80 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
batch_size=batch_size,
model_input_name=model_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config.decoder_start_token_id,
bos_token_id=generation_config.bos_token_id,
decoder_start_token_id=generation_config._decoder_start_token_tensor,
bos_token_id=generation_config._bos_token_tensor,
device=inputs_tensor.device,
)
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
logger.warning(
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
"to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation."
)
elif generation_config.max_new_tokens is not None:
if not has_default_max_length:
logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
has_default_min_length=has_default_min_length,
model_input_name=model_input_name,
inputs_tensor=inputs_tensor,
input_ids_length=input_ids_length,
)
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None:
raise ValueError(
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
f" the maximum length ({generation_config.max_length})"
"Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a "
"Cache object) is unsupported. Please use only one of the two."
)
if input_ids_seq_length >= generation_config.max_length:
logger.warning(
f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
elif generation_config.cache_implementation is not None:
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if generation_config.cache_implementation == "static" and not self._supports_static_cache:
raise ValueError(
"This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
)
if not self.prompt_cross_attention:
# when we prepend prompt_hidden_state to inputs_embeds, max_cache_len needs to be actualised
# generation_config.max_length has already been increased by input_ids_length which is
# already counted in input_embeds_seq_length so we remove it
input_embeds_seq_length = model_kwargs["inputs_embeds"].shape[1]
max_cache_len = generation_config.max_length + input_embeds_seq_length - input_ids_length
else:
max_cache_len = self.generation_config.max_length
model_kwargs["past_key_values"] = self._get_cache(
generation_config.cache_implementation,
getattr(generation_config, "num_beams", 1) * batch_size,
max_cache_len,
model_kwargs,
)
elif generation_config.cache_implementation == "quantized":
raise ValueError(
"This model does not support the quantized cache. If you want your model to support quantized "
"cache, please open an issue on the Parler-TTS repository https://github.com/huggingface/parler-tts"
)
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
past = model_kwargs.get("past_key_values", None)
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
)
if past is None:
model_kwargs["past_key_values"] = (
DynamicCache()
if not requires_cross_attention_cache
else EncoderDecoderCache(DynamicCache(), DynamicCache())
)
elif isinstance(past, tuple):
model_kwargs["past_key_values"] = (
DynamicCache.from_legacy_cache(past)
if not requires_cross_attention_cache
else EncoderDecoderCache.from_legacy_cache(past)
)
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler-TTS)
input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
input_ids,
bos_token_id=generation_config.bos_token_id,
pad_token_id=generation_config.pad_token_id,
bos_token_id=generation_config._bos_token_tensor,
pad_token_id=generation_config._pad_token_tensor,
max_length=generation_config.max_length,
)
# stash the delay mask so that we don't have to recompute in each forward pass
......@@ -3191,21 +3412,17 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
and generation_config.do_sample is True
)
# 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG)
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
generation_config.guidance_scale = None
# 9. prepare distribution pre_processing samplers
# 8. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
device=input_ids.device,
)
# 10. prepare stopping criteria
# 9. prepare stopping criteria
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
......@@ -3217,8 +3434,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
f"but is {generation_config.num_return_sequences}."
)
# 11. run greedy search
outputs = self._greedy_search(
# 10. run greedy search
outputs = self._sample(
input_ids,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
......@@ -3229,8 +3446,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
)
elif is_sample_gen_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
# 10. prepare logits warper
logits_warper = self._get_logits_warper(generation_config, device=input_ids.device)
# expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
......@@ -3240,7 +3457,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
**model_kwargs,
)
# 12. run sample
# 11. run sample
outputs = self._sample(
input_ids,
logits_processor=logits_processor,
......@@ -3269,8 +3486,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
# Revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
_, mask = self.decoder.build_delay_pattern_mask(
input_ids,
bos_token_id=generation_config.bos_token_id,
pad_token_id=generation_config.pad_token_id,
bos_token_id=generation_config._bos_token_tensor,
pad_token_id=generation_config._pad_token_tensor,
max_length=output_ids.shape[1],
)
......
......@@ -13,11 +13,12 @@
# limitations under the License.
import os
import setuptools
_deps = [
"transformers>=4.39.0,<4.41.0",
"transformers>=4.43.0,<=4.43.3",
"torch",
"sentencepiece",
"descript-audio-codec",
......@@ -60,7 +61,7 @@ setuptools.setup(
packages=setuptools.find_packages(),
install_requires=_deps,
extras_require={
"dev": [_extras_dev_deps],
"train": [_extras_training_deps],
"dev": _extras_dev_deps,
"train": _extras_training_deps,
},
)
import logging
from dataclasses import dataclass
from typing import Dict, List, Optional, Union, Set
from typing import Dict, List, Optional, Set, Union
import torch
import numpy as np
import datasets
from datasets import load_dataset, Dataset, IterableDataset, interleave_datasets, concatenate_datasets
from transformers import AutoFeatureExtractor, AutoTokenizer
from tqdm import tqdm
import numpy as np
import torch
from accelerate import Accelerator
from datasets import Dataset, IterableDataset, concatenate_datasets, interleave_datasets, load_dataset
from tqdm import tqdm
from transformers import AutoFeatureExtractor, AutoTokenizer
@dataclass
......
import os
import re
import shutil
from pathlib import Path
from dataclasses import field
from pathlib import Path
from typing import Dict, List
import torch
from datasets import concatenate_datasets, load_from_disk
from wandb import Audio
from datasets import load_from_disk, concatenate_datasets
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment