Unverified Commit a53577f2 authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Merge pull request #2 from ylacombe/main

Release
parents 85b8cac7 5eae102f
......@@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Copyright [2024] [The HuggingFace Inc. team]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......
# Stable Speech
Work in-progress reproduction of the text-to-speech (TTS) model from the paper [Natural language guidance of high-fidelity text-to-speech with synthetic annotations](https://www.text-description-to-speech.com)
by Dan Lyth and Simon King, from Stability AI and Edinburgh University respectively.
Reproducing the TTS model requires the following 5 steps to be completed in order:
1. Train the Accent Classifier
2. Annotate the Training Set
3. Aggregate Statistics
4. Create Descriptions
5. Train the Model
## Step 1: Train the Accent Classifier
The script [`run_audio_classification.py`](run_audio_classification.py) can be used to train an audio encoder model from
the [Transformers library](https://github.com/huggingface/transformers) (e.g. Wav2Vec2, MMS, or Whisper) for the accent
classification task.
Starting with a pre-trained audio encoder model, a simple linear classifier is appended to the last hidden-layer to map the
audio embeddings to class label predictions. The audio encoder can either be frozen (`--freeze_base_model`) or trained.
The linear classifier is randomly initialised, and is thus always trained.
The script can be used to train on a single accent dataset, or a combination of datasets, which should be specified by
separating dataset names, configs and splits by the `+` character in the launch command (see below for an example).
In the proceeding example, we follow Stability's approach by taking audio embeddings from a frozen [MMS-LID](https://huggingface.co/facebook/mms-lid-126)
model, and training the linear classifier on a combination of three open-source datasets:
1. The English Accented (`en_accented`) subset of [Voxpopuli](https://huggingface.co/datasets/facebook/voxpopuli)
2. The train split of [VCTK](https://huggingface.co/datasets/vctk)
3. The dev split of [EdAcc](https://huggingface.co/datasets/edinburghcstr/edacc)
The model is subsequently evaluated on the test split of [EdAcc](https://huggingface.co/datasets/edinburghcstr/edacc)
to give the final classification accuracy.
```bash
#!/usr/bin/env bash
python run_audio_classification.py \
--model_name_or_path "facebook/mms-lid-126" \
--train_dataset_name "vctk+facebook/voxpopuli+edinburghcstr/edacc" \
--train_dataset_config_name "main+en_accented+default" \
--train_split_name "train+test+validation" \
--train_label_column_name "accent+accent+accent" \
--eval_dataset_name "edinburghcstr/edacc" \
--eval_dataset_config_name "default" \
--eval_split_name "test" \
--eval_label_column_name "accent" \
--output_dir "./" \
--do_train \
--do_eval \
--overwrite_output_dir \
--remove_unused_columns False \
--fp16 \
--learning_rate 1e-4 \
--max_length_seconds 20 \
--attention_mask False \
--warmup_ratio 0.1 \
--num_train_epochs 5 \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \
--preprocessing_num_workers 16 \
--dataloader_num_workers 4 \
--logging_strategy "steps" \
--logging_steps 10 \
--evaluation_strategy "epoch" \
--save_strategy "epoch" \
--load_best_model_at_end True \
--metric_for_best_model "accuracy" \
--save_total_limit 3 \
--freeze_base_model \
--push_to_hub \
--trust_remote_code
# Parler-TTS
Parler-TTS is a lightweight text-to-speech (TTS) model that can generate high-quality, natural sounding speech in the style of a given speaker (gender, pitch, speaking style, etc). It is a reproduction of work from the paper [Natural language guidance of high-fidelity text-to-speech with synthetic annotations](https://www.text-description-to-speech.com) by Dan Lyth and Simon King, from Stability AI and Edinburgh University respectively.
Contrarily to other TTS models, Parler-TTS is a **fully open-source** release. All of the datasets, pre-processing, training code and weights are released publicly under permissive license, enabling the community to build on our work and develop their own powerful TTS models.
This repository contains the inference and training code for Parler-TTS. It is designed to accompany the [Data-Speech](https://github.com/huggingface/dataspeech) repository for dataset annotation.
> [!IMPORTANT]
> We're proud to release Parler-TTS v0.1, our first 300M parameter model, trained on 10.5K hours of audio data.
> In the coming weeks, we'll be working on scaling up to 50k hours of data, in preparation for the v1 model.
## 📖 Quick Index
* [Installation](#installation)
* [Usage](#usage)
* [Training](#training)
* [Demo](https://huggingface.co/spaces/parler-tts/parler_tts_mini)
* [Model weights and datasets](https://huggingface.co/parler-tts)
## Usage
> [!TIP]
> You can directly try it out in an interactive demo [here](https://huggingface.co/spaces/parler-tts/parler_tts_mini)!
Using Parler-TTS is as simple as "bonjour". Simply use the following inference snippet.
```py
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import soundfile as sf
import torch
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_300M_v0.1").to(device)
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_300M_v0.1")
prompt = "Hey, how are you doing today?"
description = "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks very fast."
input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
audio_arr = generation.cpu().numpy().squeeze()
sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)
```
## Installation
Parler-TTS has light-weight dependencies and can be installed in one line:
```sh
pip install git+https://github.com/huggingface/parler-tts.git
```
## Training
The [training folder](/training/) contains all the information to train or fine-tune your own Parler-TTS model. It consists of:
- [1. An introduction to the Parler-TTS architecture](/training/README.md#1-architecture)
- [2. The first steps to get started](/training/README.md#2-getting-started)
- [3. A training guide](/training/README.md#3-training)
> [!IMPORTANT]
> **TL;DR:** After having followed the [installation steps](/training/README.md#requirements), you can reproduce the Parler-TTS v0.1 training recipe with the following command line:
```sh
accelerate launch ./training/run_parler_tts_training.py ./helpers/training_configs/starting_point_0.01.json
```
Tips:
1. **Number of labels:** normalisation should be applied to the target class labels to group linguistically similar accents together (e.g. "Southern Irish" and "Irish" should both be "Irish"). This helps _balance_ the dataset by removing labels with very few examples. You can modify the function `preprocess_labels` to implement any custom normalisation strategy.
## Acknowledgements
## Step 2: Annotate the Training Set
This library builds on top of a number of open-source giants, to whom we'd like to extend our warmest thanks for providing these tools!
Annotate the training dataset with information on: SNR, C50, pitch and speaking rate.
Special thanks to:
- Dan Lyth and Simon King, from Stability AI and Edinburgh University respectively, for publishing such a promising and clear research paper: [Natural language guidance of high-fidelity text-to-speech with synthetic annotations](https://arxiv.org/abs/2402.01912).
- the many libraries used, namely [🤗 datasets](https://huggingface.co/docs/datasets/v2.17.0/en/index), [🤗 accelerate](https://huggingface.co/docs/accelerate/en/index), [jiwer](https://github.com/jitsi/jiwer), [wandb](https://wandb.ai/), and [🤗 transformers](https://huggingface.co/docs/transformers/index).
- Descript for the [DAC codec model](https://github.com/descriptinc/descript-audio-codec)
- Hugging Face 🤗 for providing compute resources and time to explore!
## Step 3: Aggregate Statistics
Aggregate statistics from Step 2. Convert continuous values to discrete labels.
## Citation
## Step 4: Create Descriptions
If you found this repository useful, please consider citing this work and also the original Stability AI paper:
Convert sequence of discrete labels to text description (using an LLM).
```
@misc{lacombe-etal-2024-parler-tts,
author = {Yoach Lacombe and Vaibhav Srivastav and Sanchit Gandhi},
title = {Parler-TTS},
year = {2024},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/huggingface/parler-tts}}
}
```
## Step 5: Train the Model
```
@misc{lyth2024natural,
title={Natural language guidance of high-fidelity text-to-speech with synthetic annotations},
author={Dan Lyth and Simon King},
year={2024},
eprint={2402.01912},
archivePrefix={arXiv},
primaryClass={cs.SD}
}
```
Train MusicGen-style model on the TTS task.
## Contribution
Contributions are welcome, as the project offers many possibilities for improvement and exploration.
Namely, we're looking at ways to improve both quality and speed:
- Datasets:
- Train on more data
- Add more features such as accents
- Training:
- Add PEFT compatibility to do Lora fine-tuning.
- Add possibility to train without description column.
- Add notebook training.
- Explore multilingual training.
- Explore mono-speaker finetuning.
- Explore more architectures.
- Optimization:
- Compilation and static cache
- Support to FA2 and SDPA
- Evaluation:
- Add more evaluation metrics
#!/usr/bin/env bash
python run_audio_classification.py \
--model_name_or_path "hf-internal-testing/tiny-random-wav2vec2" \
--train_dataset_name "facebook/voxpopuli" \
--train_dataset_config_name "en_accented" \
--train_split_name "test" \
--train_label_column_name "accent" \
--eval_dataset_name "facebook/voxpopuli" \
--eval_dataset_config_name "en_accented" \
--eval_split_name "test" \
--eval_label_column_name "accent" \
--trust_remote_code \
--output_dir "./" \
--do_train \
--do_eval \
--max_train_samples 100 \
--max_eval_samples 100 \
--overwrite_output_dir \
--remove_unused_columns False \
--fp16 \
--learning_rate 1e-4 \
--min_length_seconds 5 \
--max_length_seconds 10 \
--attention_mask False \
--warmup_ratio 0.1 \
--num_train_epochs 5 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--dataloader_num_workers 0 \
--logging_strategy "steps" \
--logging_steps 10 \
--evaluation_strategy "epoch" \
--save_strategy "epoch" \
--load_best_model_at_end True \
--metric_for_best_model "accuracy" \
--save_total_limit 3 \
--seed 0
import gradio as gr
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
device = "cuda:0" if torch.cuda.is_available() else "cpu"
repo_id = "parler-tts/parler_tts_300M_v0.1"
model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device)
tokenizer = AutoTokenizer.from_pretrained(repo_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
SAMPLE_RATE = feature_extractor.sampling_rate
SEED = 41
default_text = "Please surprise me and speak in whatever voice you enjoy."
title = "# Parler-TTS </div>"
examples = [
[
"'This is the best time of my life, Bartley,' she said happily.",
"A female speaker with a slightly low-pitched, quite monotone voice delivers her words at a slightly faster-than-average pace in a confined space with very clear audio.",
],
[
"Montrose also, after having experienced still more variety of good and bad fortune, threw down his arms, and retired out of the kingdom. ",
"A male speaker with a slightly high-pitched voice delivering his words at a slightly slow pace in a small, confined space with a touch of background noise and a quite monotone tone.",
],
[
"montrose also after having experienced still more variety of good and bad fortune threw down his arms and retired out of the kingdom",
"A male speaker with a low-pitched voice delivering his words at a fast pace in a small, confined space with a lot of background noise and an animated tone.",
],
]
def gen_tts(text, description):
inputs = tokenizer(description, return_tensors="pt").to(device)
prompt = tokenizer(text, return_tensors="pt").to(device)
set_seed(SEED)
generation = model.generate(
input_ids=inputs.input_ids, prompt_input_ids=prompt.input_ids, do_sample=True, temperature=1.0
)
audio_arr = generation.cpu().numpy().squeeze()
return (SAMPLE_RATE, audio_arr)
css = """
#share-btn-container {
display: flex;
padding-left: 0.5rem !important;
padding-right: 0.5rem !important;
background-color: #000000;
justify-content: center;
align-items: center;
border-radius: 9999px !important;
width: 13rem;
margin-top: 10px;
margin-left: auto;
flex: unset !important;
}
#share-btn {
all: initial;
color: #ffffff;
font-weight: 600;
cursor: pointer;
font-family: 'IBM Plex Sans', sans-serif;
margin-left: 0.5rem !important;
padding-top: 0.25rem !important;
padding-bottom: 0.25rem !important;
right:0;
}
#share-btn * {
all: unset !important;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
#share-btn-container .wrap {
display: none !important;
}
"""
with gr.Blocks(css=css) as block:
gr.Markdown(title)
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text")
description = gr.Textbox(label="Description", lines=2, value="", elem_id="input_description")
run_button = gr.Button("Generate Audio", variant="primary")
with gr.Column():
audio_out = gr.Audio(label="Parler-TTS generation", type="numpy", elem_id="audio_out")
inputs = [input_text, description]
outputs = [audio_out]
gr.Examples(examples=examples, fn=gen_tts, inputs=inputs, outputs=outputs, cache_examples=True)
run_button.click(fn=gen_tts, inputs=inputs, outputs=outputs, queue=True)
block.queue()
block.launch(share=True)
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import AutoConfig
import os
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("save_directory", type=str, help="Directory where to save the model and the decoder.")
parser.add_argument("--text_model", type=str, help="Repository id or path to the text encoder.")
parser.add_argument("--audio_model", type=str, help="Repository id or path to the audio encoder.")
args = parser.parse_args()
text_model = args.text_model
encodec_version = args.audio_model
t5 = AutoConfig.from_pretrained(text_model)
encodec = AutoConfig.from_pretrained(encodec_version)
encodec_vocab_size = encodec.codebook_size
num_codebooks = encodec.num_codebooks
print("num_codebooks", num_codebooks)
decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size + 1,
max_position_embeddings=2048,
num_hidden_layers=4,
ffn_dim=512,
num_attention_heads=8,
layerdrop=0.0,
use_cache=True,
activation_function="gelu",
hidden_size=512,
dropout=0.0,
attention_dropout=0.0,
activation_dropout=0.0,
pad_token_id=encodec_vocab_size,
eos_token_id=encodec_vocab_size,
bos_token_id=encodec_vocab_size + 1,
num_codebooks=num_codebooks,
)
decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained(os.path.join(args.save_directory, "decoder"))
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path=os.path.join(args.save_directory, "decoder"),
vocab_size=t5.vocab_size,
)
# set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
model.generation_config.pad_token_id = encodec_vocab_size
model.generation_config.eos_token_id = encodec_vocab_size
# set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = True # True
model.generation_config.guidance_scale = 1 # 3.0
model.config.pad_token_id = encodec_vocab_size
model.config.decoder_start_token_id = encodec_vocab_size+1
model.save_pretrained(os.path.join(args.save_directory, "tiny-model"))
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import AutoConfig
import os
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("save_directory", type=str, help="Directory where to save the model and the decoder.")
args = parser.parse_args()
text_model = "google-t5/t5-small"
encodec_version = "facebook/encodec_24khz"
t5 = AutoConfig.from_pretrained(text_model)
encodec = AutoConfig.from_pretrained(encodec_version)
encodec_vocab_size = encodec.codebook_size
num_codebooks = 8
print("num_codebooks", num_codebooks)
decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size + 1,
max_position_embeddings=2048,
num_hidden_layers=4,
ffn_dim=512,
num_attention_heads=8,
layerdrop=0.0,
use_cache=True,
activation_function="gelu",
hidden_size=512,
dropout=0.0,
attention_dropout=0.0,
activation_dropout=0.0,
pad_token_id=encodec_vocab_size,
eos_token_id=encodec_vocab_size,
bos_token_id=encodec_vocab_size + 1,
num_codebooks=num_codebooks,
)
decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained(os.path.join(args.save_directory, "decoder"))
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path=os.path.join(args.save_directory, "decoder"),
vocab_size=t5.vocab_size,
)
# set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
model.generation_config.pad_token_id = encodec_vocab_size
model.generation_config.eos_token_id = encodec_vocab_size
# set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = True # True
model.generation_config.guidance_scale = 1 # 3.0
model.save_pretrained(os.path.join(args.save_directory, "tiny-model"))
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import AutoConfig
import os
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("save_directory", type=str, help="Directory where to save the model and the decoder.")
parser.add_argument("--text_model", type=str, help="Repository id or path to the text encoder.")
parser.add_argument("--audio_model", type=str, help="Repository id or path to the audio encoder.")
args = parser.parse_args()
text_model = args.text_model
encodec_version = args.audio_model
t5 = AutoConfig.from_pretrained(text_model)
encodec = AutoConfig.from_pretrained(encodec_version)
encodec_vocab_size = encodec.codebook_size
num_codebooks = encodec.num_codebooks
print("num_codebooks", num_codebooks)
decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size + 64, # + 64 instead of +1 to have a multiple of 64
max_position_embeddings=4096, # 30 s = 2580
num_hidden_layers=24,
ffn_dim=4096,
num_attention_heads=16,
layerdrop=0.0,
use_cache=True,
activation_function="gelu",
hidden_size=1024,
dropout=0.1,
attention_dropout=0.0,
activation_dropout=0.0,
pad_token_id=encodec_vocab_size,
eos_token_id=encodec_vocab_size,
bos_token_id=encodec_vocab_size + 1,
num_codebooks=num_codebooks,
)
decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained(os.path.join(args.save_directory, "decoder"))
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path=os.path.join(args.save_directory, "decoder"),
vocab_size=t5.vocab_size,
)
# set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
model.generation_config.pad_token_id = encodec_vocab_size
model.generation_config.eos_token_id = encodec_vocab_size
# set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = True # True
model.generation_config.guidance_scale = 1 # 3.0
model.config.pad_token_id = encodec_vocab_size
model.config.decoder_start_token_id = encodec_vocab_size+1
model.save_pretrained(os.path.join(args.save_directory, "parler-tts-untrained-300M/"))
import dac
from parler_tts import DACConfig, DACModel
from transformers import AutoConfig, AutoModel
from transformers import EncodecFeatureExtractor
AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
# Download a model
model_path = dac.utils.download(model_type="44khz")
model = dac.DAC.load(model_path)
hf_dac = DACModel(DACConfig())
hf_dac.model.load_state_dict(model.state_dict())
hf_dac.push_to_hub("parler-tts/dac_44khZ_8kbps")
EncodecFeatureExtractor(sampling_rate=44100).push_to_hub("parler-tts/dac_44khZ_8kbps")
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor
path = "TODO"
repo_id = "parler_tts_300M"
AutoFeatureExtractor.from_pretrained("ylacombe/dac_44khZ_8kbps").push_to_hub(repo_id)
AutoTokenizer.from_pretrained("google/t5-v1_1-base").push_to_hub(repo_id)
ParlerTTSForConditionalGeneration.from_pretrained(path).push_to_hub(repo_id)
{
"model_name_or_path": "./parler-tts-untrained-300M/parler-tts-untrained-300M/",
"save_to_disk": "./tmp_dataset_audio/",
"temporary_save_to_disk": "./audio_code_tmp/",
"feature_extractor_name":"ylacombe/dac_44khZ_8kbps",
"description_tokenizer_name":"google/flan-t5-base",
"prompt_tokenizer_name":"google/flan-t5-base",
"report_to": ["wandb"],
"overwrite_output_dir": true,
"output_dir": "./output_dir_training",
"train_dataset_name": "blabble-io/libritts_r",
"train_metadata_dataset_name": "parler-tts/libritts_r_tags_tagged_10k_generated",
"train_dataset_config_name": "clean",
"train_split_name": "test.clean",
"eval_dataset_name": "blabble-io/libritts_r",
"eval_metadata_dataset_name": "parler-tts/libritts_r_tags_tagged_10k_generated",
"eval_dataset_config_name": "clean",
"eval_split_name": "test.clean",
"target_audio_column_name": "audio",
"description_column_name": "text_description",
"prompt_column_name": "text",
"max_eval_samples": 48,
"max_train_samples": 96,
"max_duration_in_seconds": 20,
"min_duration_in_seconds": 2.0,
"add_audio_samples_to_wandb": true,
"id_column_name": "id",
"preprocessing_num_workers": 8,
"do_train": true,
"num_train_epochs": 50,
"gradient_accumulation_steps": 1,
"gradient_checkpointing": false,
"per_device_train_batch_size": 4,
"learning_rate": 1e-3,
"adam_beta1": 0.9,
"adam_beta2": 0.99,
"weight_decay": 0.01,
"lr_scheduler_type": "cosine",
"warmup_steps": 40,
"logging_steps": 2,
"freeze_text_encoder": true,
"do_eval": true,
"predict_with_generate": true,
"include_inputs_for_metrics": true,
"evaluation_strategy": "steps",
"eval_steps": 500,
"save_steps": 5000,
"per_device_eval_batch_size": 12,
"audio_encoder_per_device_batch_size":24,
"dtype": "bfloat16",
"seed": 456,
"dataloader_num_workers":8
}
{
"model_name_or_path": "./parler-tts-untrained-300M/parler-tts-untrained-300M/",
"save_to_disk": "./tmp_dataset_audio/",
"temporary_save_to_disk": "./audio_code_tmp/",
"feature_extractor_name":"ylacombe/dac_44khZ_8kbps",
"description_tokenizer_name":"google/flan-t5-base",
"prompt_tokenizer_name":"google/flan-t5-base",
"report_to": ["wandb"],
"overwrite_output_dir": true,
"output_dir": "./output_dir_training",
"train_dataset_name": "blabble-io/libritts_r+blabble-io/libritts_r+blabble-io/libritts_r+parler-tts/mls_eng_10k",
"train_metadata_dataset_name": "parler-tts/libritts_r_tags_tagged_10k_generated+parler-tts/libritts_r_tags_tagged_10k_generated+parler-tts/libritts_r_tags_tagged_10k_generated+parler-tts/mls-eng-10k-tags_tagged_10k_generated",
"train_dataset_config_name": "clean+clean+other+default",
"train_split_name": "train.clean.360+train.clean.100+train.other.500+train",
"eval_dataset_name": "blabble-io/libritts_r+parler-tts/mls_eng_10k",
"eval_metadata_dataset_name": "parler-tts/libritts_r_tags_tagged_10k_generated+parler-tts/mls-eng-10k-tags_tagged_10k_generated",
"eval_dataset_config_name": "other+default",
"eval_split_name": "test.other+test",
"target_audio_column_name": "audio",
"description_column_name": "text_description",
"prompt_column_name": "text",
"max_eval_samples": 96,
"max_duration_in_seconds": 30,
"min_duration_in_seconds": 2.0,
"max_text_length": 400,
"group_by_length": true,
"add_audio_samples_to_wandb": true,
"id_column_name": "id",
"preprocessing_num_workers": 8,
"do_train": true,
"num_train_epochs": 40,
"gradient_accumulation_steps": 8,
"gradient_checkpointing": false,
"per_device_train_batch_size": 3,
"learning_rate": 0.00095,
"adam_beta1": 0.9,
"adam_beta2": 0.99,
"weight_decay": 0.01,
"lr_scheduler_type": "constant_with_warmup",
"warmup_steps": 20000,
"logging_steps": 1000,
"freeze_text_encoder": true,
"do_eval": true,
"predict_with_generate": true,
"include_inputs_for_metrics": true,
"evaluation_strategy": "steps",
"eval_steps": 10000,
"save_steps": 10000,
"per_device_eval_batch_size": 12,
"audio_encoder_per_device_batch_size":20,
"dtype": "bfloat16",
"seed": 456,
"dataloader_num_workers":8
}
__version__ = "0.1"
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
from .modeling_parler_tts import (
ParlerTTSForCausalLM,
ParlerTTSForConditionalGeneration,
apply_delay_pattern_mask,
build_delay_pattern_mask,
)
from .dac_wrapper import DACConfig, DACModel
from transformers import AutoConfig, AutoModel
AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
# coding=utf-8
# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
# Copyright 2024 and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Stable Speech model configuration"""
""" Parler-TTS model configuration"""
from transformers import AutoConfig, logging
from transformers.configuration_utils import PretrainedConfig
......@@ -21,26 +21,26 @@ from transformers.configuration_utils import PretrainedConfig
logger = logging.get_logger(__name__)
MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/stable_speech-small": "https://huggingface.co/facebook/stable_speech-small/resolve/main/config.json",
# See all StableSpeech models at https://huggingface.co/models?filter=stable_speech
"facebook/parler_tts-small": "https://huggingface.co/facebook/parler_tts-small/resolve/main/config.json",
# See all ParlerTTS models at https://huggingface.co/models?filter=parler_tts
}
class StableSpeechDecoderConfig(PretrainedConfig):
class ParlerTTSDecoderConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of an [`StableSpeechDecoder`]. It is used to instantiate a
Stable Speech decoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the Stable Speech
[facebook/stable_speech-small](https://huggingface.co/facebook/stable_speech-small) architecture.
This is the configuration class to store the configuration of an [`ParlerTTSDecoder`]. It is used to instantiate a
Parler-TTS decoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the Parler-TTS
[facebook/parler_tts-small](https://huggingface.co/facebook/parler_tts-small) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 2048):
Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`].
vocab_size (`int`, *optional*, defaults to 2049):
Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 24):
......@@ -76,12 +76,12 @@ class StableSpeechDecoderConfig(PretrainedConfig):
Whether input and output word embeddings should be tied.
"""
model_type = "stable_speech_decoder"
model_type = "parler_tts_decoder"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=2048,
vocab_size=2049, # vocab size = 2048 (encodec vocab size) + 1 (eos)
max_position_embeddings=2048,
num_hidden_layers=24,
ffn_dim=4096,
......@@ -97,8 +97,8 @@ class StableSpeechDecoderConfig(PretrainedConfig):
scale_embedding=False,
num_codebooks=4,
pad_token_id=2048,
bos_token_id=2048,
eos_token_id=None,
bos_token_id=2049,
eos_token_id=2048,
tie_word_embeddings=False,
**kwargs,
):
......@@ -127,16 +127,19 @@ class StableSpeechDecoderConfig(PretrainedConfig):
)
class StableSpeechConfig(PretrainedConfig):
class ParlerTTSConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`StableSpeechModel`]. It is used to instantiate a
Stable Speech model according to the specified arguments, defining the text encoder, audio encoder and Stable Speech decoder
This is the configuration class to store the configuration of a [`ParlerTTSModel`]. It is used to instantiate a
Parler-TTS model according to the specified arguments, defining the text encoder, audio encoder and Parler-TTS decoder
configs.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 1024):
Vocabulary size of the prompt token ids. Defines the number of different tokens that can be
represented by the `prompt_inputs_ids`.
kwargs (*optional*):
Dictionary of keyword arguments. Notably:
......@@ -151,24 +154,24 @@ class StableSpeechConfig(PretrainedConfig):
```python
>>> from transformers import (
... StableSpeechConfig,
... StableSpeechDecoderConfig,
... ParlerTTSConfig,
... ParlerTTSDecoderConfig,
... T5Config,
... EncodecConfig,
... StableSpeechForConditionalGeneration,
... ParlerTTSForConditionalGeneration,
... )
>>> # Initializing text encoder, audio encoder, and decoder model configurations
>>> text_encoder_config = T5Config()
>>> audio_encoder_config = EncodecConfig()
>>> decoder_config = StableSpeechDecoderConfig()
>>> decoder_config = ParlerTTSDecoderConfig()
>>> configuration = StableSpeechConfig.from_sub_models_config(
>>> configuration = ParlerTTSConfig.from_sub_models_config(
... text_encoder_config, audio_encoder_config, decoder_config
... )
>>> # Initializing a StableSpeechForConditionalGeneration (with random weights) from the facebook/stable_speech-small style configuration
>>> model = StableSpeechForConditionalGeneration(configuration)
>>> # Initializing a ParlerTTSForConditionalGeneration (with random weights) from the facebook/parler_tts-small style configuration
>>> model = ParlerTTSForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
......@@ -177,17 +180,17 @@ class StableSpeechConfig(PretrainedConfig):
>>> config_decoder = model.config.decoder
>>> # Saving the model, including its configuration
>>> model.save_pretrained("stable_speech-model")
>>> model.save_pretrained("parler_tts-model")
>>> # loading model and config from pretrained folder
>>> stable_speech_config = StableSpeechConfig.from_pretrained("stable_speech-model")
>>> model = StableSpeechForConditionalGeneration.from_pretrained("stable_speech-model", config=stable_speech_config)
>>> parler_tts_config = ParlerTTSConfig.from_pretrained("parler_tts-model")
>>> model = ParlerTTSForConditionalGeneration.from_pretrained("parler_tts-model", config=parler_tts_config)
```"""
model_type = "stable_speech"
model_type = "parler_tts"
is_composition = True
def __init__(self, **kwargs):
def __init__(self, vocab_size=1024, **kwargs):
super().__init__(**kwargs)
if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")
......@@ -200,9 +203,10 @@ class StableSpeechConfig(PretrainedConfig):
decoder_config = kwargs.pop("decoder")
self.vocab_size = vocab_size
self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
self.decoder = StableSpeechDecoderConfig(**decoder_config)
self.decoder = ParlerTTSDecoderConfig(**decoder_config)
self.is_encoder_decoder = True
@classmethod
......@@ -210,15 +214,15 @@ class StableSpeechConfig(PretrainedConfig):
cls,
text_encoder_config: PretrainedConfig,
audio_encoder_config: PretrainedConfig,
decoder_config: StableSpeechDecoderConfig,
decoder_config: ParlerTTSDecoderConfig,
**kwargs,
):
r"""
Instantiate a [`StableSpeechConfig`] (or a derived class) from text encoder, audio encoder and decoder
Instantiate a [`ParlerTTSConfig`] (or a derived class) from text encoder, audio encoder and decoder
configurations.
Returns:
[`StableSpeechConfig`]: An instance of a configuration object
[`ParlerTTSConfig`]: An instance of a configuration object
"""
return cls(
......
from .configuration_dac import DACConfig
from .modeling_dac import DACModel
from transformers import PretrainedConfig
from typing import List
class DACConfig(PretrainedConfig):
model_type = "dac"
def __init__(
self,
num_codebooks: int = 9,
model_bitrate: int = 8, # kbps
codebook_size: int = 1024,
latent_dim: int = 1024,
frame_rate: int = 86,
sampling_rate: int = 44100,
**kwargs,
):
self.codebook_size = codebook_size
self.model_bitrate = model_bitrate
self.latent_dim = latent_dim
self.num_codebooks = num_codebooks
self.frame_rate = frame_rate
self.sampling_rate = sampling_rate
super().__init__(**kwargs)
import torch
from transformers import PreTrainedModel
from transformers.models.encodec.modeling_encodec import EncodecEncoderOutput, EncodecDecoderOutput
from .configuration_dac import DACConfig
from dac.model import DAC
# model doesn't support batching yet
class DACModel(PreTrainedModel):
config_class = DACConfig
def __init__(self, config):
super().__init__(config)
self.model = DAC(
n_codebooks=config.num_codebooks,
latent_dim=config.latent_dim,
codebook_size=config.codebook_size,
)
def encode(
self, input_values, padding_mask=None, bandwidth=None, return_dict=None, n_quantizers=None, sample_rate=None
):
"""
Encodes the input audio waveform into discrete codes.
Args:
input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
Float values of the input audio waveform.
padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
Padding mask used to pad the `input_values`.
bandwidth (`float`, *optional*):
Not used, kept to have the same inferface as HF encodec.
n_quantizers (`int`, *optional*) :
Number of quantizers to use, by default None
If None, all quantizers are used.
sample_rate (`int`, *optional*) :
Signal sampling_rate
Returns:
A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling
factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
`codebook` of shape `[batch_size, num_codebooks, frames]`.
Scale is not used here.
"""
_, channels, input_length = input_values.shape
if channels < 1 or channels > 2:
raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}")
audio_data = self.model.preprocess(input_values, sample_rate)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# TODO: for now, no chunk length
chunk_length = None # self.config.chunk_length
if chunk_length is None:
chunk_length = input_length
stride = input_length
else:
stride = self.config.chunk_stride
if padding_mask is None:
padding_mask = torch.ones_like(input_values).bool()
encoded_frames = []
scales = []
step = chunk_length - stride
if (input_length % stride) - step != 0:
raise ValueError(
"The input length is not properly padded for batched chunked decoding. Make sure to pad the input correctly."
)
for offset in range(0, input_length - step, stride):
mask = padding_mask[..., offset : offset + chunk_length].bool()
frame = audio_data[:, :, offset : offset + chunk_length]
scale = None
_, encoded_frame, _, _, _ = self.model.encode(frame, n_quantizers=n_quantizers)
encoded_frames.append(encoded_frame)
scales.append(scale)
encoded_frames = torch.stack(encoded_frames)
if not return_dict:
return (encoded_frames, scales)
return EncodecEncoderOutput(encoded_frames, scales)
def decode(
self,
audio_codes,
audio_scales,
padding_mask=None,
return_dict=None,
):
"""
Decodes the given frames into an output audio waveform.
Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be
trimmed.
Args:
audio_codes (`torch.FloatTensor` of shape `(batch_size, nb_chunks, chunk_length)`, *optional*):
Discret code embeddings computed using `model.encode`.
audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*):
Not used, kept to have the same inferface as HF encodec.
padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
Padding mask used to pad the `input_values`.
Not used yet, kept to have the same inferface as HF encodec.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
return_dict = return_dict or self.config.return_dict
# TODO: for now, no chunk length
if len(audio_codes) != 1:
raise ValueError(f"Expected one frame, got {len(audio_codes)}")
audio_values = self.model.quantizer.from_codes(audio_codes.squeeze(0))[0]
audio_values = self.model.decode(audio_values)
if not return_dict:
return (audio_values,)
return EncodecDecoderOutput(audio_values)
def forward(self, tensor):
raise ValueError(f"`DACModel.forward` not implemented yet")
# coding=utf-8
# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch StableSpeech model."""
""" PyTorch ParlerTTS model."""
import copy
import inspect
import math
......@@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModel
from transformers import AutoConfig, AutoModel, AutoModelForTextEncoding
from transformers.activations import ACT2FN
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
......@@ -44,25 +44,103 @@ from transformers.utils import (
replace_return_docstrings,
)
from .configuration_stable_speech import StableSpeechConfig, StableSpeechDecoderConfig
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
from .dac_wrapper import DACConfig, DACModel
from transformers import AutoConfig, AutoModel
AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
if TYPE_CHECKING:
from transformers.generation.streamers import BaseStreamer
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "StableSpeechConfig"
_CHECKPOINT_FOR_DOC = "facebook/stable_speech-small"
_CONFIG_FOR_DOC = "ParlerTTSConfig"
_CHECKPOINT_FOR_DOC = "facebook/parler_tts-small"
MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/stable_speech-small",
# See all StableSpeech models at https://huggingface.co/models?filter=stable_speech
"facebook/parler_tts-small",
# See all ParlerTTS models at https://huggingface.co/models?filter=parler_tts
]
def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
"""Apply a delay pattern mask to the decoder input ids, only preserving predictions where
the mask is set to -1, and otherwise setting to the value detailed in the mask."""
seq_len = input_ids.shape[-1]
decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len]
input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask)
return input_ids
def build_delay_pattern_mask(
input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int, num_codebooks: int
):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
seq_len)`:
- [B, -1, -1, -1, -1, P, P, P]
- [B, B, -1, -1, -1, -1, P, P]
- [B, B, B, -1, -1, -1, -1, P]
- [B, B, B, B, -1, -1, -1, -1]
where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include
a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
mask is set to the value in the prompt:
- [B, a, b, -1, -1, P, P, P]
- [B, B, c, d, -1, -1, P, P]
- [B, B, B, e, f, -1, -1, P]
- [B, B, B, B, g, h, -1, -1]
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
tokens in our prediction.
"""
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
input_ids = input_ids.reshape(-1, num_codebooks, input_ids.shape[-1])
bsz, num_codebooks, seq_len = input_ids.shape
input_ids_shifted = torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
# we only apply the mask if we have a large enough seq len - otherwise we return as is
if max_length < 2 * num_codebooks - 1:
return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1)
# fill the shifted ids with the prompt entries, offset by the codebook idx
for codebook in range(num_codebooks):
# mono channel - loop over the codebooks one-by-one
input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
# construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding)
eos_delay_pattern = torch.triu(
torch.ones((num_codebooks, max_length), dtype=torch.bool), diagonal=max_length - num_codebooks + 1
)
# then fill the lower triangular part (the BOS padding)
bos_delay_pattern = torch.tril(torch.ones((num_codebooks, max_length), dtype=torch.bool))
bos_mask = ~(bos_delay_pattern).to(input_ids.device)
eos_mask = ~(eos_delay_pattern).to(input_ids.device)
mask = ~(bos_delay_pattern + eos_delay_pattern).to(input_ids.device)
input_ids = mask * input_ids_shifted + ~bos_mask * bos_token_id + ~eos_mask * pad_token_id
# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
first_codebook_ids = input_ids[:, 0, :]
start_ids = (first_codebook_ids == -1).nonzero()[:, 1]
if len(start_ids) > 0:
first_start_id = min(start_ids)
else:
# we have no tokens that need to be filled - return entire matrix of input ids
first_start_id = seq_len
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
pattern_mask = input_ids.reshape(bsz * num_codebooks, -1)
input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
return input_ids, pattern_mask
@dataclass
class StableSpeechUnconditionalInput(ModelOutput):
class ParlerTTSUnconditionalInput(ModelOutput):
"""
Args:
encoder_outputs (`Tuple[torch.FloatTensor]` of length 1, with tensor shape `(batch_size, sequence_length, hidden_size)`):
......@@ -99,8 +177,8 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
return shifted_input_ids
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenSinusoidalPositionalEmbedding with Musicgen->StableSpeech
class StableSpeechSinusoidalPositionalEmbedding(nn.Module):
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenSinusoidalPositionalEmbedding with Musicgen->ParlerTTS
class ParlerTTSSinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions: int, embedding_dim: int):
......@@ -136,7 +214,7 @@ class StableSpeechSinusoidalPositionalEmbedding(nn.Module):
@torch.no_grad()
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
bsz, codebooks, seq_len = input_ids.size()
bsz, seq_len, _ = input_ids.size()
# Create the position ids from the input token ids.
position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device)
# expand embeddings if needed
......@@ -145,8 +223,8 @@ class StableSpeechSinusoidalPositionalEmbedding(nn.Module):
return self.weights.index_select(0, position_ids.view(-1)).detach()
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->StableSpeech
class StableSpeechAttention(nn.Module):
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->ParlerTTS
class ParlerTTSAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
......@@ -157,7 +235,7 @@ class StableSpeechAttention(nn.Module):
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[StableSpeechConfig] = None,
config: Optional[ParlerTTSConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
......@@ -304,13 +382,13 @@ class StableSpeechAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoderLayer with Musicgen->StableSpeech
class StableSpeechDecoderLayer(nn.Module):
def __init__(self, config: StableSpeechDecoderConfig):
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoderLayer with Musicgen->ParlerTTS
class ParlerTTSDecoderLayer(nn.Module):
def __init__(self, config: ParlerTTSDecoderConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = StableSpeechAttention(
self.self_attn = ParlerTTSAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
......@@ -322,7 +400,7 @@ class StableSpeechDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = StableSpeechAttention(
self.encoder_attn = ParlerTTSAttention(
self.embed_dim,
config.num_attention_heads,
dropout=config.attention_dropout,
......@@ -424,17 +502,17 @@ class StableSpeechDecoderLayer(nn.Module):
return outputs
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenPreTrainedModel with Musicgen->StableSpeech
class StableSpeechPreTrainedModel(PreTrainedModel):
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenPreTrainedModel with Musicgen->ParlerTTS
class ParlerTTSPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = StableSpeechDecoderConfig
config_class = ParlerTTSDecoderConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["StableSpeechDecoderLayer", "StableSpeechAttention"]
_no_split_modules = ["ParlerTTSDecoderLayer", "ParlerTTSAttention"]
def _init_weights(self, module):
std = self.config.initializer_factor
......@@ -450,7 +528,7 @@ class StableSpeechPreTrainedModel(PreTrainedModel):
MUSICGEN_START_DOCSTRING = r"""
The StableSpeech model was proposed in [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by
The ParlerTTS model was proposed in [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by
Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi, Alexandre Défossez. It is an
encoder decoder transformer trained on the task of conditional music generation
......@@ -463,7 +541,7 @@ MUSICGEN_START_DOCSTRING = r"""
and behavior.
Parameters:
config ([`StableSpeechConfig`]): Model configuration class with all the parameters of the model.
config ([`ParlerTTSConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
......@@ -553,6 +631,25 @@ MUSICGEN_INPUTS_DOCSTRING = r"""
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
of `inputs_embeds`.
prompt_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input prompt sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
prompt_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding prompt token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
prompt_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `prompt_input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `prompt_input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
......@@ -604,6 +701,16 @@ MUSICGEN_DECODER_INPUTS_DOCSTRING = r"""
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
prompt_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of prompt hidden-states at the output of the initial embedding layer. Concatenated to the input embeds.
prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
Mask to avoid performing cross-attention on padding tokens indices of prompt input_ids. Mask values
selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
......@@ -644,13 +751,13 @@ MUSICGEN_DECODER_INPUTS_DOCSTRING = r"""
"""
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder with Musicgen->StableSpeech
class StableSpeechDecoder(StableSpeechPreTrainedModel):
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder with Musicgen->ParlerTTS
class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`StableSpeechDecoderLayer`]
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ParlerTTSDecoderLayer`]
"""
def __init__(self, config: StableSpeechDecoderConfig):
def __init__(self, config: ParlerTTSDecoderConfig):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.layerdrop
......@@ -659,17 +766,18 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
self.num_codebooks = config.num_codebooks
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
embed_dim = config.vocab_size + 1
# TODO(YL): actually doesn't need the +1 if initialized correctly. Too late to change now.
embed_dim = config.vocab_size + 1 # + 1 for pad token id
self.embed_tokens = nn.ModuleList(
[nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)]
)
self.embed_positions = StableSpeechSinusoidalPositionalEmbedding(
self.embed_positions = ParlerTTSSinusoidalPositionalEmbedding(
config.max_position_embeddings,
config.hidden_size,
)
self.layers = nn.ModuleList([StableSpeechDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layers = nn.ModuleList([ParlerTTSDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layer_norm = nn.LayerNorm(config.hidden_size)
self.gradient_checkpointing = False
......@@ -689,6 +797,8 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
......@@ -725,6 +835,38 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)])
# if prompt_hidden_states, fuse to inputs_embeds and update input shape
if prompt_hidden_states is not None:
inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1)
# As it is, the masked ids from the prompt will still count in the positions embeddings
if prompt_attention_mask is not None and attention_mask is not None:
attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1)
elif prompt_attention_mask is not None:
logger.warning_once(
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour."
)
if past_key_values is None:
attention_mask = torch.cat(
[
prompt_attention_mask,
torch.ones(input_shape, device=self.device, dtype=prompt_attention_mask.dtype),
],
dim=1,
)
else:
generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1
attention_mask = torch.cat(
[
prompt_attention_mask,
torch.ones(
(input_shape[0], generated_length), device=self.device, dtype=prompt_attention_mask.dtype
),
],
dim=1,
)
input_shape = inputs_embeds.size()[:-1]
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
......@@ -737,7 +879,9 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
)
# embed positions
positions = self.embed_positions(input, past_key_values_length)
# TODO: As it is, the masked ids from the prompt will still count in the positions embeddings
# maybe should modify position embeddings
positions = self.embed_positions(inputs_embeds, past_key_values_length)
hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
......@@ -835,14 +979,14 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
@add_start_docstrings(
"The bare StableSpeech decoder model outputting raw hidden-states without any specific head on top.",
"The bare ParlerTTS decoder model outputting raw hidden-states without any specific head on top.",
MUSICGEN_START_DOCSTRING,
)
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenModel with Musicgen->StableSpeech
class StableSpeechModel(StableSpeechPreTrainedModel):
def __init__(self, config: StableSpeechDecoderConfig):
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenModel with Musicgen->ParlerTTS
class ParlerTTSModel(ParlerTTSPreTrainedModel):
def __init__(self, config: ParlerTTSDecoderConfig):
super().__init__(config)
self.decoder = StableSpeechDecoder(config)
self.decoder = ParlerTTSDecoder(config)
# Initialize weights and apply final processing
self.post_init()
......@@ -862,6 +1006,8 @@ class StableSpeechModel(StableSpeechPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
......@@ -884,6 +1030,8 @@ class StableSpeechModel(StableSpeechPreTrainedModel):
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
prompt_hidden_states=prompt_hidden_states,
prompt_attention_mask=prompt_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
......@@ -907,15 +1055,15 @@ class StableSpeechModel(StableSpeechPreTrainedModel):
@add_start_docstrings(
"The Stable Speech decoder model with a language modelling head on top.",
"The Parler-TTS decoder model with a language modelling head on top.",
MUSICGEN_START_DOCSTRING,
)
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForCausalLM with Musicgen->StableSpeech
class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
def __init__(self, config: StableSpeechDecoderConfig):
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForCausalLM with Musicgen->ParlerTTS
class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
def __init__(self, config: ParlerTTSDecoderConfig):
super().__init__(config)
self.model = StableSpeechModel(config)
self.model = ParlerTTSModel(config)
self.num_codebooks = config.num_codebooks
self.lm_heads = nn.ModuleList(
......@@ -951,6 +1099,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
......@@ -962,11 +1112,11 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
Returns:
Returns:
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......@@ -976,6 +1126,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
prompt_hidden_states=prompt_hidden_states,
prompt_attention_mask=prompt_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
......@@ -992,7 +1144,29 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
loss = None
if labels is not None:
raise NotImplementedError("Training is not implemented for StableSpeech.")
loss = torch.zeros([], device=self.device)
# since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels
logits = lm_logits[:, :, -labels.shape[1] :]
loss_fct = CrossEntropyLoss()
loss = torch.zeros([], device=self.device)
# (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks)
labels = labels.masked_fill(labels == self.config.bos_token_id, -100)
# we use every codebooks token AND one single EOS at the end of each codebooks
mask = (input_ids.transpose(1, 2) != self.config.eos_token_id) & ((labels != -100))
# per codebook cross-entropy
for codebook in range(self.config.num_codebooks):
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
codebook_mask = mask[..., codebook].contiguous().view(-1)
codebook_labels = labels[..., codebook].contiguous().view(-1)
codebook_loss = loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask])
loss += codebook_loss
loss = loss / self.config.num_codebooks
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
......@@ -1016,6 +1190,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
prompt_hidden_states=None,
prompt_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
......@@ -1027,6 +1203,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
if delay_pattern_mask is None:
input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
input_ids,
bos_token_id=self.generation_config.bos_token_id,
pad_token_id=self.generation_config.pad_token_id,
max_length=self.generation_config.max_length,
)
......@@ -1041,14 +1218,29 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
if attention_mask is not None:
attention_mask = attention_mask.repeat((2, 1))
if prompt_hidden_states is not None:
prompt_hidden_states = torch.concatenate(
[prompt_hidden_states, torch.zeros_like(prompt_hidden_states)], dim=0
)
if prompt_attention_mask is not None:
prompt_attention_mask = torch.concatenate(
[prompt_attention_mask, torch.zeros_like(prompt_attention_mask)], dim=0
)
if past_key_values is not None:
input_ids = input_ids[:, -1:]
# we only want to use prompt signal in the 1st generation step but keeping the attention mask
prompt_hidden_states = None
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
"prompt_hidden_states": prompt_hidden_states,
"prompt_attention_mask": prompt_attention_mask,
"head_mask": head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"past_key_values": past_key_values,
......@@ -1056,77 +1248,35 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
}
# Ignore copy
def build_delay_pattern_mask(self, input_ids: torch.LongTensor, pad_token_id: int, max_length: int = None):
def build_delay_pattern_mask(
self, input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int = None
):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
seq_len)`:
- [P, -1, -1, -1, -1, P, P, P]
- [P, P, -1, -1, -1, -1, P, P]
- [P, P, P, -1, -1, -1, -1, P]
- [P, P, P, P, -1, -1, -1, -1]
- [B, -1, -1, -1, -1, P, P, P]
- [B, B, -1, -1, -1, -1, P, P]
- [B, B, B, -1, -1, -1, -1, P]
- [B, B, B, B, -1, -1, -1, -1]
where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include
a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
mask is set to the value in the prompt:
- [P, a, b, -1, -1, P, P, P]
- [P, P, c, d, -1, -1, P, P]
- [P, P, P, e, f, -1, -1, P]
- [P, P, P, P, g, h, -1, -1]
- [B, a, b, -1, -1, P, P, P]
- [B, B, c, d, -1, -1, P, P]
- [B, B, B, e, f, -1, -1, P]
- [B, B, B, B, g, h, -1, -1]
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
tokens in our prediction.
"""
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
input_ids = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
bsz, num_codebooks, seq_len = input_ids.shape
max_length = max_length if max_length is not None else self.generation_config.max_length
input_ids_shifted = (
torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
)
# we only apply the mask if we have a large enough seq len - otherwise we return as is
if max_length < 2 * num_codebooks - 1:
return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1)
# fill the shifted ids with the prompt entries, offset by the codebook idx
for codebook in range(num_codebooks):
# mono channel - loop over the codebooks one-by-one
input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
# construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding)
delay_pattern = torch.triu(
torch.ones((num_codebooks, max_length), dtype=torch.bool), diagonal=max_length - num_codebooks + 1
)
# then fill the lower triangular part (the BOS padding)
delay_pattern = delay_pattern + torch.tril(torch.ones((num_codebooks, max_length), dtype=torch.bool))
mask = ~delay_pattern.to(input_ids.device)
input_ids = mask * input_ids_shifted + ~mask * pad_token_id
# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
first_codebook_ids = input_ids[:, 0, :]
start_ids = (first_codebook_ids == -1).nonzero()[:, 1]
if len(start_ids) > 0:
first_start_id = min(start_ids)
else:
# we have no tokens that need to be filled - return entire matrix of input ids
first_start_id = seq_len
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
pattern_mask = input_ids.reshape(bsz * num_codebooks, -1)
input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
return input_ids, pattern_mask
return build_delay_pattern_mask(input_ids, bos_token_id, pad_token_id, max_length, self.num_codebooks)
@staticmethod
def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
"""Apply a delay pattern mask to the decoder input ids, only preserving predictions where
the mask is set to -1, and otherwise setting to the value detailed in the mask."""
seq_len = input_ids.shape[-1]
decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len]
input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask)
return input_ids
return apply_delay_pattern_mask(input_ids, decoder_pad_token_mask)
@torch.no_grad()
def generate(
......@@ -1140,7 +1290,6 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
**kwargs,
):
"""
Generates sequences of token ids for models with a language modeling head.
<Tip warning={true}>
......@@ -1279,10 +1428,11 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
)
# 6. Prepare `input_ids` which will be used for auto-regressive generation
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler-TTS)
input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
input_ids,
pad_token_id=generation_config.decoder_start_token_id,
bos_token_id=generation_config.bos_token_id,
pad_token_id=generation_config.pad_token_id,
max_length=generation_config.max_length,
)
......@@ -1380,15 +1530,21 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
output_ids = outputs.sequences
else:
output_ids = outputs
# apply the pattern mask to the final ids
output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
# revert the pattern delay mask by filtering the pad token id
output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape(
batch_size, self.num_codebooks, -1
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
_, mask = self.build_delay_pattern_mask(
input_ids,
bos_token_id=generation_config.bos_token_id,
pad_token_id=generation_config.pad_token_id,
max_length=output_ids.shape[1],
)
mask = (mask != generation_config.bos_token_id) & (mask != generation_config.pad_token_id)
output_ids = output_ids[mask].reshape(batch_size, self.num_codebooks, -1)
if generation_config.return_dict_in_generate:
outputs.sequences = output_ids
return outputs
......@@ -1397,31 +1553,29 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
@add_start_docstrings(
"The composite Stable Speech model with a text encoder, audio encoder and StableSpeech decoder, "
"The composite Parler-TTS model with a text encoder, audio encoder and ParlerTTS decoder, "
"for music generation tasks with one or both of text and audio prompts.",
MUSICGEN_START_DOCSTRING,
)
class StableSpeechForConditionalGeneration(PreTrainedModel):
config_class = StableSpeechConfig
class ParlerTTSForConditionalGeneration(PreTrainedModel):
config_class = ParlerTTSConfig
base_model_prefix = "encoder_decoder"
main_input_name = "input_ids"
supports_gradient_checkpointing = True
def __init__(
self,
config: Optional[StableSpeechConfig] = None,
config: Optional[ParlerTTSConfig] = None,
text_encoder: Optional[PreTrainedModel] = None,
audio_encoder: Optional[PreTrainedModel] = None,
decoder: Optional[StableSpeechForCausalLM] = None,
decoder: Optional[ParlerTTSForCausalLM] = None,
):
if config is None and (text_encoder is None or audio_encoder is None or decoder is None):
raise ValueError(
"Either a configuration has to be provided, or all three of text encoder, audio encoder and Stable Speech decoder."
"Either a configuration has to be provided, or all three of text encoder, audio encoder and Parler-TTS decoder."
)
if config is None:
config = StableSpeechConfig.from_sub_models_config(
text_encoder.config, audio_encoder.config, decoder.config
)
config = ParlerTTSConfig.from_sub_models_config(text_encoder.config, audio_encoder.config, decoder.config)
else:
if not isinstance(config, self.config_class):
raise ValueError(f"Config: {config} has to be of type {self.config_class}")
......@@ -1429,7 +1583,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.text_encoder.hidden_size:
raise ValueError(
"If `cross_attention_hidden_size` is specified in the Stable Speech decoder's configuration, it has to be equal"
"If `cross_attention_hidden_size` is specified in the Parler-TTS decoder's configuration, it has to be equal"
f" to the text encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
f" `config.decoder.cross_attention_hidden_size` and {config.text_encoder.hidden_size} for"
" `config.text_encoder.hidden_size`."
......@@ -1449,7 +1603,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
audio_encoder = AutoModel.from_config(config.audio_encoder)
if decoder is None:
decoder = StableSpeechForCausalLM(config.decoder)
decoder = ParlerTTSForCausalLM(config.decoder)
self.text_encoder = text_encoder
self.audio_encoder = audio_encoder
......@@ -1484,6 +1638,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
):
self.enc_to_dec_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size)
# prompt embeddings
self.embed_prompts = nn.Embedding(config.vocab_size, self.decoder.config.hidden_size)
if self.text_encoder.get_output_embeddings() is not None:
raise ValueError(
f"The encoder {self.text_encoder} should not have a LM Head. Please use a model without and LM Head"
......@@ -1496,8 +1653,19 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
"following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
)
# tie text encoder, decoder weights if config set accordingly
self.tie_weights()
# Initialize projection and embedding layers and tie text encoder and decoder weights if set accordingly
self.post_init()
def _init_weights(self, module):
std = self.decoder.config.initializer_factor
if isinstance(module, (nn.Linear, nn.Conv1d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def tie_weights(self):
# tie text encoder & decoder if needed
......@@ -1536,15 +1704,15 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
Example:
```python
>>> from transformers import StableSpeechForConditionalGeneration
>>> from parler_tts import ParlerTTSForConditionalGeneration
>>> model = StableSpeechForConditionalGeneration.from_pretrained("facebook/stable_speech-small")
>>> model = ParlerTTSForConditionalGeneration.from_pretrained("facebook/parler_tts-small")
```"""
# At the moment fast initialization is not supported for composite models
if kwargs.get("_fast_init", False):
logger.warning(
"Fast initialization is currently not supported for StableSpeechForConditionalGeneration. "
"Fast initialization is currently not supported for ParlerTTSForConditionalGeneration. "
"Falling back to slow initialization..."
)
kwargs["_fast_init"] = False
......@@ -1561,7 +1729,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
**kwargs,
) -> PreTrainedModel:
r"""
Instantiate a text encoder, an audio encoder, and a Stable Speech decoder from one, two or three base classes of the
Instantiate a text encoder, an audio encoder, and a Parler-TTS decoder from one, two or three base classes of the
library from pretrained model checkpoints.
......@@ -1592,7 +1760,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like `gpt2`, or namespaced under a user or
organization name, like `facebook/stable_speech-small`.
organization name, like `facebook/parler_tts-small`.
- A path to a *directory* containing model weights saved using
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
......@@ -1615,18 +1783,18 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
Example:
```python
>>> from transformers import StableSpeechForConditionalGeneration
>>> from parler_tts import ParlerTTSForConditionalGeneration
>>> # initialize a stable_speech model from a t5 text encoder, encodec audio encoder, and stable_speech decoder
>>> model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
>>> # initialize a parler_tts model from a t5 text encoder, encodec audio encoder, and parler_tts decoder
>>> model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
... text_encoder_pretrained_model_name_or_path="t5-base",
... audio_encoder_pretrained_model_name_or_path="facebook/encodec_24khz",
... decoder_pretrained_model_name_or_path="facebook/stable_speech-small",
... decoder_pretrained_model_name_or_path="facebook/parler_tts-small",
... )
>>> # saving model after fine-tuning
>>> model.save_pretrained("./stable_speech-ft")
>>> model.save_pretrained("./parler_tts-ft")
>>> # load fine-tuned model
>>> model = StableSpeechForConditionalGeneration.from_pretrained("./stable_speech-ft")
>>> model = ParlerTTSForConditionalGeneration.from_pretrained("./parler_tts-ft")
```"""
kwargs_text_encoder = {
......@@ -1679,7 +1847,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
kwargs_text_encoder["config"] = encoder_config
text_encoder = AutoModel.from_pretrained(
text_encoder = AutoModelForTextEncoding.from_pretrained(
text_encoder_pretrained_model_name_or_path, *model_args, **kwargs_text_encoder
)
......@@ -1719,11 +1887,11 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
)
if "config" not in kwargs_decoder:
decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
decoder_config, kwargs_decoder = ParlerTTSDecoderConfig.from_pretrained(
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
)
if isinstance(decoder_config, StableSpeechConfig):
if isinstance(decoder_config, ParlerTTSConfig):
decoder_config = decoder_config.decoder
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
......@@ -1746,10 +1914,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
"`decoder_config` to `.from_sub_models_pretrained(...)`"
)
decoder = StableSpeechForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
decoder = ParlerTTSForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
# instantiate config with corresponding kwargs
config = StableSpeechConfig.from_sub_models_config(
config = ParlerTTSConfig.from_sub_models_config(
text_encoder.config, audio_encoder.config, decoder.config, **kwargs
)
return cls(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder, config=config)
......@@ -1768,6 +1936,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
prompt_input_ids: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
......@@ -1780,11 +1951,11 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
Examples:
```python
>>> from transformers import AutoProcessor, StableSpeechForConditionalGeneration
>>> from transformers import AutoProcessor, ParlerTTSForConditionalGeneration
>>> import torch
>>> processor = AutoProcessor.from_pretrained("facebook/stable_speech-small")
>>> model = StableSpeechForConditionalGeneration.from_pretrained("facebook/stable_speech-small")
>>> processor = AutoProcessor.from_pretrained("facebook/parler_tts-small")
>>> model = ParlerTTSForConditionalGeneration.from_pretrained("facebook/parler_tts-small")
>>> inputs = processor(
... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
......@@ -1845,10 +2016,14 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if attention_mask is not None:
encoder_hidden_states = encoder_hidden_states * attention_mask[..., None]
if prompt_hidden_states is None:
if prompt_input_ids is not None:
prompt_hidden_states = self.embed_prompts(prompt_input_ids)
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
).transpose(1, 2)
elif decoder_input_ids is None and decoder_inputs_embeds is None:
audio_encoder_outputs = self.audio_encoder(
......@@ -1876,29 +2051,23 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=attention_mask,
prompt_hidden_states=prompt_hidden_states,
prompt_attention_mask=prompt_attention_mask,
inputs_embeds=decoder_inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
past_key_values=past_key_values,
return_dict=return_dict,
labels=labels,
**kwargs_decoder,
)
loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
if loss is not None:
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs
return decoder_outputs + (encoder_hidden_states,)
return Seq2SeqLMOutput(
loss=loss,
loss=decoder_outputs.loss,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
......@@ -1917,6 +2086,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
head_mask=None,
decoder_attention_mask=None,
decoder_head_mask=None,
prompt_hidden_states=None,
prompt_attention_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
......@@ -1927,7 +2098,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if decoder_delay_pattern_mask is None:
decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
decoder_input_ids,
self.generation_config.pad_token_id,
bos_token_id=self.generation_config.bos_token_id,
pad_token_id=self.generation_config.pad_token_id,
max_length=self.generation_config.max_length,
)
......@@ -1940,6 +2112,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
decoder_input_ids = decoder_input_ids.repeat((2, 1))
if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.repeat((2, 1))
if prompt_hidden_states is not None:
prompt_hidden_states = prompt_hidden_states.repeat((2, 1, 1))
if prompt_attention_mask is not None:
prompt_attention_mask = prompt_attention_mask.repeat((2, 1))
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
......@@ -1953,6 +2129,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
# we only want to use prompt signal in the 1st generation step but keeping the attention mask
prompt_hidden_states = None
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
......@@ -1963,6 +2142,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"prompt_hidden_states": prompt_hidden_states,
"prompt_attention_mask": prompt_attention_mask,
"use_cache": use_cache,
}
......@@ -2059,6 +2240,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
return model_kwargs
def _prepare_prompt_kwargs_for_generation(self, prompt_input_ids, model_kwargs):
model_kwargs["prompt_hidden_states"] = self.embed_prompts(prompt_input_ids)
return model_kwargs
def _prepare_audio_encoder_kwargs_for_generation(
self, input_values, model_kwargs, model_input_name: Optional[str] = None
):
......@@ -2107,7 +2292,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
return model_kwargs
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id).transpose(1, 2)
def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
......@@ -2144,6 +2329,16 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
break
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
def freeze_encoders(self, freeze_text_encoder=True):
if freeze_text_encoder:
for param in self.text_encoder.parameters():
param.requires_grad = False
self.text_encoder._requires_grad = False
for param in self.audio_encoder.parameters():
param.requires_grad = False
self.audio_encoder._requires_grad = False
@torch.no_grad()
def generate(
self,
......@@ -2278,6 +2473,13 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
guidance_scale=generation_config.guidance_scale,
)
if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs:
# `prompt_hidden_states` are created and added to `model_kwargs`
model_kwargs = self._prepare_prompt_kwargs_for_generation(
model_kwargs["prompt_input_ids"],
model_kwargs,
)
if "decoder_input_ids" not in model_kwargs and "input_values" in model_kwargs:
model_kwargs = self._prepare_audio_encoder_kwargs_for_generation(
model_kwargs["input_values"],
......@@ -2324,10 +2526,11 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
" increasing `max_new_tokens`."
)
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler-TTS)
input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
input_ids,
pad_token_id=generation_config.decoder_start_token_id,
bos_token_id=generation_config.bos_token_id,
pad_token_id=generation_config.pad_token_id,
max_length=generation_config.max_length,
)
# stash the delay mask so that we don't have to recompute in each forward pass
......@@ -2430,11 +2633,17 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# apply the pattern mask to the final ids
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
# revert the pattern delay mask by filtering the pad token id
output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape(
batch_size, self.decoder.num_codebooks, -1
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
_, mask = self.decoder.build_delay_pattern_mask(
input_ids,
bos_token_id=generation_config.bos_token_id,
pad_token_id=generation_config.pad_token_id,
max_length=output_ids.shape[1],
)
mask = (mask != generation_config.bos_token_id) & (mask != generation_config.pad_token_id)
output_ids = output_ids[mask].reshape(batch_size, self.decoder.num_codebooks, -1)
# append the frame dimension back to the audio codes
output_ids = output_ids[None, ...]
......@@ -2442,47 +2651,36 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if audio_scales is None:
audio_scales = [None] * batch_size
output_values = self.audio_encoder.decode(
output_ids,
audio_scales=audio_scales,
).audio_values
decode_sequentially = (
generation_config.bos_token_id in output_ids
or generation_config.pad_token_id in output_ids
or generation_config.eos_token_id in output_ids
)
if not decode_sequentially:
output_values = self.audio_encoder.decode(
output_ids,
audio_scales=audio_scales,
).audio_values.squeeze(1)
else:
output_values = []
for sample_id in range(batch_size):
sample = output_ids[:, sample_id]
sample_mask = (sample >= self.audio_encoder.config.codebook_size).sum(dim=(0, 1)) == 0
if sample_mask.sum() > 0:
sample = sample[:, :, sample_mask]
sample = self.audio_encoder.decode(sample[None, ...], [audio_scales[sample_id]]).audio_values
output_values.append(sample.transpose(0, 2))
else:
output_values.append(torch.zeros((1, 1, 1)).to(self.device))
# TODO: we should keep track of output length as well. Not really straightfoward tbh
output_values = (
torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0)
.squeeze(-1)
.squeeze(-1)
)
if generation_config.return_dict_in_generate:
outputs.sequences = output_values
return outputs
else:
return output_values
def get_unconditional_inputs(self, num_samples=1):
"""
Helper function to get null inputs for unconditional generation, enabling the model to be used without the
feature extractor or tokenizer.
Args:
num_samples (int, *optional*):
Number of audio samples to unconditionally generate.
max_new_tokens (int, *optional*):
Number of tokens to generate for each sample. More tokens means longer audio samples, at the expense of
longer inference (since more audio tokens need to be generated per sample).
Example:
```python
>>> from transformers import StableSpeechForConditionalGeneration
>>> model = StableSpeechForConditionalGeneration.from_pretrained("facebook/stable_speech-small")
>>> # get the unconditional (or 'null') inputs for the model
>>> unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
>>> audio_samples = model.generate(**unconditional_inputs, max_new_tokens=256)
```"""
last_hidden_state = torch.zeros(
(num_samples, 1, self.config.text_encoder.hidden_size), device=self.device, dtype=self.dtype
)
attention_mask = torch.zeros((num_samples, 1), device=self.device, dtype=torch.long)
return StableSpeechUnconditionalInput(
encoder_outputs=(last_hidden_state,),
attention_mask=attention_mask,
guidance_scale=1.0,
)
return output_values
\ No newline at end of file
#!/usr/bin/env bash
python run_prompt_creation.py \
--dataset_name "ylacombe/libritts_r_tags_and_text" \
--dataset_config_name "clean" \
--dataset_split_name "dev.clean" \
--model_name_or_path "mistralai/Mistral-7B-Instruct-v0.2" \
--per_device_eval_batch_size 2 \
--attn_implementation "sdpa" \
--dataloader_num_workers 0 \
--output_dir "./" \
--load_in_4bit
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import re
import sys
from collections import Counter
from dataclasses import dataclass, field
from random import randint
from typing import List, Optional, Union
import datasets
import evaluate
import numpy as np
import transformers
from datasets import Dataset, DatasetDict, IterableDataset, concatenate_datasets, interleave_datasets, load_dataset
from tqdm import tqdm
from transformers import (
AutoConfig,
AutoFeatureExtractor,
AutoModelForAudioClassification,
HfArgumentParser,
Trainer,
TrainingArguments,
set_seed,
)
from transformers.models.whisper.tokenization_whisper import LANGUAGES
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.38.0.dev0")
def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 16000) -> np.ndarray:
"""Randomly sample chunks of `max_length` seconds from the input audio"""
sample_length = int(round(sample_rate * max_length))
if len(wav) <= sample_length:
return wav
random_offset = randint(0, len(wav) - sample_length - 1)
return wav[random_offset : random_offset + sample_length]
def deterministic_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 16000) -> np.ndarray:
"""Take first `max_length` seconds from the input audio"""
sample_length = int(round(sample_rate * max_length))
if len(wav) <= sample_length:
return wav
return wav[0:sample_length]
# This list first defines the accent prefixes, which we use to strip the accent from CV
# e.g. England, southern accent, slight west-country expression -> England
# TODO(YL): update this with any CV test prefixes not present in the train set
STARTS_WITH = [
"Afrikaans",
"American",
"Australian",
"Bangladeshi",
"Canadian",
"Chinese",
"Dutch",
"Eastern European",
"European",
"England",
"English",
"German",
"Filipino",
"India",
"Irish" "Israeli",
"Italian",
"Japanese",
"Kenyan",
"Northern Irish",
"New Zealand",
"Nigerian",
"Malaysian",
"Russian",
"Scottish",
"Singaporean",
"Slavic",
"South African",
"Southern African",
"Swedish",
"Swiss",
"United States English",
"West Indies",
"french",
"polish",
"serbian",
]
# This dictionary is used to map the un-normalised accent names to normalised ones
# TODO(YL): update this with any CV test mappings not present in the train set
ACCENT_MAPPING = {
"British": "English",
# "Canadian": "American", TODO(SG): decide whether to normalize these to closely related accents
# "New zealand": "Australian",
"Northern irish": "Irish",
"Pakistani": "Indian",
"Mainstream u s english": "American",
"Southern british english": "English",
"Indian english": "Indian",
"Scottish english": "Scottish",
"Don't know": "Unknown",
"Nigerian english": "Nigerian",
"Kenyan english": "Kenyan",
"Ghanain english": "Ghanain",
"Jamaican english": "Jamaican",
"Indonesian english": "Indonesian",
"South african english": "South african",
"Irish english": "Irish",
"Latin": "Latin american",
"European": "Unknown", # Too general
"Eastern european": "Eastern european", # TODO(SG): keep for now, but maybe remove later as too general
"Bangladeshi": "Indian",
"England": "English",
"India": "Indian",
"Afrikaans": "South african",
"California": "American",
"Nepali": "Indian",
"New york city": "American",
"New jerseyan": "American",
"Northumbrian british english": "English",
"Nottinghamshire,east midlands": "English",
"Southern african": "South african",
"United states english": "American",
"West indies": "Jamaican",
"2nd language": "Unknown", # Too vague
"A savage texas gentleman": "American",
"A variety of texan english with some german influence that has undergone the cot-caught merger": "American",
"A'lo": "Unknown", # Unclear
"Academic southern english,england english": "English",
"Argentinian english": "Latin american",
"Austrian": "German",
"Bangladesh,india and south asia (india, pakistan, sri lanka)": "Indian",
"Brazillian accent": "Brazilian",
"British accent": "English",
"Caribbean canadian": "Unknown", # Specific combination not listed
"Colombian accent": "Latin american",
"Czech accent": "Czech",
"East african khoja": "Unknown", # Specific community
"East indian": "Indian",
"East london": "English",
"England,london,academic": "English",
"Filipino": "Unknown", # Unique blend
"Fluent,e sl,european": "Unknown", # Too vague
"Generic european": "Unknown", # Too vague
"Georgian english": "Unknown", # No direct match
"Ghanaian english accent,african regular reader": "Unknown", # Specific category not listed
"Haitian creole": "Unknown", # Unique blend
"Hispanic": "Latin american",
"Hispanic/latino": "Latin american",
"Hong kong english": "Chinese",
"Hong kong english,scottish english": "Chinese",
"Hunglish": "Hungarian",
"I think mine accent is influenced by indian accent ,yes please. ,india and south asia (india, pakistan, sri lanka)": "Indian",
"I was born in england and have lived in australia, canada and france.": "English",
"International english,united states english,australian english": "American",
"Israeli": "Unknown", # No direct match
"Israeli english": "Unknown", # No direct match
"Javanese,indonesian english,malaysian english": "Indonesian",
"Kazakhstan english": "Unknown", # No direct match
"Kiwi": "New zealand", # Could be generalised to Australian
"Latin america,united states english": "Latin american",
"Latin american accent": "Latin american",
"Latin english": "Unknown", # Too vague
"Latino": "Latin american",
"Latvian": "Latvian", # Note: added new
"Little latino,united states english,second language": "Latin american",
"Liverpool english,lancashire english,england english": "English",
"Liverpudlian english": "English",
"Malaysian english": "Malaysian", # Note: added new
"Mexican accent": "Latin american",
"Mid-atlantic united states english,philadelphia, pennsylvania, united states english,united states english,philadelphia style united states english": "American",
"Mid-atlantic,england english,united states english": "American",
"Midatlantic,england english": "American",
"Midwestern states (michigan),united states english": "American",
"Mild northern england english": "English",
"Minor french accent": "French",
"Mix of american and british ,native polish": "Polish",
"Mix of american and british accent": "Unknown", # Combination not clearly mapped
"Mostly american with some british and australian inflections": "Unknown", # Combination not clearly mapped
"My accent is influenced by the phones of all letters within a sentence.,southern african (south africa, zimbabwe, namibia)": "South african",
"New zealand english": "New Zealand English",
"Nigeria english": "Nigerian", # Note: added new
"Non native speaker from france": "French",
"Non-native": "Unknown", # Too vague
"Non-native,german accent": "German",
"North european english": "Unknown", # Too broad
"Norwegian": "Norwegian", # Note: added new
"Ontario,canadian english": "Canadian", # Note: added new
"Polish english": "Polish",
"Rhode island new england accent": "American",
"Singaporean english": "Singaporean", # Note: added new
"Slavic": "Eastern european",
"Slighty southern affected by decades in the midwest, 4 years in spain and germany, speak some german, spanish, polish. have lived in nine states.": "Unknown", # Complex blend
"South african": "South african",
"South atlantic (falkland islands, saint helena)": "Unknown", # Specific regions not listed
"South australia": "Australian",
"South indian": "Indian",
"Southern drawl": "American",
"Southern texas accent,united states english": "American",
"Southern united states,united states english": "American",
"Spanish bilingual": "Spanish",
"Spanish,foreign,non-native": "Spanish",
"Strong latvian accent": "Latvian",
"Swedish accent": "Swedish", # Note: added new
"Transnational englishes blend": "Unknown", # Too vague
"U.k. english": "English",
"Very slight russian accent,standard american english,boston influence": "American",
"Welsh english": "Welsh",
"West african": "Unknown", # No specific West African category
"West indian": "Unknown", # Caribbean, but no specific match
"Western europe": "Unknown", # Too broad
"With heavy cantonese accent": "Chinese",
}
def preprocess_labels(label: str) -> str:
"""Apply pre-processing formatting to the accent labels"""
if "_" in label:
# voxpopuli stylises the accent as a language code (e.g. en_pl for "polish") - convert to full accent
language_code = label.split("_")[-1]
label = LANGUAGES[language_code]
# VCTK labels for two words are concatenated into one (NewZeleand-> New Zealand)
label = re.sub(r"(\w)([A-Z])", r"\1 \2", label).strip()
for prefix in STARTS_WITH:
if label.startswith(prefix):
label = prefix
# convert Whisper language code (polish) to capitalised (Polish)
label = label.capitalize()
if label in ACCENT_MAPPING:
label = ACCENT_MAPPING[label]
return label
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
Using `HfArgumentParser` we can turn this class
into argparse arguments to be able to specify them on
the command line.
"""
train_dataset_name: str = field(
default=None,
metadata={
"help": "The name of the training dataset to use (via the datasets library). Load and combine "
"multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
},
)
train_dataset_config_name: Optional[str] = field(
default=None,
metadata={
"help": "The configuration name of the training dataset to use (via the datasets library). Load and combine "
"multiple datasets by separating dataset configs by a '+' symbol."
},
)
train_split_name: str = field(
default="train",
metadata={
"help": ("The name of the training data set split to use (via the datasets library). Defaults to 'train'")
},
)
train_dataset_samples: str = field(
default=None,
metadata={
"help": "Number of samples in the training data. Load and combine "
"multiple datasets by separating dataset samples by a '+' symbol."
},
)
eval_dataset_name: str = field(
default=None,
metadata={
"help": "The name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset name if unspecified."
},
)
eval_dataset_config_name: Optional[str] = field(
default=None,
metadata={
"help": "The configuration name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset config name if unspecified"
},
)
eval_split_name: str = field(
default="validation",
metadata={
"help": (
"The name of the evaluation data set split to use (via the datasets"
" library). Defaults to 'validation'"
)
},
)
audio_column_name: str = field(
default="audio",
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
)
train_label_column_name: str = field(
default="labels",
metadata={
"help": "The name of the dataset column containing the labels in the train set. Defaults to 'label'"
},
)
eval_label_column_name: str = field(
default="labels",
metadata={"help": "The name of the dataset column containing the labels in the eval set. Defaults to 'label'"},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
max_length_seconds: Optional[float] = field(
default=20,
metadata={"help": "Audio samples will be randomly cut to this length during training if the value is set."},
)
min_length_seconds: Optional[float] = field(
default=5,
metadata={"help": "Audio samples less than this value will be filtered during training if the value is set."},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
filter_threshold: Optional[float] = field(
default=1.0,
metadata={"help": "Filter labels that occur less than `filter_threshold` percent in the training/eval data."},
)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
default="facebook/wav2vec2-base",
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"},
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from the Hub"}
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
feature_extractor_name: Optional[str] = field(
default=None, metadata={"help": "Name or path of preprocessor config."}
)
freeze_feature_encoder: bool = field(
default=False,
metadata={
"help": "Whether to freeze the feature encoder layers of the model. Only relevant for Wav2Vec2-style models."
},
)
freeze_base_model: bool = field(
default=True, metadata={"help": "Whether to freeze the base encoder of the model."}
)
attention_mask: bool = field(
default=True, metadata={"help": "Whether to generate an attention mask in the feature extractor."}
)
token: str = field(
default=None,
metadata={
"help": (
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
)
},
)
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
"should only be set to `True` for repositories you trust and in which you have read the code, as it will "
"execute code present on the Hub on your local machine."
)
},
)
ignore_mismatched_sizes: bool = field(
default=True,
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
)
attention_dropout: float = field(
default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."}
)
activation_dropout: float = field(
default=0.0, metadata={"help": "The dropout ratio for activations inside the fully connected layer."}
)
feat_proj_dropout: float = field(default=0.0, metadata={"help": "The dropout ratio for the projected features."})
hidden_dropout: float = field(
default=0.0,
metadata={
"help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
},
)
final_dropout: float = field(
default=0.0,
metadata={"help": "The dropout probability for the final projection layer."},
)
mask_time_prob: float = field(
default=0.05,
metadata={
"help": (
"Probability of each feature vector along the time axis to be chosen as the start of the vector "
"span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature "
"vectors will be masked along the time axis."
)
},
)
mask_time_length: int = field(
default=10,
metadata={"help": "Length of vector span to mask along the time axis."},
)
mask_feature_prob: float = field(
default=0.0,
metadata={
"help": (
"Probability of each feature vector along the feature axis to be chosen as the start of the vectorspan"
" to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature"
" bins will be masked along the time axis."
)
},
)
mask_feature_length: int = field(
default=10,
metadata={"help": "Length of vector span to mask along the feature axis."},
)
layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
def convert_dataset_str_to_list(
dataset_names,
dataset_config_names,
splits=None,
label_column_names=None,
dataset_samples=None,
default_split="train",
):
if isinstance(dataset_names, str):
dataset_names = dataset_names.split("+")
dataset_config_names = dataset_config_names.split("+")
splits = splits.split("+") if splits is not None else None
label_column_names = label_column_names.split("+") if label_column_names is not None else None
dataset_samples = dataset_samples.split("+") if dataset_samples is not None else None
# basic checks to ensure we've got the right number of datasets/configs/splits/columns/probs
if len(dataset_names) != len(dataset_config_names):
raise ValueError(
f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and"
f" {len(dataset_config_names)} configs."
)
if splits is not None and len(splits) != len(dataset_names):
raise ValueError(
f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
)
if label_column_names is not None and len(label_column_names) != len(dataset_names):
raise ValueError(
f"Ensure one label column name is passed for each dataset, got {len(dataset_names)} datasets and"
f" {len(label_column_names)} label column names."
)
if dataset_samples is not None:
if len(dataset_samples) != len(dataset_names):
raise ValueError(
f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and "
f"{len(dataset_samples)} samples."
)
dataset_samples = [float(ds_sample) for ds_sample in dataset_samples]
else:
dataset_samples = [None] * len(dataset_names)
label_column_names = (
label_column_names if label_column_names is not None else ["labels" for _ in range(len(dataset_names))]
)
splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
dataset_names_dict = []
for i, ds_name in enumerate(dataset_names):
dataset_names_dict.append(
{
"name": ds_name,
"config": dataset_config_names[i],
"split": splits[i],
"label_column_name": label_column_names[i],
"samples": dataset_samples[i],
}
)
return dataset_names_dict
def load_multiple_datasets(
dataset_names: Union[List, str],
dataset_config_names: Union[List, str],
splits: Optional[Union[List, str]] = None,
label_column_names: Optional[List] = None,
sampling_rate: Optional[int] = 16000,
stopping_strategy: Optional[str] = "first_exhausted",
dataset_samples: Optional[Union[List, np.array]] = None,
streaming: Optional[bool] = False,
seed: Optional[int] = None,
audio_column_name: Optional[str] = "audio",
**kwargs,
) -> Union[Dataset, IterableDataset]:
dataset_names_dict = convert_dataset_str_to_list(
dataset_names, dataset_config_names, splits, label_column_names, dataset_samples
)
if dataset_samples is not None:
dataset_samples = [ds_dict["samples"] for ds_dict in dataset_names_dict]
probabilities = np.array(dataset_samples) / np.sum(dataset_samples)
else:
probabilities = None
all_datasets = []
# iterate over the datasets we want to interleave
for dataset_dict in tqdm(dataset_names_dict, desc="Combining datasets..."):
dataset = load_dataset(
dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
streaming=streaming,
**kwargs,
)
dataset_features = dataset.features.keys()
if audio_column_name not in dataset_features:
raise ValueError(
f"Audio column name '{audio_column_name}' not found in dataset"
f" '{dataset_dict['name']}'. Make sure to set `--audio_column_name` to"
f" the correct audio column - one of {', '.join(dataset_features)}."
)
# resample to specified sampling rate
dataset = dataset.cast_column("audio", datasets.features.Audio(sampling_rate))
if dataset_dict["label_column_name"] not in dataset_features:
raise ValueError(
f"Label column name {dataset_dict['label_column_name']} not found in dataset"
f" '{dataset_dict['name']}'. Make sure to set `--label_column_name` to the"
f" correct text column - one of {', '.join(dataset_features)}."
)
# blanket renaming of all label columns to label
if dataset_dict["label_column_name"] != "labels":
dataset = dataset.rename_column(dataset_dict["label_column_name"], "labels")
dataset_features = dataset.features.keys()
columns_to_keep = {"audio", "labels"}
dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
all_datasets.append(dataset)
if len(all_datasets) == 1:
# we have a single dataset so just return it as is
return all_datasets[0]
if streaming:
interleaved_dataset = interleave_datasets(
all_datasets,
stopping_strategy=stopping_strategy,
probabilities=probabilities,
seed=seed,
)
else:
interleaved_dataset = concatenate_datasets(all_datasets)
return interleaved_dataset
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
set_seed(training_args.seed)
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to train from scratch."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Initialize our dataset and prepare it for the audio classification task.
raw_datasets = DatasetDict()
# set seed for determinism
set_seed(training_args.seed)
if training_args.do_train:
raw_datasets["train"] = load_multiple_datasets(
data_args.train_dataset_name,
data_args.train_dataset_config_name,
splits=data_args.train_split_name,
label_column_names=data_args.train_label_column_name,
dataset_samples=data_args.train_dataset_samples,
seed=training_args.seed,
cache_dir=model_args.cache_dir,
token=True if model_args.token else None,
trust_remote_code=model_args.trust_remote_code,
num_proc=data_args.preprocessing_num_workers,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
if training_args.do_eval:
dataset_names_dict = convert_dataset_str_to_list(
data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
(
data_args.eval_dataset_config_name
if data_args.eval_dataset_config_name
else data_args.train_dataset_config_name
),
splits=data_args.eval_split_name,
label_column_names=data_args.eval_label_column_name,
)
all_eval_splits = []
# load multiple eval sets
for dataset_dict in dataset_names_dict:
pretty_name = (
f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['split'].replace('.', '-')}"
if len(dataset_names_dict) > 1
else "eval"
)
all_eval_splits.append(pretty_name)
raw_datasets[pretty_name] = load_dataset(
dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
cache_dir=model_args.cache_dir,
token=True if model_args.token else None,
trust_remote_code=model_args.trust_remote_code,
num_proc=data_args.preprocessing_num_workers,
# streaming=data_args.streaming,
)
features = raw_datasets[pretty_name].features.keys()
if dataset_dict["label_column_name"] not in features:
raise ValueError(
f"--label_column_name {data_args.eval_label_column_name} not found in dataset '{data_args.dataset_name}'. "
"Make sure to set `--label_column_name` to the correct text column - one of "
f"{', '.join(raw_datasets['train'].column_names)}."
)
elif dataset_dict["label_column_name"] != "labels":
raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
dataset_dict["label_column_name"], "labels"
)
raw_datasets[pretty_name] = raw_datasets[pretty_name].remove_columns(
set(raw_datasets[pretty_name].features.keys()) - {"audio", "labels"}
)
if not training_args.do_train and not training_args.do_eval:
raise ValueError(
"Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
)
# Setting `return_attention_mask=True` is the way to get a correctly masked mean-pooling over
# transformer outputs in the classifier, but it doesn't always lead to better accuracy
feature_extractor = AutoFeatureExtractor.from_pretrained(
model_args.feature_extractor_name or model_args.model_name_or_path,
return_attention_mask=model_args.attention_mask,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
# `datasets` takes care of automatically loading and resampling the audio,
# so we just need to set the correct target sampling rate.
raw_datasets = raw_datasets.cast_column(
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
)
if training_args.do_train:
if data_args.max_train_samples is not None:
raw_datasets["train"] = (
raw_datasets["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
)
if training_args.do_eval:
if data_args.max_eval_samples is not None:
raw_datasets["eval"] = (
raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
)
sampling_rate = feature_extractor.sampling_rate
model_input_name = feature_extractor.model_input_names[0]
def prepare_dataset(batch):
batch["length"] = len(batch["audio"]["array"])
batch["labels"] = preprocess_labels(batch["labels"])
return batch
raw_datasets = raw_datasets.map(
prepare_dataset,
num_proc=data_args.preprocessing_num_workers,
desc="Computing audio length",
)
# filter training data with inputs < min_input_length
min_input_length = data_args.min_length_seconds * sampling_rate
def is_audio_valid(input_length):
return input_length > min_input_length
raw_datasets = raw_datasets.filter(
is_audio_valid,
input_columns=["length"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering by audio length",
)
# filter training data with non-valid labels
def is_label_valid(label):
return label != "Unknown"
raw_datasets = raw_datasets.filter(
is_label_valid,
input_columns=["labels"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering by labels",
)
# Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered)
# sort by freq
count_labels_dict = Counter(raw_datasets["train"]["labels"])
count_labels_dict = sorted(count_labels_dict.items(), key=lambda item: (-item[1], item[0]))
labels, frequencies = zip(*count_labels_dict)
total_labels = sum(frequencies)
labels_to_remove = []
logger.info(f"{'Accent':<15} {'Perc.':<5}")
logger.info("-" * 20)
for lab, freq in zip(labels, frequencies):
freq = 100 * freq / total_labels
logger.info(f"{lab:<15} {freq:<5}")
if freq < data_args.filter_threshold:
labels_to_remove.append(lab)
if len(labels_to_remove):
# filter training data with label freq below threshold
def is_label_valid(label):
return label not in labels_to_remove
raw_datasets = raw_datasets.filter(
is_label_valid,
input_columns=["labels"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering low freq labels",
)
# We'll include these in the model's config to get human readable labels in the Inference API.
set_labels = set(raw_datasets["train"]["labels"])
if training_args.do_eval:
set_labels = set_labels.union(set(raw_datasets["eval"]["labels"]))
label2id, id2label = {}, {}
for i, label in enumerate(set(set_labels)):
label2id[label] = str(i)
id2label[str(i)] = label
def train_transforms(batch):
"""Apply train_transforms across a batch."""
subsampled_wavs = []
for audio in batch["audio"]:
wav = deterministic_subsample(
audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
)
subsampled_wavs.append(wav)
inputs = feature_extractor(
subsampled_wavs, return_attention_mask=model_args.attention_mask, sampling_rate=sampling_rate
)
output_batch = {
model_input_name: inputs.get(model_input_name),
"attention_mask": inputs.get("attention_mask"),
"labels": [int(label2id[label]) for label in batch["labels"]],
}
return output_batch
if training_args.do_train:
# Set the training transforms
raw_datasets["train"].set_transform(train_transforms, output_all_columns=False)
if training_args.do_eval:
# Set the validation transforms
raw_datasets["eval"].set_transform(train_transforms, output_all_columns=False)
# Load the accuracy metric from the datasets package
metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)
# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with
# `predictions` and `label_ids` fields) and has to return a dictionary string to float.
def compute_metrics(eval_pred):
"""Computes accuracy on a batch of predictions"""
predictions = np.argmax(eval_pred.predictions, axis=1)
return metric.compute(predictions=predictions, references=eval_pred.label_ids)
config = AutoConfig.from_pretrained(
model_args.config_name or model_args.model_name_or_path,
num_labels=len(label2id),
label2id=label2id,
id2label=id2label,
finetuning_task="audio-classification",
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
# adapt config with regularization
config.update(
{
"feat_proj_dropout": model_args.feat_proj_dropout,
"attention_dropout": model_args.attention_dropout,
"hidden_dropout": model_args.hidden_dropout,
"final_dropout": model_args.final_dropout,
"mask_time_prob": model_args.mask_time_prob,
"mask_time_length": model_args.mask_time_length,
"mask_feature_prob": model_args.mask_feature_prob,
"mask_feature_length": model_args.mask_feature_length,
"layerdrop": model_args.layerdrop,
"activation_dropout": model_args.activation_dropout,
}
)
model = AutoModelForAudioClassification.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
# freeze the convolutional waveform encoder for wav2vec2-style models
if model_args.freeze_feature_encoder:
if hasattr(model, "freeze_feature_encoder"):
model.freeze_feature_encoder()
else:
raise ValueError("Method for freezing the feature encoder is not defined for Whisper-style models.")
if model_args.freeze_base_model:
if hasattr(model, "freeze_base_model"):
# wav2vec2-style models
model.freeze_base_model()
if hasattr(model, "freeze_feature_encoder"):
model.freeze_feature_encoder()
elif hasattr(model, "freeze_encoder"):
# whisper-style models
model.freeze_encoder()
else:
raise ValueError("Method for freezing the base module of the audio encoder is not defined")
# Initialize our trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=raw_datasets["train"] if training_args.do_train else None,
eval_dataset=raw_datasets["eval"] if training_args.do_eval else None,
compute_metrics=compute_metrics,
tokenizer=feature_extractor,
)
# Training
if training_args.do_train:
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Write model card and (optionally) push to hub
kwargs = {
"finetuned_from": model_args.model_name_or_path,
"tasks": "audio-classification",
"dataset": data_args.train_dataset_name.split("+")[0],
"tags": ["audio-classification"],
}
if training_args.push_to_hub:
trainer.push_to_hub(**kwargs)
else:
trainer.create_model_card(**kwargs)
if __name__ == "__main__":
main()
import os
import sys
from dataclasses import dataclass, field
from pathlib import Path
import numpy as np
from datasets import Audio, concatenate_datasets, load_dataset
from huggingface_hub import get_full_repo_name
from transformers import HfArgumentParser, WhisperTokenizerFast
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name: str = field(
default=None,
metadata={"help": "The name of the dataset to use (via the datasets library)."},
)
dataset_config_name: str = field(
default=None,
metadata={"help": "The configuration name of the dataset to use (via the datasets library)."},
)
dataset_split_name: str = field(
default=None,
metadata={
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
},
)
label_column_name: str = field(
default="labels",
metadata={"help": "The name of the dataset column containing the labels in the dataset. Defaults to 'label'"},
)
text_column_name: str = field(
default="text",
metadata={
"help": "The name of the dataset column containing the text transcriptions in the dataset. Defaults to 'text'"
},
)
speaker_column_name: str = field(
default="speaker_id",
metadata={
"help": "The name of the dataset column containing the speaker ids in the dataset. Defaults to 'speaker_id'"
},
)
dataset_cache_dir: str = field(
default=None,
metadata={"help": "Path to cache directory for saving and loading datasets"},
)
preprocessing_num_workers: int = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
batch_size: int = field(
default=500,
metadata={"help": "Number of examples per batch provided to the preprocessing function."},
)
download_only: bool = field(
default=False,
metadata={"help": "Whether to only do data download and skip pre-processing."},
)
audio_column_name: str = field(
default="audio",
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
)
max_duration_in_seconds: float = field(
default=20.0,
metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"},
)
sampling_rate: int = field(
default=16_000,
metadata={
"help": "Sampling rate at which to resample the audio data. Should be set to the same sampling rate as the target model."
},
)
max_samples: int = field(
default=None,
metadata={
"help": "For debugging purposes, truncate the number of examples in the dataset to this value if set."
},
)
output_dir: str = field(
default=None,
metadata={
"help": "Where to save the processed dataset to disk. If unspecified, uses a 'pretty' version of the "
"original dataset name. E.g. 'facebook/voxpopuli' will be saved under 'voxpopuli'."
},
)
push_to_hub: bool = field(
default=False,
metadata={"help": "Whether or not to push the processed dataset to the Hub."},
)
seed: int = field(
default=0,
metadata={"help": "RNG seed for reproducibility. Used during the final shuffling of the combined dataset."},
)
def convert_dataset_str_to_list(
dataset_names,
dataset_config_names,
splits=None,
label_column_names=None,
text_column_names=None,
speaker_column_names=None,
dataset_samples=None,
default_split="train",
):
if isinstance(dataset_names, str):
dataset_names = dataset_names.split("+")
dataset_config_names = dataset_config_names.split("+")
splits = splits.split("+") if splits is not None else None
label_column_names = label_column_names.split("+") if label_column_names is not None else None
text_column_names = text_column_names.split("+") if text_column_names is not None else None
speaker_column_names = speaker_column_names.split("+") if speaker_column_names is not None else None
dataset_samples = dataset_samples.split("+") if dataset_samples is not None else None
# basic checks to ensure we've got the right number of datasets/configs/splits/columns/probs
if len(dataset_names) != len(dataset_config_names):
raise ValueError(
f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and"
f" {len(dataset_config_names)} configs."
)
if splits is not None and len(splits) != len(dataset_names):
raise ValueError(
f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
)
if label_column_names is not None and len(label_column_names) != len(dataset_names):
raise ValueError(
f"Ensure one label column name is passed for each dataset, got {len(dataset_names)} datasets and"
f" {len(label_column_names)} label column names."
)
if text_column_names is not None and len(text_column_names) != len(dataset_names):
raise ValueError(
f"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and"
f" {len(text_column_names)} text column names."
)
if speaker_column_names is not None and len(speaker_column_names) != len(dataset_names):
raise ValueError(
f"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and"
f" {len(speaker_column_names)} speaker column names."
)
if dataset_samples is not None:
if len(dataset_samples) != len(dataset_names):
raise ValueError(
f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and "
f"{len(dataset_samples)} samples."
)
dataset_samples = [float(ds_sample) for ds_sample in dataset_samples]
else:
dataset_samples = [None] * len(dataset_names)
label_column_names = (
label_column_names if label_column_names is not None else ["labels" for _ in range(len(dataset_names))]
)
text_column_names = (
text_column_names if text_column_names is not None else ["text" for _ in range(len(dataset_names))]
)
speaker_column_names = (
speaker_column_names if speaker_column_names is not None else ["speaker_id" for _ in range(len(dataset_names))]
)
splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
dataset_names_dict = []
for i, ds_name in enumerate(dataset_names):
dataset_names_dict.append(
{
"name": ds_name,
"config": dataset_config_names[i],
"split": splits[i],
"label_column_name": label_column_names[i],
"text_column_name": text_column_names[i],
"speaker_column_name": speaker_column_names[i],
"samples": dataset_samples[i],
}
)
return dataset_names_dict
def main():
# 1. Parse input arguments
parser = HfArgumentParser(DataTrainingArguments)
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
data_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
else:
data_args = parser.parse_args_into_dataclasses()[0]
dataset_names_dict = convert_dataset_str_to_list(
data_args.dataset_name,
data_args.dataset_config_name,
splits=data_args.dataset_split_name,
label_column_names=data_args.label_column_name,
text_column_names=data_args.text_column_name,
speaker_column_names=data_args.speaker_column_name,
)
# load whisper tokenizer for normalisation
sampling_rate = data_args.sampling_rate
tokenizer = WhisperTokenizerFast.from_pretrained("openai/whisper-tiny.en")
max_input_length = int(data_args.max_duration_in_seconds * sampling_rate)
batch_size = data_args.batch_size
preprocessing_num_workers = data_args.preprocessing_num_workers
all_vectorized_datasets = []
for dataset_dict in dataset_names_dict:
print(10 * "=", dataset_dict["name"], 10 * "=")
raw_datasets = load_dataset(
dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
cache_dir=data_args.dataset_cache_dir,
num_proc=data_args.preprocessing_num_workers,
)
if data_args.download_only:
continue
features = raw_datasets.column_names
if dataset_dict["label_column_name"] not in features:
raise ValueError(
f"--label_column_name {dataset_dict['label_column_name']} not found in dataset '{dataset_dict['name']}'. "
"Make sure to set `--label_column_name` to the correct text column - one of "
f"{', '.join(features)}."
)
elif dataset_dict["label_column_name"] != "labels":
raw_datasets = raw_datasets.rename_column(dataset_dict["label_column_name"], "labels")
if dataset_dict["text_column_name"] not in features:
raise ValueError(
f"--text_column_name {dataset_dict['text_column_name']} not found in dataset '{dataset_dict['name']}'. "
"Make sure to set `--text_column_name` to the correct text column - one of "
f"{', '.join(features)}."
)
elif dataset_dict["text_column_name"] != "text":
raw_datasets = raw_datasets.rename_column(dataset_dict["text_column_name"], "text")
if dataset_dict["speaker_column_name"] not in features:
raise ValueError(
f"--speaker_column_name {dataset_dict['speaker_column_name']} not found in dataset '{dataset_dict['name']}'. "
"Make sure to set `--speaker_column_name` to the correct speaker id column - one of "
f"{', '.join(features)}."
)
elif dataset_dict["speaker_column_name"] != "speaker_id":
raw_datasets = raw_datasets.rename_column(dataset_dict["speaker_column_name"], "speaker_id")
raw_datasets = raw_datasets.remove_columns(
set(raw_datasets.features.keys()) - {"audio", "labels", "text", "speaker_id"}
)
if data_args.max_samples is not None:
raw_datasets = raw_datasets.select(range(data_args.max_samples))
raw_datasets = raw_datasets.cast_column(data_args.audio_column_name, Audio(sampling_rate=sampling_rate))
raw_datasets = raw_datasets.sort("speaker_id")
def filter_transcriptions(text):
normalized_text = tokenizer.normalize(text).strip()
return bool(normalized_text) and text.lower() != "ignore_time_segment_in_scoring"
raw_datasets = raw_datasets.filter(
filter_transcriptions, input_columns=["text"], desc="Filtering non-speech transcriptions"
)
def prepare_dataset(batch):
audio = [sample["array"] for sample in batch["audio"]]
input_lengths = [len(sample) for sample in audio]
concatenated_audio = []
concatenated_text = []
concatenated_speaker = []
concatenated_labels = []
audio_sample = audio[0]
text_sample = batch["text"][0]
label_sample = batch["labels"][0]
for idx in range(1, len(audio)):
prev_speaker = batch["speaker_id"][idx - 1]
speaker = batch["speaker_id"][idx]
if len(audio_sample) + input_lengths[idx] < max_input_length:
if speaker == prev_speaker:
# we have no information about whether the segments follow on sequentially
# so we just ensure the same speaker as we concatenate across files
audio_sample = np.append(audio_sample, audio[idx])
# extra spaces in the text transcription don't matter, since we only use it for the WER computation
text_sample += " " + batch["text"][idx]
else:
# segments do not follow sequentially, save the audio and start looping again
concatenated_audio.append(audio_sample)
concatenated_text.append(text_sample)
concatenated_labels.append(label_sample)
concatenated_speaker.append(speaker)
audio_sample = audio[idx]
text_sample = batch["text"][idx]
label_sample = batch["labels"][idx]
else:
# concatenated audio exceeds max length, save the audio and start looping again
concatenated_audio.append(audio_sample)
concatenated_text.append(text_sample)
concatenated_labels.append(label_sample)
concatenated_speaker.append(speaker)
audio_sample = audio[idx]
text_sample = batch["text"][idx]
label_sample = batch["labels"][idx]
batch["audio"] = [{"array": array, "sampling_rate": sampling_rate} for array in concatenated_audio]
batch["text"] = concatenated_text
batch["labels"] = concatenated_labels
batch["speaker_id"] = concatenated_speaker
return batch
raw_datasets = raw_datasets.map(
prepare_dataset,
batched=True,
batch_size=batch_size,
num_proc=preprocessing_num_workers,
desc="Concatenating dataset...",
)
pretty_name = dataset_dict["name"].split("/")[-1]
def postprocess_ids(speaker_id, idx):
formatted_idx = f"{pretty_name}-{speaker_id}-{idx}"
return {"id": formatted_idx}
raw_datasets = raw_datasets.map(
postprocess_ids,
input_columns=["speaker_id"],
with_indices=True,
desc="Setting sample idxs...",
num_proc=preprocessing_num_workers,
)
print(f"Final length {pretty_name}: ", len(raw_datasets))
# Re-format transcriptions and condition on prev as numpy arrays
raw_datasets = raw_datasets.with_format("np")
all_vectorized_datasets.append(raw_datasets)
all_vectorized_datasets = concatenate_datasets(all_vectorized_datasets)
dataset_features = all_vectorized_datasets.features.copy()
dataset_features["audio"] = Audio(sampling_rate=sampling_rate)
all_vectorized_datasets = all_vectorized_datasets.cast(
dataset_features, batch_size=batch_size, writer_batch_size=batch_size, num_proc=preprocessing_num_workers
)
all_vectorized_datasets = all_vectorized_datasets.shuffle(seed=data_args.seed)
all_vectorized_datasets.save_to_disk(data_args.output_dir)
repo_name = get_full_repo_name(Path(data_args.output_dir).absolute().name)
if data_args.push_to_hub:
all_vectorized_datasets.push_to_hub(repo_name, config_name="train", max_shard_size="1GB")
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment