Unverified Commit 11b209e1 authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Architecture improvements (#65)



* add RoPe

* don't include padding in rope

* possibly use cross-attn for prompt

* fix rope

* fix cross-attn

* fix self-attn

* fix dummy model

* clean-up rope

* first gqa implementation

* fix wer eval

* feat: add flash attention and spda

* chore: add README for flash attention

* chore: add benchmark script

* chore: add benchmark attention approach

* multi node and fix wer and fix compile

* Update modeling_parler_tts.py

* fix FA2, SDPA and add cross-attn MHA and attention type forcing

* better cross_attention key values number of heads default + add training arguments for attn implementation

* fix audio padding when torch compile or pad_to_max_length=True

* correct multi node

* make rope faster

* fix encoder sdpa

* fix training with cross attention + with FAZ

* use fp32 as default model dtype + fix generation when using FA2 with autocast

* remove redundant passes in generate + clean and fix attentions

* fix edge case in WER evaluation when longform generation

* better multi-node mapping and saving / add eval dataloader num workers

* remove old benchmarks

* faster audio encoding + checkpointing + fix generation step

* better eval + add right padding + fix eval loss compute

* correct README

* correct config docstrings

* remove comment

* make style

---------
Co-authored-by: default avatarsanchit-gandhi <sanchit@huggingface.co>
Co-authored-by: default avatarsang-nguyen-ts <sang.nguyen@trustingsocial.com>
Co-authored-by: yoach@huggingface.co <Yoach Lacombe>
parent 8b8c576e
...@@ -53,7 +53,8 @@ if torch.xpu.is_available(): ...@@ -53,7 +53,8 @@ if torch.xpu.is_available():
device = "xpu" device = "xpu"
torch_dtype = torch.float16 if device != "cpu" else torch.float32 torch_dtype = torch.float16 if device != "cpu" else torch.float32
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device, dtype=torch_dtype) model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1", torch_dtype=torch_dtype).to(device)
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1") tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1")
prompt = "Hey, how are you doing today?" prompt = "Hey, how are you doing today?"
......
...@@ -60,8 +60,8 @@ if __name__ == "__main__": ...@@ -60,8 +60,8 @@ if __name__ == "__main__":
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate) model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = True # True model.generation_config.do_sample = True # True
model.generation_config.guidance_scale = 1 # 3.0 model.generation_config.guidance_scale = 1 # 3.0
model.config.pad_token_id = encodec_vocab_size model.config.pad_token_id = encodec_vocab_size
model.config.decoder_start_token_id = encodec_vocab_size+1 model.config.decoder_start_token_id = encodec_vocab_size + 1
model.save_pretrained(os.path.join(args.save_directory, "tiny-model")) model.save_pretrained(os.path.join(args.save_directory, "tiny-model"))
...@@ -58,4 +58,7 @@ if __name__ == "__main__": ...@@ -58,4 +58,7 @@ if __name__ == "__main__":
model.generation_config.do_sample = True # True model.generation_config.do_sample = True # True
model.generation_config.guidance_scale = 1 # 3.0 model.generation_config.guidance_scale = 1 # 3.0
model.config.pad_token_id = encodec_vocab_size
model.config.decoder_start_token_id = encodec_vocab_size + 1
model.save_pretrained(os.path.join(args.save_directory, "tiny-model")) model.save_pretrained(os.path.join(args.save_directory, "tiny-model"))
...@@ -60,8 +60,8 @@ if __name__ == "__main__": ...@@ -60,8 +60,8 @@ if __name__ == "__main__":
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate) model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = True # True model.generation_config.do_sample = True # True
model.generation_config.guidance_scale = 1 # 3.0 model.generation_config.guidance_scale = 1 # 3.0
model.config.pad_token_id = encodec_vocab_size model.config.pad_token_id = encodec_vocab_size
model.config.decoder_start_token_id = encodec_vocab_size+1 model.config.decoder_start_token_id = encodec_vocab_size + 1
model.save_pretrained(os.path.join(args.save_directory, "parler-tts-untrained-600M/")) model.save_pretrained(os.path.join(args.save_directory, "parler-tts-untrained-600M/"))
...@@ -2,6 +2,7 @@ import dac ...@@ -2,6 +2,7 @@ import dac
from parler_tts import DACConfig, DACModel from parler_tts import DACConfig, DACModel
from transformers import AutoConfig, AutoModel from transformers import AutoConfig, AutoModel
from transformers import EncodecFeatureExtractor from transformers import EncodecFeatureExtractor
AutoConfig.register("dac", DACConfig) AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel) AutoModel.register(DACConfig, DACModel)
......
...@@ -47,6 +47,17 @@ class ParlerTTSDecoderConfig(PretrainedConfig): ...@@ -47,6 +47,17 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
Number of decoder layers. Number of decoder layers.
num_attention_heads (`int`, *optional*, defaults to 16): num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer block. Number of attention heads for each attention layer in the Transformer block.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
num_cross_attention_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention in the cross-attention layers.
If it is not specified, will default to `num_key_value_heads`.
ffn_dim (`int`, *optional*, defaults to 4096): ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block. Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block.
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
...@@ -74,6 +85,12 @@ class ParlerTTSDecoderConfig(PretrainedConfig): ...@@ -74,6 +85,12 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
The number of parallel codebooks forwarded to the model. The number of parallel codebooks forwarded to the model.
tie_word_embeddings(`bool`, *optional*, defaults to `False`): tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether input and output word embeddings should be tied. Whether input and output word embeddings should be tied.
rope_embeddings (`bool`, *optional*, defaults to `False`):
Whether to use ROPE or absolute positional embeddings.
rope_theta (`float`, *optional*, defaults to 100000.0):
The base period of the RoPE embeddings.
cross_attention_implementation_strategy (`str`, *optional*):
If not specified, the cross-attention implementation will be the same as `_attn_implementation`. If `always_eager`, it will always be the eager implementation. If `always_sdpa`, it will always be the sdpa implementation.
""" """
model_type = "parler_tts_decoder" model_type = "parler_tts_decoder"
...@@ -86,6 +103,8 @@ class ParlerTTSDecoderConfig(PretrainedConfig): ...@@ -86,6 +103,8 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
num_hidden_layers=24, num_hidden_layers=24,
ffn_dim=4096, ffn_dim=4096,
num_attention_heads=16, num_attention_heads=16,
num_key_value_heads=None,
num_cross_attention_key_value_heads=None,
layerdrop=0.0, layerdrop=0.0,
use_cache=True, use_cache=True,
activation_function="gelu", activation_function="gelu",
...@@ -100,6 +119,9 @@ class ParlerTTSDecoderConfig(PretrainedConfig): ...@@ -100,6 +119,9 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
bos_token_id=2049, bos_token_id=2049,
eos_token_id=2048, eos_token_id=2048,
tie_word_embeddings=False, tie_word_embeddings=False,
rope_embeddings=False,
rope_theta=10_000.0,
cross_attention_implementation_strategy=None,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
...@@ -108,6 +130,12 @@ class ParlerTTSDecoderConfig(PretrainedConfig): ...@@ -108,6 +130,12 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
self.ffn_dim = ffn_dim self.ffn_dim = ffn_dim
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
if num_cross_attention_key_value_heads is None:
num_cross_attention_key_value_heads = num_key_value_heads
self.num_cross_attention_key_value_heads = num_cross_attention_key_value_heads
self.dropout = dropout self.dropout = dropout
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout self.activation_dropout = activation_dropout
...@@ -117,6 +145,9 @@ class ParlerTTSDecoderConfig(PretrainedConfig): ...@@ -117,6 +145,9 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
self.use_cache = use_cache self.use_cache = use_cache
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.num_codebooks = num_codebooks self.num_codebooks = num_codebooks
self.rope_embeddings = rope_embeddings
self.rope_theta = rope_theta
self.cross_attention_implementation_strategy = cross_attention_implementation_strategy
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
...@@ -140,6 +171,8 @@ class ParlerTTSConfig(PretrainedConfig): ...@@ -140,6 +171,8 @@ class ParlerTTSConfig(PretrainedConfig):
vocab_size (`int`, *optional*, defaults to 1024): vocab_size (`int`, *optional*, defaults to 1024):
Vocabulary size of the prompt token ids. Defines the number of different tokens that can be Vocabulary size of the prompt token ids. Defines the number of different tokens that can be
represented by the `prompt_inputs_ids`. represented by the `prompt_inputs_ids`.
prompt_cross_attention (`bool`, *optional*, defaults to `False`):
Whether to use cross-attention conditioning for the prompt (as well as the description).
kwargs (*optional*): kwargs (*optional*):
Dictionary of keyword arguments. Notably: Dictionary of keyword arguments. Notably:
...@@ -190,7 +223,7 @@ class ParlerTTSConfig(PretrainedConfig): ...@@ -190,7 +223,7 @@ class ParlerTTSConfig(PretrainedConfig):
model_type = "parler_tts" model_type = "parler_tts"
is_composition = True is_composition = True
def __init__(self, vocab_size=1024, **kwargs): def __init__(self, vocab_size=1024, prompt_cross_attention=False, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs: if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config") raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")
...@@ -204,6 +237,7 @@ class ParlerTTSConfig(PretrainedConfig): ...@@ -204,6 +237,7 @@ class ParlerTTSConfig(PretrainedConfig):
decoder_config = kwargs.pop("decoder") decoder_config = kwargs.pop("decoder")
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.prompt_cross_attention = prompt_cross_attention
self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config) self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config) self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
self.decoder = ParlerTTSDecoderConfig(**decoder_config) self.decoder = ParlerTTSDecoderConfig(**decoder_config)
...@@ -236,3 +270,21 @@ class ParlerTTSConfig(PretrainedConfig): ...@@ -236,3 +270,21 @@ class ParlerTTSConfig(PretrainedConfig):
# This is a property because you might want to change the codec model on the fly # This is a property because you might want to change the codec model on the fly
def sampling_rate(self): def sampling_rate(self):
return self.audio_encoder.sampling_rate return self.audio_encoder.sampling_rate
# Copy from musicgen
@property
def _attn_implementation(self):
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
if hasattr(self, "_attn_implementation_internal"):
if self._attn_implementation_internal is None:
# `config.attn_implementation` should never be None, for backward compatibility.
return "eager"
else:
return self._attn_implementation_internal
else:
return "eager"
@_attn_implementation.setter
def _attn_implementation(self, value):
self._attn_implementation_internal = value
self.decoder._attn_implementation = value
...@@ -18,7 +18,7 @@ import inspect ...@@ -18,7 +18,7 @@ import inspect
import math import math
import random import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, List
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -28,7 +28,12 @@ from transformers.activations import ACT2FN ...@@ -28,7 +28,12 @@ from transformers.activations import ACT2FN
from transformers.generation.configuration_utils import GenerationConfig from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList from transformers.generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList from transformers.generation.stopping_criteria import StoppingCriteriaList
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from transformers.modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
_prepare_4d_causal_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -43,6 +48,9 @@ from transformers.utils import ( ...@@ -43,6 +48,9 @@ from transformers.utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
import torch.nn.functional as F
from transformers.utils.import_utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
from .dac_wrapper import DACConfig, DACModel from .dac_wrapper import DACConfig, DACModel
...@@ -56,6 +64,13 @@ if TYPE_CHECKING: ...@@ -56,6 +64,13 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
else:
logger.warn("Flash attention 2 is not installed")
_CONFIG_FOR_DOC = "ParlerTTSConfig" _CONFIG_FOR_DOC = "ParlerTTSConfig"
_CHECKPOINT_FOR_DOC = "facebook/parler_tts-small" _CHECKPOINT_FOR_DOC = "facebook/parler_tts-small"
...@@ -139,6 +154,19 @@ def build_delay_pattern_mask( ...@@ -139,6 +154,19 @@ def build_delay_pattern_mask(
return input_ids, pattern_mask return input_ids, pattern_mask
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
@dataclass @dataclass
class ParlerTTSUnconditionalInput(ModelOutput): class ParlerTTSUnconditionalInput(ModelOutput):
""" """
...@@ -223,25 +251,95 @@ class ParlerTTSSinusoidalPositionalEmbedding(nn.Module): ...@@ -223,25 +251,95 @@ class ParlerTTSSinusoidalPositionalEmbedding(nn.Module):
return self.weights.index_select(0, position_ids.view(-1)).detach() return self.weights.index_select(0, position_ids.view(-1)).detach()
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->ParlerTTS # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->ParlerTTS
class ParlerTTSRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__()
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
# Ignore copy
@torch.no_grad()
def forward(self, device_type, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :]
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos, sin
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
x (`torch.Tensor`): The tensor over which to apply the rope embeddings
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed
class ParlerTTSAttention(nn.Module): class ParlerTTSAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper. Modified to use GQA and MQA."""
def __init__( def __init__(
self, self,
embed_dim: int, embed_dim: int,
num_heads: int, num_heads: int,
num_key_value_heads: int,
dropout: float = 0.0, dropout: float = 0.0,
is_decoder: bool = False, is_decoder: bool = False,
bias: bool = True, bias: bool = True,
is_causal: bool = False, is_causal: bool = False,
config: Optional[ParlerTTSConfig] = None, rope_embeddings: bool = False,
config: Optional[ParlerTTSDecoderConfig] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.num_key_value_heads = num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.config = config self.config = config
if (self.head_dim * num_heads) != self.embed_dim: if (self.head_dim * num_heads) != self.embed_dim:
...@@ -253,20 +351,27 @@ class ParlerTTSAttention(nn.Module): ...@@ -253,20 +351,27 @@ class ParlerTTSAttention(nn.Module):
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.is_causal = is_causal self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, self.num_key_value_heads * self.head_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, self.num_key_value_heads * self.head_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): self.rope_embeddings = rope_embeddings
def _shape_query(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def _shape_key_value(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2).contiguous()
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None, key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
cos: Optional[torch.LongTensor] = None,
sin: Optional[torch.LongTensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
...@@ -276,13 +381,18 @@ class ParlerTTSAttention(nn.Module): ...@@ -276,13 +381,18 @@ class ParlerTTSAttention(nn.Module):
# for the decoder # for the decoder
is_cross_attention = key_value_states is not None is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size() bsz, tgt_len = hidden_states.shape[:2]
# get query proj # get query proj
query_states = self.q_proj(hidden_states) * self.scaling query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj query_states = self._shape_query(query_states, tgt_len, bsz)
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as if self.rope_embeddings:
query_states = apply_rotary_pos_emb(query_states, cos, sin)
current_states = key_value_states if key_value_states is not None else hidden_states
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning # the provided `key_value_states` to support prefix tuning
if ( if (
is_cross_attention is_cross_attention
...@@ -292,20 +402,17 @@ class ParlerTTSAttention(nn.Module): ...@@ -292,20 +402,17 @@ class ParlerTTSAttention(nn.Module):
# reuse k,v, cross_attentions # reuse k,v, cross_attentions
key_states = past_key_value[0] key_states = past_key_value[0]
value_states = past_key_value[1] value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else: else:
# self_attention key_states = self._shape_key_value(self.k_proj(current_states), -1, bsz)
key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape_key_value(self.v_proj(current_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if not is_cross_attention:
# cached key states already have rope applied - only apply to new state
key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
if self.is_decoder: if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
...@@ -317,8 +424,12 @@ class ParlerTTSAttention(nn.Module): ...@@ -317,8 +424,12 @@ class ParlerTTSAttention(nn.Module):
# if encoder bi-directional self-attention `past_key_value` is always `None` # if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states) past_key_value = (key_states, value_states)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
proj_shape = (bsz * self.num_heads, -1, self.head_dim) proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) query_states = query_states.reshape(*proj_shape)
key_states = key_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape)
value_states = value_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape)
...@@ -351,10 +462,6 @@ class ParlerTTSAttention(nn.Module): ...@@ -351,10 +462,6 @@ class ParlerTTSAttention(nn.Module):
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions: if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else: else:
...@@ -382,30 +489,390 @@ class ParlerTTSAttention(nn.Module): ...@@ -382,30 +489,390 @@ class ParlerTTSAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoderLayer with Musicgen->ParlerTTS def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenFlashAttention2 with Musicgen->ParlerTTS
class ParlerTTSFlashAttention2(ParlerTTSAttention):
"""
ParlerTTS flash attention module. This module inherits from `ParlerTTSAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
cos: Optional[torch.LongTensor] = None,
sin: Optional[torch.LongTensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# ParlerTTSFlashAttention2 attention does not support output_attentions
if output_attentions:
raise ValueError("ParlerTTSFlashAttention2 attention does not support output_attentions")
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, q_len = hidden_states.shape[:2]
# get query proj
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
if self.rope_embeddings:
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
current_states = key_value_states if key_value_states is not None else hidden_states
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0].transpose(1, 2)
value_states = past_key_value[1].transpose(1, 2)
else:
key_states = self.k_proj(current_states).view(bsz, -1, self.num_key_value_heads, self.head_dim)
value_states = self.v_proj(current_states).view(bsz, -1, self.num_key_value_heads, self.head_dim)
if not is_cross_attention:
# cached key states already have rope applied - only apply to new state
key_states = (
apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2) if self.rope_embeddings else key_states
)
if past_key_value is not None:
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
if query_states.dtype == torch.float32 or value_states.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
)
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->Musicgen
class ParlerTTSSdpaAttention(ParlerTTSAttention):
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
cos: Optional[torch.LongTensor] = None,
sin: Optional[torch.LongTensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
if output_attentions or layer_head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"ParlerTTSModel is using ParlerTTSSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
key_value_states=key_value_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len = hidden_states.shape[:2]
# get query proj
query_states = self.q_proj(hidden_states)
query_states = self._shape_query(query_states, tgt_len, bsz)
if self.rope_embeddings:
query_states = apply_rotary_pos_emb(query_states, cos, sin)
current_states = key_value_states if key_value_states is not None else hidden_states
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
else:
key_states = self._shape_key_value(self.k_proj(current_states), -1, bsz)
value_states = self._shape_key_value(self.v_proj(current_states), -1, bsz)
if not is_cross_attention:
# cached key states already have rope applied - only apply to new state
key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal=is_causal,
)
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
PARLERTTS_ATTENTION_CLASSES = {
"eager": ParlerTTSAttention,
"sdpa": ParlerTTSSdpaAttention,
"flash_attention_2": ParlerTTSFlashAttention2,
}
class ParlerTTSDecoderLayer(nn.Module): class ParlerTTSDecoderLayer(nn.Module):
def __init__(self, config: ParlerTTSDecoderConfig): def __init__(self, config: ParlerTTSDecoderConfig):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.self_attn = ParlerTTSAttention( self.self_attn = PARLERTTS_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=True, is_decoder=True,
is_causal=True,
bias=False, bias=False,
rope_embeddings=config.rope_embeddings,
config=config,
) )
self.dropout = config.dropout self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function] self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = ParlerTTSAttention( cross_attn_implementation = config._attn_implementation
if config.cross_attention_implementation_strategy == "always_eager":
cross_attn_implementation = "eager"
elif config.cross_attention_implementation_strategy == "always_sdpa":
cross_attn_implementation = "sdpa"
self.encoder_attn = PARLERTTS_ATTENTION_CLASSES[cross_attn_implementation](
self.embed_dim, self.embed_dim,
config.num_attention_heads, config.num_attention_heads,
num_key_value_heads=config.num_cross_attention_key_value_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=True, is_decoder=True,
bias=False, bias=False,
rope_embeddings=config.rope_embeddings,
config=config,
) )
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False) self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False)
...@@ -416,6 +883,8 @@ class ParlerTTSDecoderLayer(nn.Module): ...@@ -416,6 +883,8 @@ class ParlerTTSDecoderLayer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
cos: Optional[torch.LongTensor] = None,
sin: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
...@@ -429,6 +898,9 @@ class ParlerTTSDecoderLayer(nn.Module): ...@@ -429,6 +898,9 @@ class ParlerTTSDecoderLayer(nn.Module):
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
encoder_hidden_states (`torch.FloatTensor`): encoder_hidden_states (`torch.FloatTensor`):
cross attention input to the layer of shape `(batch, seq_len, embed_dim)` cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
...@@ -449,10 +921,13 @@ class ParlerTTSDecoderLayer(nn.Module): ...@@ -449,10 +921,13 @@ class ParlerTTSDecoderLayer(nn.Module):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple # add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
past_key_value=self_attn_past_key_value, past_key_value=self_attn_past_key_value,
attention_mask=attention_mask, attention_mask=attention_mask,
cos=cos,
sin=sin,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -472,6 +947,8 @@ class ParlerTTSDecoderLayer(nn.Module): ...@@ -472,6 +947,8 @@ class ParlerTTSDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
cos=cos,
sin=sin,
layer_head_mask=cross_attn_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
...@@ -512,6 +989,8 @@ class ParlerTTSPreTrainedModel(PreTrainedModel): ...@@ -512,6 +989,8 @@ class ParlerTTSPreTrainedModel(PreTrainedModel):
config_class = ParlerTTSDecoderConfig config_class = ParlerTTSDecoderConfig
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_no_split_modules = ["ParlerTTSDecoderLayer", "ParlerTTSAttention"] _no_split_modules = ["ParlerTTSDecoderLayer", "ParlerTTSAttention"]
def _init_weights(self, module): def _init_weights(self, module):
...@@ -608,6 +1087,7 @@ MUSICGEN_INPUTS_DOCSTRING = r""" ...@@ -608,6 +1087,7 @@ MUSICGEN_INPUTS_DOCSTRING = r"""
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
TODO: it's passed through enc_to_dec_proj and optionnally we concat the prompt hidden states in certain cases.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
...@@ -751,7 +1231,6 @@ MUSICGEN_DECODER_INPUTS_DOCSTRING = r""" ...@@ -751,7 +1231,6 @@ MUSICGEN_DECODER_INPUTS_DOCSTRING = r"""
""" """
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder with Musicgen->ParlerTTS
class ParlerTTSDecoder(ParlerTTSPreTrainedModel): class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
""" """
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ParlerTTSDecoderLayer`] Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ParlerTTSDecoderLayer`]
...@@ -772,14 +1251,28 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -772,14 +1251,28 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
[nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)] [nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)]
) )
self.embed_positions = ParlerTTSSinusoidalPositionalEmbedding( self.rope_embeddings = config.rope_embeddings
config.max_position_embeddings, if not config.rope_embeddings:
config.hidden_size, self.embed_positions = ParlerTTSSinusoidalPositionalEmbedding(
) config.max_position_embeddings,
config.hidden_size,
)
else:
self.rotary_emb = ParlerTTSRotaryEmbedding(
config.hidden_size // config.num_attention_heads,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
)
self.layers = nn.ModuleList([ParlerTTSDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([ParlerTTSDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layer_norm = nn.LayerNorm(config.hidden_size) self.layer_norm = nn.LayerNorm(config.hidden_size)
self.attn_implementation = config._attn_implementation
encoder_attn_implementation = config._attn_implementation
if config.cross_attention_implementation_strategy is not None:
encoder_attn_implementation = (
"sdpa" if config.cross_attention_implementation_strategy == "always_sdpa" else "eager"
)
self.encoder_attn_implementation = encoder_attn_implementation
self.gradient_checkpointing = False self.gradient_checkpointing = False
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
...@@ -803,6 +1296,7 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -803,6 +1296,7 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
cross_attn_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
...@@ -839,7 +1333,10 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -839,7 +1333,10 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
if prompt_hidden_states is not None: if prompt_hidden_states is not None:
inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1) inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1)
# As it is, the masked ids from the prompt will still count in the positions embeddings # NOTE: 1. As it is, the masked ids from the prompt will still count in the positions embeddings
# NOTE: 2. we want to concatenate the prompt attention mask and the decoder attention mask
# i.i.f `prompt_cross_attention=False`. ParlerTTSForConditionalGeneration's taking care of setting
# `prompt_attention_mask=None`
if prompt_attention_mask is not None and attention_mask is not None: if prompt_attention_mask is not None and attention_mask is not None:
attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1) attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1)
elif prompt_attention_mask is not None: elif prompt_attention_mask is not None:
...@@ -855,6 +1352,9 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -855,6 +1352,9 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
dim=1, dim=1,
) )
else: else:
# In the generation case of `prompt_cross_attention=True`, we need to recreate an attention mask from scratch
# to be able to prepend the prompt attention mask.
# Since we generate token per token, we can recompute the generated length from the information we have.
generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1 generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1
attention_mask = torch.cat( attention_mask = torch.cat(
[ [
...@@ -867,26 +1367,75 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -867,26 +1367,75 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
) )
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
attention_mask = _prepare_4d_causal_attention_mask( cos, sin = None, None
attention_mask, input_shape, inputs_embeds, past_key_values_length
) if not self.rope_embeddings:
# embed positions
# TODO: As it is, the masked ids from the prompt will still count in the positions embeddings
# maybe should modify position embeddings
positions = self.embed_positions(inputs_embeds, past_key_values_length)
hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
else:
hidden_states = inputs_embeds
# expand encoder attention mask if position_ids is None:
if encoder_hidden_states is not None and encoder_attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # masked ids will **not** count in the position embeddings
encoder_attention_mask = _prepare_4d_attention_mask( position_ids = attention_mask.long().cumsum(-1) - 1
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] position_ids.masked_fill_(attention_mask == 0, 1)
) else:
position_ids = torch.arange(
past_key_values_length,
input_shape[1] + past_key_values_length,
dtype=torch.long,
device=inputs_embeds.device,
)
position_ids = position_ids.unsqueeze(0)
# embed positions # Some generation methods already pass only the last input ID
# TODO: As it is, the masked ids from the prompt will still count in the positions embeddings if position_ids.shape[1] > input_shape[1]:
# maybe should modify position embeddings position_ids = position_ids[:, -input_shape[1] :]
positions = self.embed_positions(inputs_embeds, past_key_values_length)
hidden_states = inputs_embeds + positions.to(inputs_embeds.device) cos, sin = self.rotary_emb(hidden_states.device.type, position_ids)
cos, sin = cos.to(hidden_states.dtype), sin.to(hidden_states.dtype)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
if self.attn_implementation == "flash_attention_2":
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self.attn_implementation == "sdpa" and head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
)
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if self.encoder_attn_implementation == "flash_attention_2":
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self.encoder_attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask,
inputs_embeds.dtype,
tgt_len=input_shape[-1],
)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning_once( logger.warning_once(
...@@ -923,6 +1472,8 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -923,6 +1472,8 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
decoder_layer.forward, decoder_layer.forward,
hidden_states, hidden_states,
attention_mask, attention_mask,
cos,
sin,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
...@@ -935,6 +1486,8 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -935,6 +1486,8 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
cos=cos,
sin=sin,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
...@@ -1004,6 +1557,7 @@ class ParlerTTSModel(ParlerTTSPreTrainedModel): ...@@ -1004,6 +1557,7 @@ class ParlerTTSModel(ParlerTTSPreTrainedModel):
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None, prompt_hidden_states: Optional[torch.FloatTensor] = None,
...@@ -1028,6 +1582,7 @@ class ParlerTTSModel(ParlerTTSPreTrainedModel): ...@@ -1028,6 +1582,7 @@ class ParlerTTSModel(ParlerTTSPreTrainedModel):
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
prompt_hidden_states=prompt_hidden_states, prompt_hidden_states=prompt_hidden_states,
...@@ -1058,7 +1613,6 @@ class ParlerTTSModel(ParlerTTSPreTrainedModel): ...@@ -1058,7 +1613,6 @@ class ParlerTTSModel(ParlerTTSPreTrainedModel):
"The Parler-TTS decoder model with a language modelling head on top.", "The Parler-TTS decoder model with a language modelling head on top.",
MUSICGEN_START_DOCSTRING, MUSICGEN_START_DOCSTRING,
) )
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForCausalLM with Musicgen->ParlerTTS
class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
def __init__(self, config: ParlerTTSDecoderConfig): def __init__(self, config: ParlerTTSDecoderConfig):
super().__init__(config) super().__init__(config)
...@@ -1097,6 +1651,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1097,6 +1651,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None, prompt_hidden_states: Optional[torch.FloatTensor] = None,
...@@ -1124,6 +1679,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1124,6 +1679,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
outputs = self.model( outputs = self.model(
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
prompt_hidden_states=prompt_hidden_states, prompt_hidden_states=prompt_hidden_states,
...@@ -1144,7 +1700,6 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1144,7 +1700,6 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss = torch.zeros([], device=self.device)
# since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels # since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels
logits = lm_logits[:, :, -labels.shape[1] :] logits = lm_logits[:, :, -labels.shape[1] :]
...@@ -1228,8 +1783,16 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1228,8 +1783,16 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
[prompt_attention_mask, torch.zeros_like(prompt_attention_mask)], dim=0 [prompt_attention_mask, torch.zeros_like(prompt_attention_mask)], dim=0
) )
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] input_ids = input_ids[:, -1:]
if position_ids is not None:
position_ids = position_ids[:, -input_ids.shape[1] :]
# we only want to use prompt signal in the 1st generation step but keeping the attention mask # we only want to use prompt signal in the 1st generation step but keeping the attention mask
prompt_hidden_states = None prompt_hidden_states = None
...@@ -1237,6 +1800,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1237,6 +1800,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"position_ids": position_ids,
"encoder_hidden_states": encoder_hidden_states, "encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask, "encoder_attention_mask": encoder_attention_mask,
"prompt_hidden_states": prompt_hidden_states, "prompt_hidden_states": prompt_hidden_states,
...@@ -1554,6 +2118,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -1554,6 +2118,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
base_model_prefix = "encoder_decoder" base_model_prefix = "encoder_decoder"
main_input_name = "input_ids" main_input_name = "input_ids"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__( def __init__(
self, self,
...@@ -1633,6 +2199,13 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -1633,6 +2199,13 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
# prompt embeddings # prompt embeddings
self.embed_prompts = nn.Embedding(config.vocab_size, self.decoder.config.hidden_size) self.embed_prompts = nn.Embedding(config.vocab_size, self.decoder.config.hidden_size)
self.prompt_cross_attention = config.prompt_cross_attention
if config.prompt_cross_attention:
self.embed_positions = ParlerTTSSinusoidalPositionalEmbedding(
config.decoder.max_position_embeddings,
config.decoder.hidden_size,
)
if self.text_encoder.get_output_embeddings() is not None: if self.text_encoder.get_output_embeddings() is not None:
raise ValueError( raise ValueError(
f"The encoder {self.text_encoder} should not have a LM Head. Please use a model without and LM Head" f"The encoder {self.text_encoder} should not have a LM Head. Please use a model without and LM Head"
...@@ -1931,6 +2504,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -1931,6 +2504,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
prompt_input_ids: Optional[torch.FloatTensor] = None, prompt_input_ids: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.LongTensor] = None, prompt_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None, prompt_hidden_states: Optional[torch.FloatTensor] = None,
decoder_position_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
...@@ -1983,6 +2557,10 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -1983,6 +2557,10 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
} }
if prompt_hidden_states is None:
if prompt_input_ids is not None:
prompt_hidden_states = self.embed_prompts(prompt_input_ids)
if encoder_outputs is None: if encoder_outputs is None:
encoder_outputs = self.text_encoder( encoder_outputs = self.text_encoder(
input_ids=input_ids, input_ids=input_ids,
...@@ -1993,24 +2571,46 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -1993,24 +2571,46 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
return_dict=return_dict, return_dict=return_dict,
**kwargs_text_encoder, **kwargs_text_encoder,
) )
elif isinstance(encoder_outputs, tuple): encoder_hidden_states = encoder_outputs[0]
encoder_outputs = BaseModelOutput(*encoder_outputs)
encoder_hidden_states = encoder_outputs[0] # optionally project encoder_hidden_states
if (
self.text_encoder.config.hidden_size != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
# optionally project encoder_hidden_states if attention_mask is not None:
if ( encoder_hidden_states = encoder_hidden_states * attention_mask[..., None]
self.text_encoder.config.hidden_size != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
if attention_mask is not None: if prompt_hidden_states is not None and self.prompt_cross_attention:
encoder_hidden_states = encoder_hidden_states * attention_mask[..., None] # add sinusoidal positional embedding
positions = self.embed_positions(prompt_hidden_states, 0)
prompt_hidden_states = prompt_hidden_states + positions.to(prompt_hidden_states.device)
if prompt_hidden_states is None: if prompt_attention_mask is not None and attention_mask is None:
if prompt_input_ids is not None: attention_mask = torch.ones(
prompt_hidden_states = self.embed_prompts(prompt_input_ids) encoder_hidden_states.shape[:2], device=self.device, dtype=prompt_attention_mask.dtype
)
elif attention_mask is not None and prompt_attention_mask is None:
prompt_attention_mask = torch.ones(
prompt_hidden_states.shape[:2], device=self.device, dtype=attention_mask.dtype
)
# concatenate text description states with prompt description states
encoder_hidden_states = torch.cat([encoder_hidden_states, prompt_hidden_states], dim=1)
if prompt_attention_mask is not None:
attention_mask = torch.cat([attention_mask, prompt_attention_mask], dim=1)
prompt_hidden_states = None
prompt_attention_mask = None
encoder_outputs["last_hidden_state"] = encoder_hidden_states
elif isinstance(encoder_outputs, tuple):
encoder_outputs = BaseModelOutput(*encoder_outputs)
encoder_hidden_states = encoder_outputs.last_hidden_state
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
...@@ -2041,6 +2641,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2041,6 +2641,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
input_ids=decoder_input_ids, input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask, attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
prompt_hidden_states=prompt_hidden_states, prompt_hidden_states=prompt_hidden_states,
...@@ -2121,7 +2722,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2121,7 +2722,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
# we only want to use prompt signal in the 1st generation step but keeping the attention mask # if prompt_cross_attention,
# we only want to use prompt signal in the 1st generation step
prompt_hidden_states = None prompt_hidden_states = None
return { return {
...@@ -2231,12 +2833,56 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2231,12 +2833,56 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
[model_kwargs["attention_mask"], torch.zeros_like(model_kwargs["attention_mask"])], dim=0 [model_kwargs["attention_mask"], torch.zeros_like(model_kwargs["attention_mask"])], dim=0
) )
model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=last_hidden_state) # we optionnally project last_hidden_state to avoid recomputing every time
encoder_hidden_states = last_hidden_state
if (
self.text_encoder.config.hidden_size != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
if model_kwargs["attention_mask"] is not None:
encoder_hidden_states = encoder_hidden_states * model_kwargs["attention_mask"][..., None]
model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=encoder_hidden_states)
return model_kwargs return model_kwargs
def _prepare_prompt_kwargs_for_generation(self, prompt_input_ids, model_kwargs): def _prepare_prompt_kwargs_for_generation(self, prompt_input_ids, model_kwargs):
model_kwargs["prompt_hidden_states"] = self.embed_prompts(prompt_input_ids) prompt_hidden_states = self.embed_prompts(prompt_input_ids)
if self.prompt_cross_attention:
# add sinusoidal positional embedding
positions = self.embed_positions(prompt_hidden_states, 0)
prompt_hidden_states = prompt_hidden_states + positions.to(prompt_hidden_states.device)
attention_mask = model_kwargs.get("attention_mask", None)
prompt_attention_mask = model_kwargs.get("prompt_attention_mask", None)
encoder_hidden_states = model_kwargs["encoder_outputs"].last_hidden_state
if prompt_attention_mask is not None and attention_mask is None:
attention_mask = torch.ones(
encoder_hidden_states.shape[:2], device=self.device, dtype=prompt_attention_mask.dtype
)
elif attention_mask is not None and prompt_attention_mask is None:
prompt_attention_mask = torch.ones(
prompt_hidden_states.shape[:2], device=self.device, dtype=attention_mask.dtype
)
# concatenate text description states with prompt description states
encoder_hidden_states = torch.cat([encoder_hidden_states, prompt_hidden_states], dim=1)
if prompt_attention_mask is not None:
attention_mask = torch.cat([attention_mask, prompt_attention_mask], dim=1)
model_kwargs["encoder_outputs"].last_hidden_state = encoder_hidden_states
model_kwargs["attention_mask"] = attention_mask
# in this case, since we already concatenated the prompt hidden states and attention mask, we don't need them anymore.
model_kwargs["prompt_hidden_states"] = None
model_kwargs["prompt_attention_mask"] = None
else:
model_kwargs["prompt_hidden_states"] = prompt_hidden_states
# we're keeping the prompt attention mask because it has to be prepended to the decoder attention mask on the fly
return model_kwargs return model_kwargs
def _prepare_audio_encoder_kwargs_for_generation( def _prepare_audio_encoder_kwargs_for_generation(
...@@ -2617,10 +3263,10 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2617,10 +3263,10 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
else: else:
output_ids = outputs output_ids = outputs
# apply the pattern mask to the final ids # Apply the pattern mask to the final ids
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"]) output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask # Revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
_, mask = self.decoder.build_delay_pattern_mask( _, mask = self.decoder.build_delay_pattern_mask(
input_ids, input_ids,
bos_token_id=generation_config.bos_token_id, bos_token_id=generation_config.bos_token_id,
...@@ -2659,13 +3305,12 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2659,13 +3305,12 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
output_values.append(sample.transpose(0, 2)) output_values.append(sample.transpose(0, 2))
else: else:
output_values.append(torch.zeros((1, 1, 1)).to(self.device)) output_values.append(torch.zeros((1, 1, 1)).to(self.device))
# TODO: we should keep track of output length as well. Not really straightfoward tbh # TODO: we should keep track of output length as well. Not really straightforward tbh
output_values = ( output_values = (
torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0) torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0)
.squeeze(-1) .squeeze(-1)
.squeeze(-1) .squeeze(-1)
) )
if generation_config.return_dict_in_generate: if generation_config.return_dict_in_generate:
outputs.sequences = output_values outputs.sequences = output_values
return outputs return outputs
......
...@@ -78,6 +78,22 @@ class ModelArguments: ...@@ -78,6 +78,22 @@ class ModelArguments:
"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models" "help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
}, },
) )
attn_implementation: str = field(
default="eager",
metadata={"help": "Attention implementation used. One of `eager`, `sdpa`, `flash_attention_2`"},
)
cross_attention_implementation_strategy: str = field(
default=None,
metadata={
"help": "If not specified, the cross-attention implementation will be the same as `_attn_implementation`. If `always_eager`, it will always be the eager implementation. If `always_sdpa`, it will always be the sdpa implementation."
},
)
prompt_padding_side: Optional[str] = field(
default="left",
metadata={
"help": "Prompt tokenizer padding side. Defaults to `left`. If the prompt is pre-pended to the codebooks hidden states, it should be padded on the left."
},
)
@dataclass @dataclass
...@@ -290,6 +306,10 @@ class DataTrainingArguments: ...@@ -290,6 +306,10 @@ class DataTrainingArguments:
}, },
) )
temporary_save_to_disk: str = field(default=None, metadata={"help": "Temporarily save audio labels here."}) temporary_save_to_disk: str = field(default=None, metadata={"help": "Temporarily save audio labels here."})
save_codec_steps: Optional[int] = field(
default=500,
metadata={"help": "Temporarily save the audio labels every `save_steps`."},
)
pad_to_multiple_of: Optional[int] = field( pad_to_multiple_of: Optional[int] = field(
default=2, default=2,
metadata={"help": ("Pad to multiple of for tokenizers.")}, metadata={"help": ("Pad to multiple of for tokenizers.")},
...@@ -311,3 +331,32 @@ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments): ...@@ -311,3 +331,32 @@ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
default=8, default=8,
metadata={"help": ("Specify the batch size of the audio encoding pre-processing steps.")}, metadata={"help": ("Specify the batch size of the audio encoding pre-processing steps.")},
) )
eval_dataloader_num_workers: Optional[int] = field(
default=0,
metadata={
"help": (
"Number of subprocesses to use for evaluation data loading (PyTorch only). 0 means that the data will be loaded in the main process."
)
},
)
compute_clap_similarity_metric: bool = field(
default=True,
metadata={
"help": (
"Whether or not to compute the clap similarity metric between the description and the generation during evalution."
)
},
)
compute_noise_level_metric: bool = field(
default=True,
metadata={"help": ("Whether or not to compute the squim si-sdr measure of the generations.")},
)
noise_level_to_compute_clean_wer: float = field(
default=25,
metadata={
"help": (
"if `compute_noise_level_metric=True`, will compute a 'clean' WER on samples with generated noise higher than `noise_level_to_compute_clean_wer`."
"This is a proxy measure to compute WER on clean audios, provided that the model learn to generate clean audios."
)
},
)
...@@ -30,6 +30,8 @@ class DataCollatorEncodecWithPadding: ...@@ -30,6 +30,8 @@ class DataCollatorEncodecWithPadding:
# different padding methods # different padding methods
audios = [feature[self.audio_column_name]["array"] for feature in features] audios = [feature[self.audio_column_name]["array"] for feature in features]
len_audio = [len(audio) for audio in audios] len_audio = [len(audio) for audio in audios]
if self.max_length is not None:
audios = [audio[: min(l, self.max_length)] for audio, l in zip(audios, len_audio)]
# since resampling has already been performed in the 'load_multiple_datasets' function, # since resampling has already been performed in the 'load_multiple_datasets' function,
# a fixed sampling_rate(44100hz) is passed to the feature_extractor. # a fixed sampling_rate(44100hz) is passed to the feature_extractor.
...@@ -81,7 +83,9 @@ class DataCollatorParlerTTSWithPadding: ...@@ -81,7 +83,9 @@ class DataCollatorParlerTTSWithPadding:
# (bsz, seq_len, num_codebooks) # (bsz, seq_len, num_codebooks)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
if self.audio_max_length is not None and self.padding == "max_length": if self.audio_max_length is not None and self.padding == "max_length":
labels = torch.nn.functional.pad(labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0))) labels = torch.nn.functional.pad(
labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0)), value=-100
)
input_ids = [{"input_ids": feature["input_ids"]} for feature in features] input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
...@@ -95,11 +99,6 @@ class DataCollatorParlerTTSWithPadding: ...@@ -95,11 +99,6 @@ class DataCollatorParlerTTSWithPadding:
batch = {"labels": labels, **input_ids} batch = {"labels": labels, **input_ids}
if self.audio_max_length is not None and self.padding == "max_length":
# if we do torch.compile, we need to also specify the attention_mask
decoder_attention_mask = torch.ones(labels.shape[:2], dtype=input_ids["attention_mask"].dtype)
batch["decoder_attention_mask"] = decoder_attention_mask
prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features] prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
prompt_input_ids = self.prompt_tokenizer.pad( prompt_input_ids = self.prompt_tokenizer.pad(
prompt_input_ids, prompt_input_ids,
...@@ -206,7 +205,7 @@ def load_multiple_datasets( ...@@ -206,7 +205,7 @@ def load_multiple_datasets(
all_datasets = [] all_datasets = []
# iterate over the datasets we want to interleave # iterate over the datasets we want to interleave
for dataset_dict in tqdm(dataset_names_dict, desc="Combining datasets..."): for dataset_dict in tqdm(dataset_names_dict, desc="Combining datasets..."):
with accelerator.main_process_first(): with accelerator.local_main_process_first():
dataset = load_dataset( dataset = load_dataset(
dataset_dict["name"], dataset_dict["name"],
dataset_dict["config"], dataset_dict["config"],
...@@ -242,7 +241,7 @@ def load_multiple_datasets( ...@@ -242,7 +241,7 @@ def load_multiple_datasets(
# metadata_dataset = metadata_dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24) # metadata_dataset = metadata_dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24)
# metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}") # metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}")
if dataset_dict["name"] != "parler-tts/mls_eng_10k": if dataset_dict["name"] not in {"parler-tts/mls_eng_10k", "parler-tts/mls_eng"}:
if id_column_name is not None and id_column_name not in dataset.column_names: if id_column_name is not None and id_column_name not in dataset.column_names:
raise ValueError( raise ValueError(
f"id_column_name={id_column_name} but has not been found in the dataset columns" f"id_column_name={id_column_name} but has not been found in the dataset columns"
...@@ -272,7 +271,10 @@ def load_multiple_datasets( ...@@ -272,7 +271,10 @@ def load_multiple_datasets(
dataset = concatenate_datasets([dataset, metadata_dataset], axis=1) dataset = concatenate_datasets([dataset, metadata_dataset], axis=1)
if id_column_name is not None and dataset_dict["name"] != "parler-tts/mls_eng_10k": if id_column_name is not None and dataset_dict["name"] not in {
"parler-tts/mls_eng_10k",
"parler-tts/mls_eng",
}:
if ( if (
len( len(
dataset.filter( dataset.filter(
...@@ -304,7 +306,7 @@ def load_multiple_datasets( ...@@ -304,7 +306,7 @@ def load_multiple_datasets(
seed=seed, seed=seed,
) )
else: else:
with accelerator.main_process_first(): with accelerator.local_main_process_first():
interleaved_dataset = concatenate_datasets(all_datasets) interleaved_dataset = concatenate_datasets(all_datasets)
return interleaved_dataset return interleaved_dataset
import torch import torch
from torchaudio.pipelines import SQUIM_OBJECTIVE
import torchaudio
import evaluate import evaluate
from transformers import AutoModel, AutoProcessor, pipeline, WhisperForConditionalGeneration, WhisperTokenizer, WhisperTokenizerFast from transformers import (
AutoModel,
AutoProcessor,
pipeline,
WhisperForConditionalGeneration,
WhisperTokenizer,
WhisperTokenizerFast,
)
from accelerate.utils.memory import release_memory
import numpy as np
def clap_similarity(clap_model_name_or_path, texts, audios, device): def clap_similarity(clap_model_name_or_path, texts, audios, device, input_sampling_rate=44100):
clap = AutoModel.from_pretrained(clap_model_name_or_path) clap = AutoModel.from_pretrained(clap_model_name_or_path)
clap_processor = AutoProcessor.from_pretrained(clap_model_name_or_path) clap_processor = AutoProcessor.from_pretrained(clap_model_name_or_path)
clap_inputs = clap_processor(text=texts, audios=audios, padding=True, return_tensors="pt").to(device) output_sampling_rate = clap_processor.feature_extractor.sampling_rate
if input_sampling_rate != output_sampling_rate:
audios = [
torchaudio.functional.resample(torch.from_numpy(audio), input_sampling_rate, output_sampling_rate).numpy()
for audio in audios
]
clap_inputs = clap_processor(
text=texts, audios=audios, padding=True, return_tensors="pt", sampling_rate=output_sampling_rate
).to(device)
clap.to(device) clap.to(device)
with torch.no_grad(): with torch.no_grad():
text_features = clap.get_text_features( text_features = clap.get_text_features(
...@@ -14,16 +34,52 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device): ...@@ -14,16 +34,52 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
) )
audio_features = clap.get_audio_features(clap_inputs["input_features"]) audio_features = clap.get_audio_features(clap_inputs["input_features"])
cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8) cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8).mean()
cosine_sim = cosine_sim.to("cpu")
clap.to("cpu") clap.to("cpu")
clap_inputs.to("cpu") clap, clap_inputs, audio_features, text_features = release_memory(clap, clap_inputs, audio_features, text_features)
return cosine_sim.mean().to("cpu") return cosine_sim
def si_sdr(audios, device, input_sampling_rate=44100):
max_audio_length = 15 * SQUIM_OBJECTIVE.sample_rate
model = SQUIM_OBJECTIVE.get_model().to((device))
output_sampling_rate = SQUIM_OBJECTIVE.sample_rate
if input_sampling_rate != output_sampling_rate:
audios = [
torchaudio.functional.resample(
torch.tensor(audio)[None, :].to(device).float(), input_sampling_rate, output_sampling_rate
)
for audio in audios
]
def apply_squim(waveform):
with torch.no_grad():
waveform = waveform[:, : min(max_audio_length, waveform.shape[1])]
_, _, sdr_sample = model(waveform)
sdr_sample = sdr_sample.cpu()[0]
return sdr_sample
def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate): si_sdrs = [apply_squim(audio) for audio in audios]
audios, model = release_memory(audios, model)
return si_sdrs
def wer(
asr_model_name_or_path,
prompts,
audios,
device,
per_device_eval_batch_size,
sampling_rate,
noise_level_to_compute_clean_wer,
si_sdr_measures,
):
metric = evaluate.load("wer") metric = evaluate.load("wer")
asr_pipeline = pipeline(model=asr_model_name_or_path, device=device) asr_pipeline = pipeline(model=asr_model_name_or_path, device=device, chunk_length_s=25.0)
return_language = None return_language = None
if isinstance(asr_pipeline.model, WhisperForConditionalGeneration): if isinstance(asr_pipeline.model, WhisperForConditionalGeneration):
...@@ -47,7 +103,11 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s ...@@ -47,7 +103,11 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s
normalized_references = [] normalized_references = []
for pred, ref in zip(transcriptions, prompts): for pred, ref in zip(transcriptions, prompts):
normalizer = english_normalizer if return_language and pred["chunks"][0]["language"] == "english" else basic_normalizer normalizer = (
english_normalizer
if isinstance(pred.get("chunks", None), list) and pred["chunks"][0].get("language", None) == "english"
else basic_normalizer
)
norm_ref = normalizer(ref) norm_ref = normalizer(ref)
if len(norm_ref) > 0: if len(norm_ref) > 0:
norm_pred = normalizer(pred["text"]) norm_pred = normalizer(pred["text"])
...@@ -56,4 +116,21 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s ...@@ -56,4 +116,21 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s
word_error = 100 * metric.compute(predictions=normalized_predictions, references=normalized_references) word_error = 100 * metric.compute(predictions=normalized_predictions, references=normalized_references)
return word_error, [t["text"] for t in transcriptions] clean_word_error = None
noisy_word_error = None
percent_clean_samples = 0
if noise_level_to_compute_clean_wer and si_sdr_measures:
si_sdr_measures = np.array(si_sdr_measures)
mask = si_sdr_measures >= noise_level_to_compute_clean_wer
if mask.any():
clean_word_error = 100 * metric.compute(
predictions=np.array(normalized_predictions)[mask], references=np.array(normalized_references)[mask]
)
noisy_word_error = 100 * metric.compute(
predictions=np.array(normalized_predictions)[~mask], references=np.array(normalized_references)[~mask]
)
percent_clean_samples = mask.sum() / len(mask)
asr_pipeline.model.to("cpu")
asr_pipeline = release_memory(asr_pipeline)
return word_error, [t["text"] for t in transcriptions], clean_word_error, noisy_word_error, percent_clean_samples
...@@ -42,7 +42,7 @@ from transformers.optimization import get_scheduler ...@@ -42,7 +42,7 @@ from transformers.optimization import get_scheduler
from transformers.utils import send_example_telemetry from transformers.utils import send_example_telemetry
from accelerate import Accelerator from accelerate import Accelerator, skip_first_batches
from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin
from accelerate.utils.memory import release_memory from accelerate.utils.memory import release_memory
...@@ -52,11 +52,18 @@ from parler_tts import ( ...@@ -52,11 +52,18 @@ from parler_tts import (
build_delay_pattern_mask, build_delay_pattern_mask,
) )
from training.utils import get_last_checkpoint, rotate_checkpoints, log_pred, log_metric from training.utils import (
get_last_checkpoint,
rotate_checkpoints,
log_pred,
log_metric,
load_all_codec_checkpoints,
save_codec_checkpoint,
get_last_codec_checkpoint_step,
)
from training.arguments import ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments from training.arguments import ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments
from training.data import load_multiple_datasets, DataCollatorParlerTTSWithPadding, DataCollatorEncodecWithPadding from training.data import load_multiple_datasets, DataCollatorParlerTTSWithPadding, DataCollatorEncodecWithPadding
from training.eval import clap_similarity, wer from training.eval import clap_similarity, wer, si_sdr
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -80,10 +87,13 @@ def main(): ...@@ -80,10 +87,13 @@ def main():
if training_args.dtype == "float16": if training_args.dtype == "float16":
mixed_precision = "fp16" mixed_precision = "fp16"
torch_dtype = torch.float16
elif training_args.dtype == "bfloat16": elif training_args.dtype == "bfloat16":
mixed_precision = "bf16" mixed_precision = "bf16"
torch_dtype = torch.bfloat16
else: else:
mixed_precision = "no" mixed_precision = "no"
torch_dtype = torch.float32
if data_args.pad_to_max_length and ( if data_args.pad_to_max_length and (
data_args.max_duration_in_seconds is None data_args.max_duration_in_seconds is None
...@@ -97,7 +107,7 @@ def main(): ...@@ -97,7 +107,7 @@ def main():
padding = "max_length" if data_args.pad_to_max_length else "longest" padding = "max_length" if data_args.pad_to_max_length else "longest"
####### A. Preparation ####### A. Preparation
kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=60))] kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=120))]
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=training_args.gradient_accumulation_steps, gradient_accumulation_steps=training_args.gradient_accumulation_steps,
...@@ -192,7 +202,7 @@ def main(): ...@@ -192,7 +202,7 @@ def main():
token=data_args.token, token=data_args.token,
trust_remote_code=data_args.trust_remote_code, trust_remote_code=data_args.trust_remote_code,
use_fast=model_args.use_fast_tokenizer, use_fast=model_args.use_fast_tokenizer,
padding_side="left", # prompt has to be padded on the left bc it's preprend to codebooks hidden states padding_side=model_args.prompt_padding_side,
) )
# load description tokenizer # load description tokenizer
...@@ -219,7 +229,8 @@ def main(): ...@@ -219,7 +229,8 @@ def main():
# assume that the dataset has been saved to `save_to_disk` if the latter is not empty # assume that the dataset has been saved to `save_to_disk` if the latter is not empty
dataset_was_precomputed = len(os.listdir(data_args.save_to_disk)) > 0 dataset_was_precomputed = len(os.listdir(data_args.save_to_disk)) > 0
if dataset_was_precomputed: if dataset_was_precomputed:
vectorized_datasets = datasets.load_from_disk(data_args.save_to_disk) with accelerator.local_main_process_first():
vectorized_datasets = datasets.load_from_disk(data_args.save_to_disk)
else: else:
raw_datasets = DatasetDict() raw_datasets = DatasetDict()
...@@ -282,9 +293,10 @@ def main(): ...@@ -282,9 +293,10 @@ def main():
) )
if data_args.max_eval_samples is not None: if data_args.max_eval_samples is not None:
raw_datasets["eval"] = ( with accelerator.local_main_process_first():
raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples)) raw_datasets["eval"] = (
) raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
)
# 3. Next, let's load the config. # 3. Next, let's load the config.
config = ParlerTTSConfig.from_pretrained( config = ParlerTTSConfig.from_pretrained(
...@@ -295,6 +307,13 @@ def main(): ...@@ -295,6 +307,13 @@ def main():
) )
# update pad token id and decoder_start_token_id # update pad token id and decoder_start_token_id
config.decoder.update(
{
"cross_attention_implementation_strategy": model_args.cross_attention_implementation_strategy
if model_args.cross_attention_implementation_strategy is not None
else None
}
)
config.update( config.update(
{ {
"pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else config.pad_token_id, "pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else config.pad_token_id,
...@@ -311,6 +330,7 @@ def main(): ...@@ -311,6 +330,7 @@ def main():
config=config, config=config,
token=data_args.token, token=data_args.token,
trust_remote_code=data_args.trust_remote_code, trust_remote_code=data_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
) )
# enable gradient checkpointing if necessary # enable gradient checkpointing if necessary
...@@ -336,11 +356,13 @@ def main(): ...@@ -336,11 +356,13 @@ def main():
max_length = model.generation_config.max_length max_length = model.generation_config.max_length
num_codebooks = model.decoder.config.num_codebooks num_codebooks = model.decoder.config.num_codebooks
bandwidth = model_args.bandwidth bandwidth = model_args.bandwidth
attn_implementation = model_args.attn_implementation
# Freeze Encoders # Freeze Encoders
model.freeze_encoders(model_args.freeze_text_encoder) model.freeze_encoders(model_args.freeze_text_encoder)
# Test all gather - used for warmout and avoiding timeout # Test all gather - used for warmout and avoiding timeout
logger.debug(str(accelerator.process_index), main_process_only=False, in_order=True)
test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device) test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
gathered_tensor = accelerator.gather(test_tensor) gathered_tensor = accelerator.gather(test_tensor)
print("gathered_tensor", gathered_tensor) print("gathered_tensor", gathered_tensor)
...@@ -349,7 +371,7 @@ def main(): ...@@ -349,7 +371,7 @@ def main():
if not dataset_was_precomputed: if not dataset_was_precomputed:
# Filter on text length # Filter on text length
if description_column_name is not None and data_args.max_text_length is not None: if description_column_name is not None and data_args.max_text_length is not None:
with accelerator.main_process_first(): with accelerator.local_main_process_first():
# filter description that is shorter than max_text_length # filter description that is shorter than max_text_length
raw_datasets = raw_datasets.filter( raw_datasets = raw_datasets.filter(
lambda x: len(x) < data_args.max_text_length, lambda x: len(x) < data_args.max_text_length,
...@@ -367,7 +389,7 @@ def main(): ...@@ -367,7 +389,7 @@ def main():
return batch return batch
with accelerator.main_process_first(): with accelerator.local_main_process_first():
# this is a trick to avoid to rewrite the entire audio column which takes ages # this is a trick to avoid to rewrite the entire audio column which takes ages
vectorized_datasets = raw_datasets.map( vectorized_datasets = raw_datasets.map(
pass_through_processors, pass_through_processors,
...@@ -410,7 +432,41 @@ def main(): ...@@ -410,7 +432,41 @@ def main():
output["len_audio"] = len_audio output["len_audio"] = len_audio
# (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks) # (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
output["labels"] = labels.squeeze(0).transpose(1, 2) output["labels"] = labels.squeeze(0).transpose(1, 2)
output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / len_audio.max()
# if `pad_to_max_length`, the maximum corresponding audio length of the current batch is max_duration*sampling_rate
max_length = len_audio.max() if padding != "max_length" else max_target_length
output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / max_length
return output
# (1, codebooks, seq_len) where seq_len=1
bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id
def postprocess_dataset(labels):
# (1, codebooks, seq_len)
labels = torch.tensor(labels).unsqueeze(0)
# add bos
labels = torch.cat([bos_labels, labels], dim=-1)
labels, delay_pattern_mask = build_delay_pattern_mask(
labels,
bos_token_id=audio_encoder_bos_token_id,
pad_token_id=audio_encoder_eos_token_id,
max_length=labels.shape[-1] + num_codebooks,
num_codebooks=num_codebooks,
)
# the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
# to take care of EOS
# we want labels to look like this:
# - [B, a, b, E, E, E, E]
# - [B, B, c, d, E, E, E]
# - [B, B, B, e, f, E, E]
# - [B, B, B, B, g, h, E]
labels = torch.where(delay_pattern_mask == -1, audio_encoder_eos_token_id, delay_pattern_mask)
# the first timestamp is associated to a row full of BOS, let's get rid of it
# we also remove the last timestampts (full of PAD)
output = {"labels": labels[:, 1:]}
return output return output
for split in vectorized_datasets: for split in vectorized_datasets:
...@@ -422,79 +478,79 @@ def main(): ...@@ -422,79 +478,79 @@ def main():
pin_memory=True, pin_memory=True,
) )
data_loader = accelerator.prepare(data_loader) data_loader = accelerator.prepare(data_loader)
total_inference_steps = len(data_loader)
start_step = get_last_codec_checkpoint_step(os.path.join(data_args.temporary_save_to_disk, split))
accelerator.wait_for_everyone()
if start_step > 0:
logger.info(f"Resuming {split} from step {start_step}")
# efficiently skip the first n batches
start_step += 1
data_loader = skip_first_batches(data_loader, start_step)
all_generated_labels = [] all_generated_labels = []
all_lens = [] all_lens = []
for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process): if start_step < total_inference_steps:
generate_labels = apply_audio_decoder(batch) for i, batch in enumerate(tqdm(data_loader, disable=not accelerator.is_local_main_process)):
generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0) cur_step = start_step + i
generate_labels = accelerator.gather_for_metrics(generate_labels) generate_labels = apply_audio_decoder(batch)
generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0)
if accelerator.is_main_process: generate_labels = accelerator.gather_for_metrics(generate_labels)
lab = generate_labels["labels"].cpu().transpose(1, 2).to(torch.int16)
rat = generate_labels["ratio"].cpu().squeeze()
lens = generate_labels["len_audio"].cpu().squeeze()
lab = [l[:, : int(ratio * length)] for (l, ratio, length) in zip(lab, rat, lens)]
all_generated_labels.extend(lab) if accelerator.is_main_process:
all_lens.extend(lens) lab = generate_labels["labels"].cpu().transpose(1, 2).to(torch.int16)
rat = generate_labels["ratio"].cpu().squeeze(1)
# (1, codebooks, seq_len) where seq_len=1 lens = generate_labels["len_audio"].cpu().squeeze(1)
bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id lab = [l[:, : int(ratio * length)] for (l, ratio, length) in zip(lab, rat, lens)]
if accelerator.is_main_process: all_generated_labels.extend(lab)
tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens}) all_lens.extend(lens)
tmp_labels.save_to_disk(
os.path.join(data_args.temporary_save_to_disk, split),
num_proc=1 if split == "eval" else data_args.preprocessing_num_workers,
)
accelerator.wait_for_everyone()
del all_generated_labels
tmp_labels = datasets.load_from_disk(os.path.join(data_args.temporary_save_to_disk, split)) if ((cur_step + 1) % data_args.save_codec_steps == 0) or (
with accelerator.main_process_first(): cur_step == total_inference_steps - 1
vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1) ):
tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens})
tmp_labels = tmp_labels.map(
postprocess_dataset,
num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
input_columns=["labels"],
desc="Postprocessing labeling",
)
save_codec_checkpoint(
os.path.join(data_args.temporary_save_to_disk, split), tmp_labels, cur_step
)
all_generated_labels = []
all_lens = []
def postprocess_dataset(labels): accelerator.wait_for_everyone()
# (1, codebooks, seq_len)
labels = torch.tensor(labels).unsqueeze(0)
# add bos
labels = torch.cat([bos_labels, labels], dim=-1)
labels, delay_pattern_mask = build_delay_pattern_mask(
labels,
bos_token_id=audio_encoder_bos_token_id,
pad_token_id=audio_encoder_eos_token_id,
max_length=labels.shape[-1] + num_codebooks,
num_codebooks=num_codebooks,
)
# the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask if accelerator.is_main_process and len(all_generated_labels) > 0:
# to take care of EOS tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens})
# we want labels to look like this: tmp_labels = tmp_labels.map(
# - [B, a, b, E, E, E, E]
# - [B, B, c, d, E, E, E]
# - [B, B, B, e, f, E, E]
# - [B, B, B, B, g, h, E]
labels = torch.where(delay_pattern_mask == -1, audio_encoder_eos_token_id, delay_pattern_mask)
# the first timestamp is associated to a row full of BOS, let's get rid of it
# we also remove the last timestampts (full of PAD)
output = {"labels": labels[:, 1:]}
return output
with accelerator.main_process_first():
vectorized_datasets[split] = vectorized_datasets[split].map(
postprocess_dataset, postprocess_dataset,
num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor. num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
input_columns=["labels"], input_columns=["labels"],
desc="Postprocessing labeling", desc="Postprocessing labeling",
) )
save_codec_checkpoint(os.path.join(data_args.temporary_save_to_disk, split), tmp_labels, cur_step)
all_generated_labels = []
all_lens = []
accelerator.wait_for_everyone()
del all_generated_labels
accelerator.wait_for_everyone()
with accelerator.local_main_process_first():
tmp_labels = load_all_codec_checkpoints(os.path.join(data_args.temporary_save_to_disk, split)).select(
range(len(vectorized_datasets[split]))
)
logger.info(f"Concatenating {split}: {tmp_labels} with {vectorized_datasets[split]}")
vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1)
accelerator.free_memory() accelerator.free_memory()
del generate_labels, all_lens del generate_labels, all_lens
with accelerator.main_process_first(): with accelerator.local_main_process_first():
# NOTE: filtering is done at the end because in the `datasets` library, caching audio files is done after most operations # NOTE: filtering is done at the end because in the `datasets` library, caching audio files is done after most operations
# caching audio files is time and disk-space consuming, so we want to avoid it at all costs, especially for large (>1Kh) audio datasets. # caching audio files is time and disk-space consuming, so we want to avoid it at all costs, especially for large (>1Kh) audio datasets.
# That's also why we avoid to concat the processed datasets (vectorized_datasets) with the audio column present in raw_datasets. # That's also why we avoid to concat the processed datasets (vectorized_datasets) with the audio column present in raw_datasets.
...@@ -509,23 +565,23 @@ def main(): ...@@ -509,23 +565,23 @@ def main():
input_columns=["target_length"], input_columns=["target_length"],
) )
if description_column_name is not None and data_args.max_description_token_length is not None: if description_column_name is not None and data_args.max_description_token_length is not None:
with accelerator.main_process_first(): with accelerator.local_main_process_first():
# filter description that is shorter than max_text_length # filter description that is shorter than max_text_length
vectorized_datasets = vectorized_datasets.filter( vectorized_datasets = vectorized_datasets.filter(
lambda x: len(x) < data_args.max_description_token_length, lambda x: len(x) < data_args.max_description_token_length,
num_proc=num_workers, num_proc=num_workers,
input_columns=["input_ids"], input_columns=["input_ids"],
) )
if data_args.max_prompt_token_length is not None: if data_args.max_prompt_token_length is not None:
with accelerator.main_process_first(): with accelerator.local_main_process_first():
# filter description that is shorter than max_text_length # filter description that is shorter than max_text_length
vectorized_datasets = vectorized_datasets.filter( vectorized_datasets = vectorized_datasets.filter(
lambda x: len(x) < data_args.max_prompt_token_length, lambda x: len(x) < data_args.max_prompt_token_length,
num_proc=num_workers, num_proc=num_workers,
input_columns=["prompt_input_ids"], input_columns=["prompt_input_ids"],
) )
if data_args.save_to_disk is not None and not dataset_was_precomputed: if data_args.save_to_disk is not None and not dataset_was_precomputed:
if accelerator.is_main_process: if accelerator.is_main_process:
...@@ -533,25 +589,44 @@ def main(): ...@@ -533,25 +589,44 @@ def main():
data_args.save_to_disk, data_args.save_to_disk,
num_proc=min(data_args.preprocessing_num_workers, len(vectorized_datasets["eval"]) - 1), num_proc=min(data_args.preprocessing_num_workers, len(vectorized_datasets["eval"]) - 1),
) )
accelerator.wait_for_everyone()
logger.info(f"Dataset saved at {data_args.save_to_disk}") logger.info(f"Dataset saved at {data_args.save_to_disk}")
audio_max_length = None audio_max_length = None
if padding == "max_length": if padding == "max_length":
audio_max_length = max(vectorized_datasets["train"]["target_length"]) audio_max_length = max(vectorized_datasets["train"]["target_length"])
with accelerator.main_process_first(): with accelerator.local_main_process_first():
max_sample = vectorized_datasets["train"].filter( max_sample = vectorized_datasets["train"].filter(
lambda x: x == audio_max_length, lambda x: x == audio_max_length,
num_proc=num_workers, num_proc=num_workers,
input_columns=["target_length"], input_columns=["target_length"],
) )
audio_max_length = torch.tensor(max_sample[0]["labels"]).shape[1] audio_max_length = max([len(l[0]) for l in max_sample["labels"]])
if description_column_name is not None and data_args.max_description_token_length is not None:
with accelerator.local_main_process_first():
# filter description that is shorter than max_text_length
vectorized_datasets = vectorized_datasets.filter(
lambda x: len(x) < data_args.max_description_token_length,
num_proc=num_workers,
input_columns=["input_ids"],
)
if data_args.max_prompt_token_length is not None:
with accelerator.local_main_process_first():
# filter description that is shorter than max_text_length
vectorized_datasets = vectorized_datasets.filter(
lambda x: len(x) < data_args.max_prompt_token_length,
num_proc=num_workers,
input_columns=["prompt_input_ids"],
)
if training_args.group_by_length: if training_args.group_by_length:
# apply a simple heuristic to take into account audio and text lengths # apply a simple heuristic to take into account audio and text lengths
def add_target_lengths(target_length, prompt, description): def add_target_lengths(target_length, prompt, description):
return {"target_length": target_length + len(prompt) + len(description)} return {"target_length": target_length + len(prompt) + len(description)}
with accelerator.main_process_first(): with accelerator.local_main_process_first():
vectorized_datasets = vectorized_datasets.map( vectorized_datasets = vectorized_datasets.map(
add_target_lengths, add_target_lengths,
num_proc=num_workers, num_proc=num_workers,
...@@ -574,27 +649,48 @@ def main(): ...@@ -574,27 +649,48 @@ def main():
# 6. Next, we can prepare the training. # 6. Next, we can prepare the training.
# Let's use word CLAP similary and WER metrics as our evaluation metrics, # Let's use word CLAP similary and WER metrics as our evaluation metrics,
def compute_metrics(audios, descriptions, prompts, device="cpu"): def compute_metrics(
audios,
descriptions,
prompts,
device="cpu",
compute_clap_similarity_metric=False,
compute_noise_level_metric=False,
noise_level_to_compute_clean_wer=None,
):
results = {} results = {}
input_ids = descriptions input_ids = descriptions
texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True) texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True) prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
audios = [a.cpu().numpy() for a in audios] audios = [a.float().cpu().numpy() for a in audios]
clap_score = clap_similarity(model_args.clap_model_name_or_path, texts, audios, device) if compute_clap_similarity_metric:
results["clap"] = clap_score clap_score = clap_similarity(
model_args.clap_model_name_or_path, texts, audios, device, input_sampling_rate=sampling_rate
)
results["clap"] = clap_score
si_sdr_measures = None
if compute_noise_level_metric:
si_sdr_measures = si_sdr(audios, device, input_sampling_rate=sampling_rate)
word_error, transcriptions = wer( word_error, transcriptions, clean_word_error, noisy_word_error, percent_clean_samples = wer(
model_args.asr_model_name_or_path, model_args.asr_model_name_or_path,
prompts, prompts,
audios, audios,
device, device,
training_args.per_device_eval_batch_size, training_args.per_device_eval_batch_size,
sampling_rate, sampling_rate,
noise_level_to_compute_clean_wer,
si_sdr_measures,
) )
results["wer"] = word_error results["wer"] = word_error
if clean_word_error is not None:
results["clean_wer"] = clean_word_error
results["noisy_word_error"] = noisy_word_error
results["percent_clean_samples"] = percent_clean_samples
return results, texts, prompts, audios, transcriptions return results, texts, prompts, audios, transcriptions, si_sdr_measures
# Define Training Schedule # Define Training Schedule
# Store some constants # Store some constants
...@@ -698,24 +794,24 @@ def main(): ...@@ -698,24 +794,24 @@ def main():
# Now save everything to be able to create a single processor later # Now save everything to be able to create a single processor later
# make sure all processes wait until data is saved # make sure all processes wait until data is saved
with accelerator.main_process_first(): # only the main process saves them
# only the main process saves them if accelerator.is_main_process:
if accelerator.is_main_process: # save feature extractor, tokenizer and config
# save feature extractor, tokenizer and config if (
if ( model_args.prompt_tokenizer_name is None
model_args.prompt_tokenizer_name is None and model_args.description_tokenizer_name
and model_args.description_tokenizer_name or (model_args.prompt_tokenizer_name == model_args.description_tokenizer_name)
or (model_args.prompt_tokenizer_name == model_args.description_tokenizer_name) ):
): prompt_tokenizer.save_pretrained(training_args.output_dir)
prompt_tokenizer.save_pretrained(training_args.output_dir) else:
else: logger.warning(
logger.warning( f"Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer."
f"Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer." )
) prompt_tokenizer.save_pretrained(training_args.output_dir)
prompt_tokenizer.save_pretrained(training_args.output_dir)
feature_extractor.save_pretrained(training_args.output_dir) feature_extractor.save_pretrained(training_args.output_dir)
config.save_pretrained(training_args.output_dir) config.save_pretrained(training_args.output_dir)
accelerator.wait_for_everyone()
if checkpoint is not None: if checkpoint is not None:
accelerator.load_state(checkpoint) accelerator.load_state(checkpoint)
...@@ -732,7 +828,8 @@ def main(): ...@@ -732,7 +828,8 @@ def main():
steps_trained_progress_bar.update(cur_step) steps_trained_progress_bar.update(cur_step)
for epoch in range(0, epochs_trained): for epoch in range(0, epochs_trained):
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed) with accelerator.local_main_process_first():
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
if training_args.max_steps < 0: if training_args.max_steps < 0:
# we know exactly the number of steps per epoch, so can skip through the required number of batches # we know exactly the number of steps per epoch, so can skip through the required number of batches
...@@ -742,7 +839,8 @@ def main(): ...@@ -742,7 +839,8 @@ def main():
# So we just shuffle the dataset one extra time and start from a fresh epoch # So we just shuffle the dataset one extra time and start from a fresh epoch
# This is "good enough" for our purposes but not fully correct # This is "good enough" for our purposes but not fully correct
resume_step = None resume_step = None
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed) with accelerator.local_main_process_first():
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
else: else:
resume_step = None resume_step = None
...@@ -762,8 +860,6 @@ def main(): ...@@ -762,8 +860,6 @@ def main():
accelerator, accelerator,
autocast_kwargs, autocast_kwargs,
): ):
model.train()
if mixed_precision == "fp16": if mixed_precision == "fp16":
# fp16 doesn't work with T5-like models # fp16 doesn't work with T5-like models
with accelerator.autocast(autocast_handler=autocast_kwargs): with accelerator.autocast(autocast_handler=autocast_kwargs):
...@@ -775,6 +871,22 @@ def main(): ...@@ -775,6 +871,22 @@ def main():
encoder_outputs = model.module.text_encoder( encoder_outputs = model.module.text_encoder(
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None) input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
) )
# we optionnally project last_hidden_state to avoid recomputing every time
encoder_hidden_states = encoder_outputs.last_hidden_state
if (
config.text_encoder.hidden_size != config.decoder.hidden_size
and config.decoder.cross_attention_hidden_size is None
):
encoder_hidden_states = (
model.enc_to_dec_proj(encoder_hidden_states)
if training_args.parallel_mode.value != "distributed"
else model.module.enc_to_dec_proj(encoder_hidden_states)
)
if batch.get("attention_mask", None) is not None:
encoder_hidden_states = encoder_hidden_states * batch.get("attention_mask", None)[..., None]
encoder_outputs.last_hidden_state = encoder_hidden_states
batch["encoder_outputs"] = encoder_outputs batch["encoder_outputs"] = encoder_outputs
outputs = model(**batch) outputs = model(**batch)
...@@ -791,20 +903,34 @@ def main(): ...@@ -791,20 +903,34 @@ def main():
autocast_kwargs, autocast_kwargs,
): ):
eval_model = model if not training_args.torch_compile else model._orig_mod eval_model = model if not training_args.torch_compile else model._orig_mod
eval_model.eval()
if mixed_precision == "fp16": if mixed_precision == "fp16":
# fp16 doesn't work with T5-like models # fp16 doesn't work with T5-like models
with accelerator.autocast(autocast_handler=autocast_kwargs): with accelerator.autocast(autocast_handler=autocast_kwargs):
with torch.no_grad(): if training_args.parallel_mode.value != "distributed":
if training_args.parallel_mode.value != "distributed" or training_args.torch_compile: encoder_outputs = model.text_encoder(
encoder_outputs = eval_model.text_encoder( input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None) )
) else:
else: encoder_outputs = model.module.text_encoder(
encoder_outputs = eval_model.module.text_encoder( input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None) )
) # we optionnally project last_hidden_state to avoid recomputing every time
encoder_hidden_states = encoder_outputs.last_hidden_state
if (
config.text_encoder.hidden_size != config.decoder.hidden_size
and config.decoder.cross_attention_hidden_size is None
):
encoder_hidden_states = (
model.enc_to_dec_proj(encoder_hidden_states)
if training_args.parallel_mode.value != "distributed"
else model.module.enc_to_dec_proj(encoder_hidden_states)
)
if batch.get("attention_mask", None) is not None:
encoder_hidden_states = encoder_hidden_states * batch.get("attention_mask", None)[..., None]
encoder_outputs.last_hidden_state = encoder_hidden_states
batch["encoder_outputs"] = encoder_outputs batch["encoder_outputs"] = encoder_outputs
with torch.no_grad(): with torch.no_grad():
...@@ -814,18 +940,24 @@ def main(): ...@@ -814,18 +940,24 @@ def main():
metrics = {"loss": ce_loss} metrics = {"loss": ce_loss}
return metrics return metrics
def generate_step(batch): def generate_step(batch, accelerator):
batch.pop("decoder_attention_mask", None) batch.pop("decoder_attention_mask", None)
eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper=mixed_precision != "fp16").eval() eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper=True)
if training_args.torch_compile: if training_args.torch_compile:
# if the model is compiled, we use the original model bc compile is not compatible with .generate
eval_model = model._orig_mod eval_model = model._orig_mod
# since we've might have loaded the weights in fp32, we have to autocast to ensure FA2 weights are in half-precision.
# with accelerator.autocast(autocast_handler=AutocastKwargs(enabled=(attn_implementation=="flash_attention_2"))):
output_audios = eval_model.generate(**batch, **gen_kwargs) output_audios = eval_model.generate(**batch, **gen_kwargs)
output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0) output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0)
return output_audios return output_audios
model.train()
for epoch in range(epochs_trained, num_epochs): for epoch in range(epochs_trained, num_epochs):
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed) with accelerator.local_main_process_first():
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
sampler = None sampler = None
if training_args.group_by_length: if training_args.group_by_length:
sampler = LengthGroupedSampler(train_batch_size, lengths=vectorized_datasets["train"]["target_length"]) sampler = LengthGroupedSampler(train_batch_size, lengths=vectorized_datasets["train"]["target_length"])
...@@ -843,8 +975,10 @@ def main(): ...@@ -843,8 +975,10 @@ def main():
if resume_step is not None: if resume_step is not None:
# Skip the first N batches in the dataloader when resuming from a checkpoint # Skip the first N batches in the dataloader when resuming from a checkpoint
logger.info(f" Skip first {resume_step} batches")
train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
resume_step = None resume_step = None
accelerator.wait_for_everyone()
for batch in train_dataloader: for batch in train_dataloader:
with accelerator.accumulate(model): with accelerator.accumulate(model):
...@@ -901,10 +1035,12 @@ def main(): ...@@ -901,10 +1035,12 @@ def main():
commit_message=f"Saving train state of step {cur_step}", commit_message=f"Saving train state of step {cur_step}",
run_as_future=True, run_as_future=True,
) )
accelerator.wait_for_everyone()
if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps): if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
train_time += time.time() - train_start train_time += time.time() - train_start
# ======================== Evaluating ============================== # ======================== Evaluating ==============================
model.eval()
eval_metrics = [] eval_metrics = []
eval_preds = [] eval_preds = []
eval_descriptions = [] eval_descriptions = []
...@@ -919,7 +1055,7 @@ def main(): ...@@ -919,7 +1055,7 @@ def main():
collate_fn=data_collator, collate_fn=data_collator,
batch_size=per_device_eval_batch_size, batch_size=per_device_eval_batch_size,
drop_last=False, drop_last=False,
num_workers=training_args.dataloader_pin_memory, num_workers=training_args.eval_dataloader_num_workers,
pin_memory=training_args.dataloader_pin_memory, pin_memory=training_args.dataloader_pin_memory,
) )
validation_dataloader = accelerator.prepare(validation_dataloader) validation_dataloader = accelerator.prepare(validation_dataloader)
...@@ -952,7 +1088,7 @@ def main(): ...@@ -952,7 +1088,7 @@ def main():
position=2, position=2,
disable=not accelerator.is_local_main_process, disable=not accelerator.is_local_main_process,
): ):
generated_audios = generate_step(batch) generated_audios = generate_step(batch, accelerator)
# Gather all predictions and targets # Gather all predictions and targets
generated_audios, input_ids, prompts = accelerator.pad_across_processes( generated_audios, input_ids, prompts = accelerator.pad_across_processes(
(generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0 (generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0
...@@ -967,35 +1103,51 @@ def main(): ...@@ -967,35 +1103,51 @@ def main():
eval_time = time.time() - eval_start eval_time = time.time() - eval_start
# normalize eval metrics # normalize eval metrics
eval_metrics = { eval_metrics = {
key: torch.mean(torch.cat([d[key].unsqueeze(0) for d in eval_metrics])) key: torch.mean(torch.cat([d[key] for d in eval_metrics])).to("cpu") for key in eval_metrics[0]
for key in eval_metrics[0]
} }
# compute metrics # compute metrics
metrics_desc = "" metrics_desc = ""
if training_args.predict_with_generate: if training_args.predict_with_generate:
metric_values, pred_descriptions, pred_prompts, audios, transcriptions = compute_metrics( if accelerator.is_local_main_process:
eval_preds, eval_descriptions, eval_prompts, accelerator.device (
) metric_values,
eval_metrics.update(metric_values)
metrics_desc = " ".join([f"Eval {key}: {value} |" for key, value in metric_values.items()])
if "wandb" in training_args.report_to:
log_pred(
accelerator,
pred_descriptions, pred_descriptions,
pred_prompts, pred_prompts,
transcriptions,
audios, audios,
sampling_rate=sampling_rate, transcriptions,
step=cur_step, si_sdr_measures,
prefix="eval", ) = compute_metrics(
eval_preds,
eval_descriptions,
eval_prompts,
accelerator.device,
training_args.compute_clap_similarity_metric,
training_args.compute_noise_level_metric,
training_args.noise_level_to_compute_clean_wer,
) )
eval_metrics.update(metric_values)
metrics_desc = " ".join([f"Eval {key}: {value} |" for key, value in metric_values.items()])
if "wandb" in training_args.report_to:
log_pred(
accelerator,
pred_descriptions,
pred_prompts,
transcriptions,
audios,
si_sdr_measures,
sampling_rate=sampling_rate,
step=cur_step,
prefix="eval",
)
accelerator.wait_for_everyone()
# Print metrics and update progress bar # Print metrics and update progress bar
steps_trained_progress_bar.write( if accelerator.is_local_main_process:
f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |" steps_trained_progress_bar.write(
f" {metrics_desc})" f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |"
) f" {metrics_desc})"
)
log_metric( log_metric(
accelerator, accelerator,
...@@ -1007,11 +1159,14 @@ def main(): ...@@ -1007,11 +1159,14 @@ def main():
) )
# release eval batch and relax metrics # release eval batch and relax metrics
eval_metrics = [] eval_metrics, eval_preds, eval_descriptions, eval_prompts, batch, eval_metric = release_memory(
eval_preds = [] eval_metrics, eval_preds, eval_descriptions, eval_prompts, batch, eval_metric
eval_descriptions = [] )
eval_prompts = [] if training_args.predict_with_generate:
batch = release_memory(batch) generated_audios, input_ids, prompts = release_memory(generated_audios, input_ids, prompts)
# train mode
model.train()
# flush the train metrics # flush the train metrics
train_start = time.time() train_start = time.time()
...@@ -1028,5 +1183,4 @@ def main(): ...@@ -1028,5 +1183,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
set_start_method("spawn")
main() main()
...@@ -7,6 +7,7 @@ from typing import Dict, List ...@@ -7,6 +7,7 @@ from typing import Dict, List
import torch import torch
from wandb import Audio from wandb import Audio
from datasets import load_from_disk, concatenate_datasets
def list_field(default=None, metadata=None): def list_field(default=None, metadata=None):
...@@ -14,6 +15,8 @@ def list_field(default=None, metadata=None): ...@@ -14,6 +15,8 @@ def list_field(default=None, metadata=None):
_RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$") _RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$")
CHECKPOINT_CODEC_PREFIX = "checkpoint"
_RE_CODEC_CHECKPOINT = re.compile(r"^checkpoint-(\d+)$")
def get_last_checkpoint(folder): def get_last_checkpoint(folder):
...@@ -60,6 +63,59 @@ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix ...@@ -60,6 +63,59 @@ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix
shutil.rmtree(checkpoint, ignore_errors=True) shutil.rmtree(checkpoint, ignore_errors=True)
def save_codec_checkpoint(output_dir, dataset, step):
checkpoint_path = f"{CHECKPOINT_CODEC_PREFIX}-{step}"
output_path = os.path.join(output_dir, checkpoint_path)
dataset.save_to_disk(output_path)
def load_codec_checkpoint(checkpoint_path):
dataset = load_from_disk(checkpoint_path)
return dataset
def sorted_codec_checkpoints(output_dir=None) -> List[str]:
"""Helper function to sort saved checkpoints from oldest to newest."""
ordering_and_checkpoint_path = []
glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{CHECKPOINT_CODEC_PREFIX}-*")]
for path in glob_checkpoints:
regex_match = re.match(f".*{CHECKPOINT_CODEC_PREFIX}-([0-9]+)", path)
if regex_match is not None and regex_match.groups() is not None:
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
return checkpoints_sorted
def load_all_codec_checkpoints(output_dir=None) -> List[str]:
"""Helper function to load and concat all checkpoints."""
checkpoints_sorted = sorted_codec_checkpoints(output_dir=output_dir)
datasets = [load_from_disk(checkpoint) for checkpoint in checkpoints_sorted]
datasets = concatenate_datasets(datasets, axis=0)
return datasets
def get_last_codec_checkpoint_step(folder) -> int:
if not os.path.exists(folder) or not os.path.isdir(folder):
os.makedirs(folder, exist_ok=True)
return 0
content = os.listdir(folder)
checkpoints = [path for path in content if _RE_CODEC_CHECKPOINT.search(path) is not None]
if len(checkpoints) == 0:
return 0
last_checkpoint = os.path.join(
folder, max(checkpoints, key=lambda x: int(_RE_CODEC_CHECKPOINT.search(x).groups()[0]))
)
# Find num steps saved state string pattern
pattern = r"checkpoint-(\d+)"
match = re.search(pattern, last_checkpoint)
cur_step = int(match.group(1))
return cur_step
def log_metric( def log_metric(
accelerator, accelerator,
metrics: Dict, metrics: Dict,
...@@ -86,6 +142,7 @@ def log_pred( ...@@ -86,6 +142,7 @@ def log_pred(
pred_prompts: List[str], pred_prompts: List[str],
transcriptions: List[str], transcriptions: List[str],
audios: List[torch.Tensor], audios: List[torch.Tensor],
si_sdr_measures: List[float],
sampling_rate: int, sampling_rate: int,
step: int, step: int,
prefix: str = "eval", prefix: str = "eval",
...@@ -98,16 +155,33 @@ def log_pred( ...@@ -98,16 +155,33 @@ def log_pred(
cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step
prefix_pretty = prefix.replace("/", "-") prefix_pretty = prefix.replace("/", "-")
# convert str data to a wandb compatible format if si_sdr_measures is None:
str_data = [[pred_descriptions[i], pred_prompts[i], transcriptions[i]] for i in range(len(pred_descriptions))] # convert str data to a wandb compatible format
# log as a table with the appropriate headers str_data = [
wandb_tracker.log_table( [pred_descriptions[i], pred_prompts[i], transcriptions[i]] for i in range(len(pred_descriptions))
table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}", ]
columns=["Target descriptions", "Target prompts", "Predicted transcriptions"], # log as a table with the appropriate headers
data=str_data[:num_lines], wandb_tracker.log_table(
step=step, table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
commit=False, columns=["Target descriptions", "Target prompts", "Predicted transcriptions"],
) data=str_data[:num_lines],
step=step,
commit=False,
)
else:
# convert str data to a wandb compatible format
str_data = [
[pred_descriptions[i], pred_prompts[i], transcriptions[i], si_sdr_measures[i]]
for i in range(len(pred_descriptions))
]
# log as a table with the appropriate headers
wandb_tracker.log_table(
table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
columns=["Target descriptions", "Target prompts", "Predicted transcriptions", "Noise estimation"],
data=str_data[:num_lines],
step=step,
commit=False,
)
# wandb can only loads 100 audios per step # wandb can only loads 100 audios per step
wandb_tracker.log( wandb_tracker.log(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment