Commit d10775d8 authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

more audio class

parent 94f40c57
...@@ -3,4 +3,91 @@ ...@@ -3,4 +3,91 @@
Work in-progress reproduction of the text-to-speech (TTS) model from the paper [Natural language guidance of high-fidelity text-to-speech with synthetic annotations](https://www.text-description-to-speech.com) Work in-progress reproduction of the text-to-speech (TTS) model from the paper [Natural language guidance of high-fidelity text-to-speech with synthetic annotations](https://www.text-description-to-speech.com)
by Dan Lyth and Simon King, from Stability AI and Edinburgh University respectively. by Dan Lyth and Simon King, from Stability AI and Edinburgh University respectively.
Reproducing the TTS model requires the following 5 steps to be completed in order:
1. Train the Accent Classifier
2. Annotate the Training Set
3. Aggregate Statistics
4. Create Descriptions
5. Train the Model
## Step 1: Train the Accent Classifier
The script [`run_audio_classification.py`](run_audio_classification.py) can be used to train an audio encoder model from
the [Transformers library](https://github.com/huggingface/transformers) (e.g. Wav2Vec2, MMS, or Whisper) for the accent
classification task.
Starting with a pre-trained audio encoder model, a simple linear classifier is appended to the last hidden-layer to map the
audio embeddings to class label predictions. The audio encoder can either be frozen (`--freeze_base_model`) or trained.
The linear classifier is randomly initialised, and is thus always trained.
The script can be used to train on a single accent dataset, or a combination of datasets, which should be specified by
separating dataset names, configs and splits by the `+` character in the launch command (see below for an example).
In the proceeding example, we follow Stability's approach by taking audio embeddings from a frozen [MMS-LID](https://huggingface.co/facebook/mms-lid-126)
model, and training the linear classifier on a combination of three open-source datasets:
1. The English Accented (`en_accented`) subset of [Voxpopuli](https://huggingface.co/datasets/facebook/voxpopuli)
2. The train split of [VCTK](https://huggingface.co/datasets/vctk)
3. The dev split of [EdAcc](https://huggingface.co/datasets/sanchit-gandhi/edacc)
The model is subsequently evaluated on the test split of [EdAcc](https://huggingface.co/datasets/sanchit-gandhi/edacc)
to give the final classification accuracy.
```bash
#!/usr/bin/env bash
python run_audio_classification.py \
--model_name_or_path "facebook/mms-lid-126" \
--train_dataset_name "vctk+facebook/voxpopuli+sanchit-gandhi/edacc" \
--train_dataset_config_name "main+en_accented+default" \
--train_split_name "train+test+validation" \
--train_label_column_name "accent+accent+accent" \
--eval_dataset_name "sanchit-gandhi/edacc" \
--eval_dataset_config_name "default" \
--eval_split_name "test" \
--eval_label_column_name "accent" \
--output_dir "./" \
--do_train \
--do_eval \
--overwrite_output_dir \
--remove_unused_columns False \
--fp16 \
--learning_rate 1e-4 \
--max_length_seconds 20 \
--attention_mask False \
--warmup_ratio 0.1 \
--num_train_epochs 5 \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \
--preprocessing_num_workers 16 \
--dataloader_num_workers 4 \
--logging_strategy "steps" \
--logging_steps 10 \
--evaluation_strategy "epoch" \
--save_strategy "epoch" \
--load_best_model_at_end True \
--metric_for_best_model "accuracy" \
--save_total_limit 3 \
--freeze_base_model \
--push_to_hub \
--trust_remote_code
```
Tips:
1. **Number of labels:** normalisation should be applied to the target class labels to group linguistically similar accents together (e.g. "Southern Irish" and "Irish" should both be "Irish"). This helps _balance_ the dataset by removing labels with very few examples. You can modify the function `preprocess_labels` to implement any custom normalisation strategy.
## Step 2: Annotate the Training Set
Annotate the training dataset with information on: SNR, C50, pitch and speaking rate.
## Step 3: Aggregate Statistics
Aggregate statistics from Step 2. Convert continuous values to discrete labels.
## Step 4: Create Descriptions
Convert sequence of discrete labels to text description (using an LLM).
## Step 4: Train the Model
Train MusicGen-style model on the TTS task.
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
python run_audio_classification.py \ python run_audio_classification.py \
--model_name_or_path "facebook/mms-lid-126" \ --model_name_or_path "facebook/mms-lid-126" \
--train_dataset_name "vctk+facebook/voxpopuli+sanchit-gandhi/edacc" \ --train_dataset_name "sanchit-gandhi/vctk+facebook/voxpopuli+sanchit-gandhi/edacc" \
--train_dataset_config_name "default+en_accented+default" \ --train_dataset_config_name "main+en_accented+default" \
--train_split_name "train+test+validation" \ --train_split_name "train+test+validation" \
--train_label_column_name "accent" \ --train_label_column_name "accent+accent+accent" \
--eval_dataset_name "sanchit-gandhi/edacc" \ --eval_dataset_name "sanchit-gandhi/edacc" \
--eval_dataset_config_name "default" \ --eval_dataset_config_name "default" \
--eval_split_name "test" \ --eval_split_name "test" \
...@@ -17,13 +17,13 @@ python run_audio_classification.py \ ...@@ -17,13 +17,13 @@ python run_audio_classification.py \
--remove_unused_columns False \ --remove_unused_columns False \
--fp16 \ --fp16 \
--learning_rate 1e-4 \ --learning_rate 1e-4 \
--min_length_seconds 5 \
--max_length_seconds 20 \ --max_length_seconds 20 \
--attention_mask False \ --attention_mask False \
--warmup_ratio 0.1 \ --warmup_ratio 0.1 \
--num_train_epochs 5 \ --num_train_epochs 5 \
--per_device_train_batch_size 32 \ --per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \ --per_device_eval_batch_size 32 \
--preprocessing_num_workers 16 \
--dataloader_num_workers 4 \ --dataloader_num_workers 4 \
--logging_strategy "steps" \ --logging_strategy "steps" \
--logging_steps 10 \ --logging_steps 10 \
...@@ -32,6 +32,6 @@ python run_audio_classification.py \ ...@@ -32,6 +32,6 @@ python run_audio_classification.py \
--load_best_model_at_end True \ --load_best_model_at_end True \
--metric_for_best_model "accuracy" \ --metric_for_best_model "accuracy" \
--save_total_limit 3 \ --save_total_limit 3 \
--seed 0 \ --freeze_base_model \
--push_to_hub \ --push_to_hub \
--trust_remote_code --trust_remote_code
...@@ -35,7 +35,7 @@ from transformers import ( ...@@ -35,7 +35,7 @@ from transformers import (
HfArgumentParser, HfArgumentParser,
Trainer, Trainer,
TrainingArguments, TrainingArguments,
set_seed, set_seed, WhisperForAudioClassification,
) )
from transformers.models.whisper.tokenization_whisper import LANGUAGES from transformers.models.whisper.tokenization_whisper import LANGUAGES
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
...@@ -163,10 +163,6 @@ class DataTrainingArguments: ...@@ -163,10 +163,6 @@ class DataTrainingArguments:
) )
}, },
) )
min_length_seconds: float = field(
default=5,
metadata={"help": "Audio clips less than this value will be filtered during training."},
)
max_length_seconds: float = field( max_length_seconds: float = field(
default=20, default=20,
metadata={"help": "Audio clips will be randomly cut to this length during training if the value is set."}, metadata={"help": "Audio clips will be randomly cut to this length during training if the value is set."},
...@@ -201,7 +197,10 @@ class ModelArguments: ...@@ -201,7 +197,10 @@ class ModelArguments:
default=None, metadata={"help": "Name or path of preprocessor config."} default=None, metadata={"help": "Name or path of preprocessor config."}
) )
freeze_feature_encoder: bool = field( freeze_feature_encoder: bool = field(
default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."} default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model. Only relevant for Wav2Vec2-style models."}
)
freeze_base_model: bool = field(
default=True, metadata={"help": "Whether to freeze the base encoder of the model."}
) )
attention_mask: bool = field( attention_mask: bool = field(
default=True, metadata={"help": "Whether to generate an attention mask in the feature extractor."} default=True, metadata={"help": "Whether to generate an attention mask in the feature extractor."}
...@@ -438,6 +437,7 @@ def main(): ...@@ -438,6 +437,7 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=True if model_args.token else None, token=True if model_args.token else None,
trust_remote_code=model_args.trust_remote_code, trust_remote_code=model_args.trust_remote_code,
num_proc=data_args.preprocessing_num_workers,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode # streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
) )
...@@ -466,6 +466,7 @@ def main(): ...@@ -466,6 +466,7 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=True if model_args.token else None, token=True if model_args.token else None,
trust_remote_code=model_args.trust_remote_code, trust_remote_code=model_args.trust_remote_code,
num_proc=data_args.preprocessing_num_workers,
# streaming=data_args.streaming, # streaming=data_args.streaming,
) )
features = raw_datasets[pretty_name].features.keys() features = raw_datasets[pretty_name].features.keys()
...@@ -620,6 +621,17 @@ def main(): ...@@ -620,6 +621,17 @@ def main():
if model_args.freeze_feature_encoder: if model_args.freeze_feature_encoder:
model.freeze_feature_encoder() model.freeze_feature_encoder()
if model_args.freeze_base_model:
if model.hasattr("freeze_base_model"):
# wav2vec2-style models
model.freeze_base_model()
model.freeze_feature_encoder()
elif model.hasattr("freeze_encoder"):
# whisper-style models
model.freeze_encoder()
else:
raise ValueError("Method for freezing the base module of the audio encoder is not defined")
# Initialize our trainer # Initialize our trainer
trainer = Trainer( trainer = Trainer(
model=model, model=model,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment