Unverified Commit 11b209e1 authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Architecture improvements (#65)



* add RoPe

* don't include padding in rope

* possibly use cross-attn for prompt

* fix rope

* fix cross-attn

* fix self-attn

* fix dummy model

* clean-up rope

* first gqa implementation

* fix wer eval

* feat: add flash attention and spda

* chore: add README for flash attention

* chore: add benchmark script

* chore: add benchmark attention approach

* multi node and fix wer and fix compile

* Update modeling_parler_tts.py

* fix FA2, SDPA and add cross-attn MHA and attention type forcing

* better cross_attention key values number of heads default + add training arguments for attn implementation

* fix audio padding when torch compile or pad_to_max_length=True

* correct multi node

* make rope faster

* fix encoder sdpa

* fix training with cross attention + with FAZ

* use fp32 as default model dtype + fix generation when using FA2 with autocast

* remove redundant passes in generate + clean and fix attentions

* fix edge case in WER evaluation when longform generation

* better multi-node mapping and saving / add eval dataloader num workers

* remove old benchmarks

* faster audio encoding + checkpointing + fix generation step

* better eval + add right padding + fix eval loss compute

* correct README

* correct config docstrings

* remove comment

* make style

---------
Co-authored-by: default avatarsanchit-gandhi <sanchit@huggingface.co>
Co-authored-by: default avatarsang-nguyen-ts <sang.nguyen@trustingsocial.com>
Co-authored-by: yoach@huggingface.co <Yoach Lacombe>
parent 8b8c576e
...@@ -53,7 +53,8 @@ if torch.xpu.is_available(): ...@@ -53,7 +53,8 @@ if torch.xpu.is_available():
device = "xpu" device = "xpu"
torch_dtype = torch.float16 if device != "cpu" else torch.float32 torch_dtype = torch.float16 if device != "cpu" else torch.float32
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device, dtype=torch_dtype) model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1", torch_dtype=torch_dtype).to(device)
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1") tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1")
prompt = "Hey, how are you doing today?" prompt = "Hey, how are you doing today?"
......
...@@ -60,8 +60,8 @@ if __name__ == "__main__": ...@@ -60,8 +60,8 @@ if __name__ == "__main__":
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate) model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = True # True model.generation_config.do_sample = True # True
model.generation_config.guidance_scale = 1 # 3.0 model.generation_config.guidance_scale = 1 # 3.0
model.config.pad_token_id = encodec_vocab_size model.config.pad_token_id = encodec_vocab_size
model.config.decoder_start_token_id = encodec_vocab_size+1 model.config.decoder_start_token_id = encodec_vocab_size + 1
model.save_pretrained(os.path.join(args.save_directory, "tiny-model")) model.save_pretrained(os.path.join(args.save_directory, "tiny-model"))
...@@ -58,4 +58,7 @@ if __name__ == "__main__": ...@@ -58,4 +58,7 @@ if __name__ == "__main__":
model.generation_config.do_sample = True # True model.generation_config.do_sample = True # True
model.generation_config.guidance_scale = 1 # 3.0 model.generation_config.guidance_scale = 1 # 3.0
model.config.pad_token_id = encodec_vocab_size
model.config.decoder_start_token_id = encodec_vocab_size + 1
model.save_pretrained(os.path.join(args.save_directory, "tiny-model")) model.save_pretrained(os.path.join(args.save_directory, "tiny-model"))
...@@ -60,8 +60,8 @@ if __name__ == "__main__": ...@@ -60,8 +60,8 @@ if __name__ == "__main__":
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate) model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = True # True model.generation_config.do_sample = True # True
model.generation_config.guidance_scale = 1 # 3.0 model.generation_config.guidance_scale = 1 # 3.0
model.config.pad_token_id = encodec_vocab_size model.config.pad_token_id = encodec_vocab_size
model.config.decoder_start_token_id = encodec_vocab_size+1 model.config.decoder_start_token_id = encodec_vocab_size + 1
model.save_pretrained(os.path.join(args.save_directory, "parler-tts-untrained-600M/")) model.save_pretrained(os.path.join(args.save_directory, "parler-tts-untrained-600M/"))
...@@ -2,6 +2,7 @@ import dac ...@@ -2,6 +2,7 @@ import dac
from parler_tts import DACConfig, DACModel from parler_tts import DACConfig, DACModel
from transformers import AutoConfig, AutoModel from transformers import AutoConfig, AutoModel
from transformers import EncodecFeatureExtractor from transformers import EncodecFeatureExtractor
AutoConfig.register("dac", DACConfig) AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel) AutoModel.register(DACConfig, DACModel)
......
...@@ -47,6 +47,17 @@ class ParlerTTSDecoderConfig(PretrainedConfig): ...@@ -47,6 +47,17 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
Number of decoder layers. Number of decoder layers.
num_attention_heads (`int`, *optional*, defaults to 16): num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer block. Number of attention heads for each attention layer in the Transformer block.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
num_cross_attention_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention in the cross-attention layers.
If it is not specified, will default to `num_key_value_heads`.
ffn_dim (`int`, *optional*, defaults to 4096): ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block. Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block.
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
...@@ -74,6 +85,12 @@ class ParlerTTSDecoderConfig(PretrainedConfig): ...@@ -74,6 +85,12 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
The number of parallel codebooks forwarded to the model. The number of parallel codebooks forwarded to the model.
tie_word_embeddings(`bool`, *optional*, defaults to `False`): tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether input and output word embeddings should be tied. Whether input and output word embeddings should be tied.
rope_embeddings (`bool`, *optional*, defaults to `False`):
Whether to use ROPE or absolute positional embeddings.
rope_theta (`float`, *optional*, defaults to 100000.0):
The base period of the RoPE embeddings.
cross_attention_implementation_strategy (`str`, *optional*):
If not specified, the cross-attention implementation will be the same as `_attn_implementation`. If `always_eager`, it will always be the eager implementation. If `always_sdpa`, it will always be the sdpa implementation.
""" """
model_type = "parler_tts_decoder" model_type = "parler_tts_decoder"
...@@ -86,6 +103,8 @@ class ParlerTTSDecoderConfig(PretrainedConfig): ...@@ -86,6 +103,8 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
num_hidden_layers=24, num_hidden_layers=24,
ffn_dim=4096, ffn_dim=4096,
num_attention_heads=16, num_attention_heads=16,
num_key_value_heads=None,
num_cross_attention_key_value_heads=None,
layerdrop=0.0, layerdrop=0.0,
use_cache=True, use_cache=True,
activation_function="gelu", activation_function="gelu",
...@@ -100,6 +119,9 @@ class ParlerTTSDecoderConfig(PretrainedConfig): ...@@ -100,6 +119,9 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
bos_token_id=2049, bos_token_id=2049,
eos_token_id=2048, eos_token_id=2048,
tie_word_embeddings=False, tie_word_embeddings=False,
rope_embeddings=False,
rope_theta=10_000.0,
cross_attention_implementation_strategy=None,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
...@@ -108,6 +130,12 @@ class ParlerTTSDecoderConfig(PretrainedConfig): ...@@ -108,6 +130,12 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
self.ffn_dim = ffn_dim self.ffn_dim = ffn_dim
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
if num_cross_attention_key_value_heads is None:
num_cross_attention_key_value_heads = num_key_value_heads
self.num_cross_attention_key_value_heads = num_cross_attention_key_value_heads
self.dropout = dropout self.dropout = dropout
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout self.activation_dropout = activation_dropout
...@@ -117,6 +145,9 @@ class ParlerTTSDecoderConfig(PretrainedConfig): ...@@ -117,6 +145,9 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
self.use_cache = use_cache self.use_cache = use_cache
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.num_codebooks = num_codebooks self.num_codebooks = num_codebooks
self.rope_embeddings = rope_embeddings
self.rope_theta = rope_theta
self.cross_attention_implementation_strategy = cross_attention_implementation_strategy
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
...@@ -140,6 +171,8 @@ class ParlerTTSConfig(PretrainedConfig): ...@@ -140,6 +171,8 @@ class ParlerTTSConfig(PretrainedConfig):
vocab_size (`int`, *optional*, defaults to 1024): vocab_size (`int`, *optional*, defaults to 1024):
Vocabulary size of the prompt token ids. Defines the number of different tokens that can be Vocabulary size of the prompt token ids. Defines the number of different tokens that can be
represented by the `prompt_inputs_ids`. represented by the `prompt_inputs_ids`.
prompt_cross_attention (`bool`, *optional*, defaults to `False`):
Whether to use cross-attention conditioning for the prompt (as well as the description).
kwargs (*optional*): kwargs (*optional*):
Dictionary of keyword arguments. Notably: Dictionary of keyword arguments. Notably:
...@@ -190,7 +223,7 @@ class ParlerTTSConfig(PretrainedConfig): ...@@ -190,7 +223,7 @@ class ParlerTTSConfig(PretrainedConfig):
model_type = "parler_tts" model_type = "parler_tts"
is_composition = True is_composition = True
def __init__(self, vocab_size=1024, **kwargs): def __init__(self, vocab_size=1024, prompt_cross_attention=False, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs: if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config") raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")
...@@ -204,6 +237,7 @@ class ParlerTTSConfig(PretrainedConfig): ...@@ -204,6 +237,7 @@ class ParlerTTSConfig(PretrainedConfig):
decoder_config = kwargs.pop("decoder") decoder_config = kwargs.pop("decoder")
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.prompt_cross_attention = prompt_cross_attention
self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config) self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config) self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
self.decoder = ParlerTTSDecoderConfig(**decoder_config) self.decoder = ParlerTTSDecoderConfig(**decoder_config)
...@@ -236,3 +270,21 @@ class ParlerTTSConfig(PretrainedConfig): ...@@ -236,3 +270,21 @@ class ParlerTTSConfig(PretrainedConfig):
# This is a property because you might want to change the codec model on the fly # This is a property because you might want to change the codec model on the fly
def sampling_rate(self): def sampling_rate(self):
return self.audio_encoder.sampling_rate return self.audio_encoder.sampling_rate
# Copy from musicgen
@property
def _attn_implementation(self):
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
if hasattr(self, "_attn_implementation_internal"):
if self._attn_implementation_internal is None:
# `config.attn_implementation` should never be None, for backward compatibility.
return "eager"
else:
return self._attn_implementation_internal
else:
return "eager"
@_attn_implementation.setter
def _attn_implementation(self, value):
self._attn_implementation_internal = value
self.decoder._attn_implementation = value
This diff is collapsed.
...@@ -78,6 +78,22 @@ class ModelArguments: ...@@ -78,6 +78,22 @@ class ModelArguments:
"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models" "help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
}, },
) )
attn_implementation: str = field(
default="eager",
metadata={"help": "Attention implementation used. One of `eager`, `sdpa`, `flash_attention_2`"},
)
cross_attention_implementation_strategy: str = field(
default=None,
metadata={
"help": "If not specified, the cross-attention implementation will be the same as `_attn_implementation`. If `always_eager`, it will always be the eager implementation. If `always_sdpa`, it will always be the sdpa implementation."
},
)
prompt_padding_side: Optional[str] = field(
default="left",
metadata={
"help": "Prompt tokenizer padding side. Defaults to `left`. If the prompt is pre-pended to the codebooks hidden states, it should be padded on the left."
},
)
@dataclass @dataclass
...@@ -290,6 +306,10 @@ class DataTrainingArguments: ...@@ -290,6 +306,10 @@ class DataTrainingArguments:
}, },
) )
temporary_save_to_disk: str = field(default=None, metadata={"help": "Temporarily save audio labels here."}) temporary_save_to_disk: str = field(default=None, metadata={"help": "Temporarily save audio labels here."})
save_codec_steps: Optional[int] = field(
default=500,
metadata={"help": "Temporarily save the audio labels every `save_steps`."},
)
pad_to_multiple_of: Optional[int] = field( pad_to_multiple_of: Optional[int] = field(
default=2, default=2,
metadata={"help": ("Pad to multiple of for tokenizers.")}, metadata={"help": ("Pad to multiple of for tokenizers.")},
...@@ -311,3 +331,32 @@ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments): ...@@ -311,3 +331,32 @@ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
default=8, default=8,
metadata={"help": ("Specify the batch size of the audio encoding pre-processing steps.")}, metadata={"help": ("Specify the batch size of the audio encoding pre-processing steps.")},
) )
eval_dataloader_num_workers: Optional[int] = field(
default=0,
metadata={
"help": (
"Number of subprocesses to use for evaluation data loading (PyTorch only). 0 means that the data will be loaded in the main process."
)
},
)
compute_clap_similarity_metric: bool = field(
default=True,
metadata={
"help": (
"Whether or not to compute the clap similarity metric between the description and the generation during evalution."
)
},
)
compute_noise_level_metric: bool = field(
default=True,
metadata={"help": ("Whether or not to compute the squim si-sdr measure of the generations.")},
)
noise_level_to_compute_clean_wer: float = field(
default=25,
metadata={
"help": (
"if `compute_noise_level_metric=True`, will compute a 'clean' WER on samples with generated noise higher than `noise_level_to_compute_clean_wer`."
"This is a proxy measure to compute WER on clean audios, provided that the model learn to generate clean audios."
)
},
)
...@@ -30,6 +30,8 @@ class DataCollatorEncodecWithPadding: ...@@ -30,6 +30,8 @@ class DataCollatorEncodecWithPadding:
# different padding methods # different padding methods
audios = [feature[self.audio_column_name]["array"] for feature in features] audios = [feature[self.audio_column_name]["array"] for feature in features]
len_audio = [len(audio) for audio in audios] len_audio = [len(audio) for audio in audios]
if self.max_length is not None:
audios = [audio[: min(l, self.max_length)] for audio, l in zip(audios, len_audio)]
# since resampling has already been performed in the 'load_multiple_datasets' function, # since resampling has already been performed in the 'load_multiple_datasets' function,
# a fixed sampling_rate(44100hz) is passed to the feature_extractor. # a fixed sampling_rate(44100hz) is passed to the feature_extractor.
...@@ -81,7 +83,9 @@ class DataCollatorParlerTTSWithPadding: ...@@ -81,7 +83,9 @@ class DataCollatorParlerTTSWithPadding:
# (bsz, seq_len, num_codebooks) # (bsz, seq_len, num_codebooks)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
if self.audio_max_length is not None and self.padding == "max_length": if self.audio_max_length is not None and self.padding == "max_length":
labels = torch.nn.functional.pad(labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0))) labels = torch.nn.functional.pad(
labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0)), value=-100
)
input_ids = [{"input_ids": feature["input_ids"]} for feature in features] input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
...@@ -95,11 +99,6 @@ class DataCollatorParlerTTSWithPadding: ...@@ -95,11 +99,6 @@ class DataCollatorParlerTTSWithPadding:
batch = {"labels": labels, **input_ids} batch = {"labels": labels, **input_ids}
if self.audio_max_length is not None and self.padding == "max_length":
# if we do torch.compile, we need to also specify the attention_mask
decoder_attention_mask = torch.ones(labels.shape[:2], dtype=input_ids["attention_mask"].dtype)
batch["decoder_attention_mask"] = decoder_attention_mask
prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features] prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
prompt_input_ids = self.prompt_tokenizer.pad( prompt_input_ids = self.prompt_tokenizer.pad(
prompt_input_ids, prompt_input_ids,
...@@ -206,7 +205,7 @@ def load_multiple_datasets( ...@@ -206,7 +205,7 @@ def load_multiple_datasets(
all_datasets = [] all_datasets = []
# iterate over the datasets we want to interleave # iterate over the datasets we want to interleave
for dataset_dict in tqdm(dataset_names_dict, desc="Combining datasets..."): for dataset_dict in tqdm(dataset_names_dict, desc="Combining datasets..."):
with accelerator.main_process_first(): with accelerator.local_main_process_first():
dataset = load_dataset( dataset = load_dataset(
dataset_dict["name"], dataset_dict["name"],
dataset_dict["config"], dataset_dict["config"],
...@@ -242,7 +241,7 @@ def load_multiple_datasets( ...@@ -242,7 +241,7 @@ def load_multiple_datasets(
# metadata_dataset = metadata_dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24) # metadata_dataset = metadata_dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24)
# metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}") # metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}")
if dataset_dict["name"] != "parler-tts/mls_eng_10k": if dataset_dict["name"] not in {"parler-tts/mls_eng_10k", "parler-tts/mls_eng"}:
if id_column_name is not None and id_column_name not in dataset.column_names: if id_column_name is not None and id_column_name not in dataset.column_names:
raise ValueError( raise ValueError(
f"id_column_name={id_column_name} but has not been found in the dataset columns" f"id_column_name={id_column_name} but has not been found in the dataset columns"
...@@ -272,7 +271,10 @@ def load_multiple_datasets( ...@@ -272,7 +271,10 @@ def load_multiple_datasets(
dataset = concatenate_datasets([dataset, metadata_dataset], axis=1) dataset = concatenate_datasets([dataset, metadata_dataset], axis=1)
if id_column_name is not None and dataset_dict["name"] != "parler-tts/mls_eng_10k": if id_column_name is not None and dataset_dict["name"] not in {
"parler-tts/mls_eng_10k",
"parler-tts/mls_eng",
}:
if ( if (
len( len(
dataset.filter( dataset.filter(
...@@ -304,7 +306,7 @@ def load_multiple_datasets( ...@@ -304,7 +306,7 @@ def load_multiple_datasets(
seed=seed, seed=seed,
) )
else: else:
with accelerator.main_process_first(): with accelerator.local_main_process_first():
interleaved_dataset = concatenate_datasets(all_datasets) interleaved_dataset = concatenate_datasets(all_datasets)
return interleaved_dataset return interleaved_dataset
import torch import torch
from torchaudio.pipelines import SQUIM_OBJECTIVE
import torchaudio
import evaluate import evaluate
from transformers import AutoModel, AutoProcessor, pipeline, WhisperForConditionalGeneration, WhisperTokenizer, WhisperTokenizerFast from transformers import (
AutoModel,
AutoProcessor,
pipeline,
WhisperForConditionalGeneration,
WhisperTokenizer,
WhisperTokenizerFast,
)
from accelerate.utils.memory import release_memory
import numpy as np
def clap_similarity(clap_model_name_or_path, texts, audios, device): def clap_similarity(clap_model_name_or_path, texts, audios, device, input_sampling_rate=44100):
clap = AutoModel.from_pretrained(clap_model_name_or_path) clap = AutoModel.from_pretrained(clap_model_name_or_path)
clap_processor = AutoProcessor.from_pretrained(clap_model_name_or_path) clap_processor = AutoProcessor.from_pretrained(clap_model_name_or_path)
clap_inputs = clap_processor(text=texts, audios=audios, padding=True, return_tensors="pt").to(device) output_sampling_rate = clap_processor.feature_extractor.sampling_rate
if input_sampling_rate != output_sampling_rate:
audios = [
torchaudio.functional.resample(torch.from_numpy(audio), input_sampling_rate, output_sampling_rate).numpy()
for audio in audios
]
clap_inputs = clap_processor(
text=texts, audios=audios, padding=True, return_tensors="pt", sampling_rate=output_sampling_rate
).to(device)
clap.to(device) clap.to(device)
with torch.no_grad(): with torch.no_grad():
text_features = clap.get_text_features( text_features = clap.get_text_features(
...@@ -14,16 +34,52 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device): ...@@ -14,16 +34,52 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
) )
audio_features = clap.get_audio_features(clap_inputs["input_features"]) audio_features = clap.get_audio_features(clap_inputs["input_features"])
cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8) cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8).mean()
cosine_sim = cosine_sim.to("cpu")
clap.to("cpu") clap.to("cpu")
clap_inputs.to("cpu") clap, clap_inputs, audio_features, text_features = release_memory(clap, clap_inputs, audio_features, text_features)
return cosine_sim.mean().to("cpu") return cosine_sim
def si_sdr(audios, device, input_sampling_rate=44100):
max_audio_length = 15 * SQUIM_OBJECTIVE.sample_rate
model = SQUIM_OBJECTIVE.get_model().to((device))
output_sampling_rate = SQUIM_OBJECTIVE.sample_rate
if input_sampling_rate != output_sampling_rate:
audios = [
torchaudio.functional.resample(
torch.tensor(audio)[None, :].to(device).float(), input_sampling_rate, output_sampling_rate
)
for audio in audios
]
def apply_squim(waveform):
with torch.no_grad():
waveform = waveform[:, : min(max_audio_length, waveform.shape[1])]
_, _, sdr_sample = model(waveform)
sdr_sample = sdr_sample.cpu()[0]
return sdr_sample
def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate): si_sdrs = [apply_squim(audio) for audio in audios]
audios, model = release_memory(audios, model)
return si_sdrs
def wer(
asr_model_name_or_path,
prompts,
audios,
device,
per_device_eval_batch_size,
sampling_rate,
noise_level_to_compute_clean_wer,
si_sdr_measures,
):
metric = evaluate.load("wer") metric = evaluate.load("wer")
asr_pipeline = pipeline(model=asr_model_name_or_path, device=device) asr_pipeline = pipeline(model=asr_model_name_or_path, device=device, chunk_length_s=25.0)
return_language = None return_language = None
if isinstance(asr_pipeline.model, WhisperForConditionalGeneration): if isinstance(asr_pipeline.model, WhisperForConditionalGeneration):
...@@ -47,7 +103,11 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s ...@@ -47,7 +103,11 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s
normalized_references = [] normalized_references = []
for pred, ref in zip(transcriptions, prompts): for pred, ref in zip(transcriptions, prompts):
normalizer = english_normalizer if return_language and pred["chunks"][0]["language"] == "english" else basic_normalizer normalizer = (
english_normalizer
if isinstance(pred.get("chunks", None), list) and pred["chunks"][0].get("language", None) == "english"
else basic_normalizer
)
norm_ref = normalizer(ref) norm_ref = normalizer(ref)
if len(norm_ref) > 0: if len(norm_ref) > 0:
norm_pred = normalizer(pred["text"]) norm_pred = normalizer(pred["text"])
...@@ -56,4 +116,21 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s ...@@ -56,4 +116,21 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s
word_error = 100 * metric.compute(predictions=normalized_predictions, references=normalized_references) word_error = 100 * metric.compute(predictions=normalized_predictions, references=normalized_references)
return word_error, [t["text"] for t in transcriptions] clean_word_error = None
noisy_word_error = None
percent_clean_samples = 0
if noise_level_to_compute_clean_wer and si_sdr_measures:
si_sdr_measures = np.array(si_sdr_measures)
mask = si_sdr_measures >= noise_level_to_compute_clean_wer
if mask.any():
clean_word_error = 100 * metric.compute(
predictions=np.array(normalized_predictions)[mask], references=np.array(normalized_references)[mask]
)
noisy_word_error = 100 * metric.compute(
predictions=np.array(normalized_predictions)[~mask], references=np.array(normalized_references)[~mask]
)
percent_clean_samples = mask.sum() / len(mask)
asr_pipeline.model.to("cpu")
asr_pipeline = release_memory(asr_pipeline)
return word_error, [t["text"] for t in transcriptions], clean_word_error, noisy_word_error, percent_clean_samples
This diff is collapsed.
...@@ -7,6 +7,7 @@ from typing import Dict, List ...@@ -7,6 +7,7 @@ from typing import Dict, List
import torch import torch
from wandb import Audio from wandb import Audio
from datasets import load_from_disk, concatenate_datasets
def list_field(default=None, metadata=None): def list_field(default=None, metadata=None):
...@@ -14,6 +15,8 @@ def list_field(default=None, metadata=None): ...@@ -14,6 +15,8 @@ def list_field(default=None, metadata=None):
_RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$") _RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$")
CHECKPOINT_CODEC_PREFIX = "checkpoint"
_RE_CODEC_CHECKPOINT = re.compile(r"^checkpoint-(\d+)$")
def get_last_checkpoint(folder): def get_last_checkpoint(folder):
...@@ -60,6 +63,59 @@ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix ...@@ -60,6 +63,59 @@ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix
shutil.rmtree(checkpoint, ignore_errors=True) shutil.rmtree(checkpoint, ignore_errors=True)
def save_codec_checkpoint(output_dir, dataset, step):
checkpoint_path = f"{CHECKPOINT_CODEC_PREFIX}-{step}"
output_path = os.path.join(output_dir, checkpoint_path)
dataset.save_to_disk(output_path)
def load_codec_checkpoint(checkpoint_path):
dataset = load_from_disk(checkpoint_path)
return dataset
def sorted_codec_checkpoints(output_dir=None) -> List[str]:
"""Helper function to sort saved checkpoints from oldest to newest."""
ordering_and_checkpoint_path = []
glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{CHECKPOINT_CODEC_PREFIX}-*")]
for path in glob_checkpoints:
regex_match = re.match(f".*{CHECKPOINT_CODEC_PREFIX}-([0-9]+)", path)
if regex_match is not None and regex_match.groups() is not None:
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
return checkpoints_sorted
def load_all_codec_checkpoints(output_dir=None) -> List[str]:
"""Helper function to load and concat all checkpoints."""
checkpoints_sorted = sorted_codec_checkpoints(output_dir=output_dir)
datasets = [load_from_disk(checkpoint) for checkpoint in checkpoints_sorted]
datasets = concatenate_datasets(datasets, axis=0)
return datasets
def get_last_codec_checkpoint_step(folder) -> int:
if not os.path.exists(folder) or not os.path.isdir(folder):
os.makedirs(folder, exist_ok=True)
return 0
content = os.listdir(folder)
checkpoints = [path for path in content if _RE_CODEC_CHECKPOINT.search(path) is not None]
if len(checkpoints) == 0:
return 0
last_checkpoint = os.path.join(
folder, max(checkpoints, key=lambda x: int(_RE_CODEC_CHECKPOINT.search(x).groups()[0]))
)
# Find num steps saved state string pattern
pattern = r"checkpoint-(\d+)"
match = re.search(pattern, last_checkpoint)
cur_step = int(match.group(1))
return cur_step
def log_metric( def log_metric(
accelerator, accelerator,
metrics: Dict, metrics: Dict,
...@@ -86,6 +142,7 @@ def log_pred( ...@@ -86,6 +142,7 @@ def log_pred(
pred_prompts: List[str], pred_prompts: List[str],
transcriptions: List[str], transcriptions: List[str],
audios: List[torch.Tensor], audios: List[torch.Tensor],
si_sdr_measures: List[float],
sampling_rate: int, sampling_rate: int,
step: int, step: int,
prefix: str = "eval", prefix: str = "eval",
...@@ -98,16 +155,33 @@ def log_pred( ...@@ -98,16 +155,33 @@ def log_pred(
cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step
prefix_pretty = prefix.replace("/", "-") prefix_pretty = prefix.replace("/", "-")
# convert str data to a wandb compatible format if si_sdr_measures is None:
str_data = [[pred_descriptions[i], pred_prompts[i], transcriptions[i]] for i in range(len(pred_descriptions))] # convert str data to a wandb compatible format
# log as a table with the appropriate headers str_data = [
wandb_tracker.log_table( [pred_descriptions[i], pred_prompts[i], transcriptions[i]] for i in range(len(pred_descriptions))
table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}", ]
columns=["Target descriptions", "Target prompts", "Predicted transcriptions"], # log as a table with the appropriate headers
data=str_data[:num_lines], wandb_tracker.log_table(
step=step, table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
commit=False, columns=["Target descriptions", "Target prompts", "Predicted transcriptions"],
) data=str_data[:num_lines],
step=step,
commit=False,
)
else:
# convert str data to a wandb compatible format
str_data = [
[pred_descriptions[i], pred_prompts[i], transcriptions[i], si_sdr_measures[i]]
for i in range(len(pred_descriptions))
]
# log as a table with the appropriate headers
wandb_tracker.log_table(
table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
columns=["Target descriptions", "Target prompts", "Predicted transcriptions", "Noise estimation"],
data=str_data[:num_lines],
step=step,
commit=False,
)
# wandb can only loads 100 audios per step # wandb can only loads 100 audios per step
wandb_tracker.log( wandb_tracker.log(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment