Commit 43087d4a authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

clean training script

parent c734f3ec
...@@ -52,9 +52,7 @@ from transformers import ( ...@@ -52,9 +52,7 @@ from transformers import (
from transformers.trainer_pt_utils import LengthGroupedSampler from transformers.trainer_pt_utils import LengthGroupedSampler
from transformers import pipeline from transformers import pipeline
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from transformers.utils import check_min_version, send_example_telemetry from transformers.utils import send_example_telemetry
from transformers.utils.versions import require_version
from transformers.integrations import is_wandb_available
from transformers import AutoModel from transformers import AutoModel
...@@ -68,13 +66,7 @@ from parler_tts import ( ...@@ -68,13 +66,7 @@ from parler_tts import (
build_delay_pattern_mask, build_delay_pattern_mask,
) )
if is_wandb_available(): from wandb import Audio
from wandb import Audio
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.38.0.dev0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -202,7 +194,6 @@ class ModelArguments: ...@@ -202,7 +194,6 @@ class ModelArguments:
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
""" """
# TODO: pretrain from scratch
model_name_or_path: str = field( model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
) )
...@@ -256,9 +247,18 @@ class ModelArguments: ...@@ -256,9 +247,18 @@ class ModelArguments:
metadata={"help": "Generation max length."}, metadata={"help": "Generation max length."},
) )
bandwidth: float = field( bandwidth: float = field(
default=6, # TODO default=6,
metadata={"help": "Audio encoder bandwidth."}, metadata={"help": "Audio encoder bandwidth."},
) )
asr_model_name_or_path: str = field(
default="distil-whisper/distil-large-v2",
metadata={"help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"}
)
clap_model_name_or_path: str = field(
default="laion/larger_clap_music_and_speech",
metadata={"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"}
)
@dataclass @dataclass
...@@ -333,17 +333,17 @@ class DataTrainingArguments: ...@@ -333,17 +333,17 @@ class DataTrainingArguments:
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`." " librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
}, },
) )
target_audio_column_name: str = field( # TODO target_audio_column_name: str = field(
default="audio", default="audio",
metadata={"help": "The name of the dataset column containing the target audio data. Defaults to 'audio'"}, metadata={"help": "The name of the dataset column containing the target audio data. Defaults to 'audio'"},
) )
description_column_name: str = field( # TODO description_column_name: str = field(
default=None, default=None,
metadata={"help": "The name of the dataset column containing the text data. Defaults to 'None'."}, metadata={"help": "The name of the dataset column containing the description text data. Defaults to 'None'."},
) )
prompt_column_name: str = field( # TODO prompt_column_name: str = field(
default=None, default=None,
metadata={"help": "The name of the dataset column containing the text data. Defaults to 'None'."}, metadata={"help": "The name of the dataset column containing the prompt text data. Defaults to 'None'."},
) )
overwrite_cache: bool = field( overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
...@@ -482,9 +482,9 @@ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments): ...@@ -482,9 +482,9 @@ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
) )
}, },
) )
audio_encode_per_device_eval_batch_size: int = field( audio_encoder_per_device_batch_size: int = field(
default=8, default=8,
metadata={"help": ("TODO")}, metadata={"help": ("Specify the batch size of the audio encoding pre-processing steps.")},
) )
...@@ -521,8 +521,6 @@ class DataCollatorParlerTTSWithPadding: ...@@ -521,8 +521,6 @@ class DataCollatorParlerTTSWithPadding:
The prompt_tokenizer used for proccessing the data. The prompt_tokenizer used for proccessing the data.
description_tokenizer (:class:`~transformers.AutoTokenizer`) description_tokenizer (:class:`~transformers.AutoTokenizer`)
The description_tokenizer used for proccessing the data. The description_tokenizer used for proccessing the data.
audio_feature_extractor (:class:`~transformers.AutoFeatureExtractor`)
The audio_feature_extractor used for proccessing the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index) Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among: among:
...@@ -540,8 +538,6 @@ class DataCollatorParlerTTSWithPadding: ...@@ -540,8 +538,6 @@ class DataCollatorParlerTTSWithPadding:
prompt_tokenizer: AutoTokenizer prompt_tokenizer: AutoTokenizer
description_tokenizer: AutoTokenizer description_tokenizer: AutoTokenizer
audio_feature_extractor: AutoFeatureExtractor
feature_extractor_input_name: Optional[str] = "input_values"
padding: Union[bool, str] = "longest" padding: Union[bool, str] = "longest"
pad_to_multiple_of: Optional[int] = None pad_to_multiple_of: Optional[int] = None
prompt_max_length: Optional[int] = None prompt_max_length: Optional[int] = None
...@@ -588,15 +584,6 @@ class DataCollatorParlerTTSWithPadding: ...@@ -588,15 +584,6 @@ class DataCollatorParlerTTSWithPadding:
if "attention_mask" in prompt_input_ids: if "attention_mask" in prompt_input_ids:
batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"] batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"]
if self.feature_extractor_input_name in features[0]:
# TODO (YL): verify that it works - IMPORTANT -> probably not working
input_values = [
{self.feature_extractor_input_name: feature[self.feature_extractor_input_name]} for feature in features
]
input_values = self.feature_extractor.pad(input_values, return_tensors="pt")
batch[self.feature_extractor_input_name : input_values]
return batch return batch
...@@ -1019,7 +1006,6 @@ def main(): ...@@ -1019,7 +1006,6 @@ def main():
) )
# 3. Next, let's load the config. # 3. Next, let's load the config.
# TODO(YL): add the option to create the config from scratch
config = ParlerTTSConfig.from_pretrained( config = ParlerTTSConfig.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
...@@ -1028,7 +1014,6 @@ def main(): ...@@ -1028,7 +1014,6 @@ def main():
) )
# update pad token id and decoder_start_token_id # update pad token id and decoder_start_token_id
# TODO(YL): verify if this makes sense, maybe should do it for model.decoder
config.update( config.update(
{ {
"pad_token_id": model_args.pad_token_id "pad_token_id": model_args.pad_token_id
...@@ -1040,7 +1025,7 @@ def main(): ...@@ -1040,7 +1025,7 @@ def main():
} }
) )
# create model + TODO(YL): not from_pretrained probably # create model
model = ParlerTTSForConditionalGeneration.from_pretrained( model = ParlerTTSForConditionalGeneration.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
...@@ -1076,7 +1061,6 @@ def main(): ...@@ -1076,7 +1061,6 @@ def main():
# Freeze Encoders # Freeze Encoders
model.freeze_encoders(model_args.freeze_text_encoder) model.freeze_encoders(model_args.freeze_text_encoder)
# TODO: remove when releasing
# Test all gather - used for warmout and avoiding timeout # Test all gather - used for warmout and avoiding timeout
test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device) test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
gathered_tensor = accelerator.gather(test_tensor) gathered_tensor = accelerator.gather(test_tensor)
...@@ -1100,7 +1084,6 @@ def main(): ...@@ -1100,7 +1084,6 @@ def main():
batch = {} batch = {}
batch["input_ids"] = description_tokenizer(description.strip())["input_ids"] batch["input_ids"] = description_tokenizer(description.strip())["input_ids"]
# TODO: add possibility to train without description column
batch["prompt_input_ids"] = prompt_tokenizer(prompt.strip())["input_ids"] batch["prompt_input_ids"] = prompt_tokenizer(prompt.strip())["input_ids"]
return batch return batch
...@@ -1154,7 +1137,7 @@ def main(): ...@@ -1154,7 +1137,7 @@ def main():
for split in vectorized_datasets: for split in vectorized_datasets:
data_loader = DataLoader( data_loader = DataLoader(
raw_datasets[split], raw_datasets[split],
batch_size=training_args.audio_encode_per_device_eval_batch_size, batch_size=training_args.audio_encoder_per_device_batch_size,
collate_fn=encoder_data_collator, collate_fn=encoder_data_collator,
num_workers=training_args.dataloader_num_workers, num_workers=training_args.dataloader_num_workers,
pin_memory=True, pin_memory=True,
...@@ -1221,7 +1204,6 @@ def main(): ...@@ -1221,7 +1204,6 @@ def main():
output = {"labels": labels[:, 1:]} output = {"labels": labels[:, 1:]}
return output return output
# TODO(YL): done multiple times, how to deal with it.
with accelerator.main_process_first(): with accelerator.main_process_first():
vectorized_datasets[split] = vectorized_datasets[split].map( vectorized_datasets[split] = vectorized_datasets[split].map(
postprocess_dataset, postprocess_dataset,
...@@ -1302,9 +1284,9 @@ def main(): ...@@ -1302,9 +1284,9 @@ def main():
# Let's use word CLAP similary and WER metrics as our evaluation metrics, # Let's use word CLAP similary and WER metrics as our evaluation metrics,
# Define evaluation metrics during training, *i.e.* CLAP similarity TODO: allow using another CLAP # Define evaluation metrics during training, *i.e.* CLAP similarity
clap = AutoModel.from_pretrained("laion/larger_clap_music_and_speech") clap = AutoModel.from_pretrained(model_args.clap_model_name_or_path)
clap_processor = AutoProcessor.from_pretrained("laion/larger_clap_music_and_speech") clap_processor = AutoProcessor.from_pretrained(model_args.clap_model_name_or_path)
metric = evaluate.load("wer") metric = evaluate.load("wer")
def clap_similarity(texts, audios, device): def clap_similarity(texts, audios, device):
...@@ -1323,7 +1305,7 @@ def main(): ...@@ -1323,7 +1305,7 @@ def main():
return cosine_sim.mean().to("cpu") return cosine_sim.mean().to("cpu")
def wer(prompts, audios, device): def wer(prompts, audios, device):
asr_pipeline = pipeline(model="distil-whisper/distil-large-v2", device=device) asr_pipeline = pipeline(model=model_args.asr_model_name_or_path, device=device)
transcriptions = asr_pipeline( transcriptions = asr_pipeline(
[{"raw": audio, "sampling_rate": sampling_rate} for audio in audios], [{"raw": audio, "sampling_rate": sampling_rate} for audio in audios],
batch_size=int(training_args.per_device_eval_batch_size), batch_size=int(training_args.per_device_eval_batch_size),
...@@ -1394,8 +1376,6 @@ def main(): ...@@ -1394,8 +1376,6 @@ def main():
# Instantiate custom data collator # Instantiate custom data collator
data_collator = DataCollatorParlerTTSWithPadding( data_collator = DataCollatorParlerTTSWithPadding(
audio_feature_extractor=feature_extractor,
feature_extractor_input_name=feature_extractor_input_name,
prompt_tokenizer=prompt_tokenizer, prompt_tokenizer=prompt_tokenizer,
description_tokenizer=description_tokenizer, description_tokenizer=description_tokenizer,
pad_to_multiple_of=data_args.pad_to_multiple_of, pad_to_multiple_of=data_args.pad_to_multiple_of,
...@@ -1531,7 +1511,6 @@ def main(): ...@@ -1531,7 +1511,6 @@ def main():
outputs = model(**batch) outputs = model(**batch)
# CE (data) loss # CE (data) loss
ce_loss = outputs.loss ce_loss = outputs.loss
# TODO: add CE per codebook
metrics = {"loss": ce_loss} metrics = {"loss": ce_loss}
return ce_loss, metrics return ce_loss, metrics
...@@ -1578,7 +1557,8 @@ def main(): ...@@ -1578,7 +1557,8 @@ def main():
for epoch in range(epochs_trained, num_epochs): for epoch in range(epochs_trained, num_epochs):
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed) vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
# TODO(YL): add args sampler = None
if training_args.group_by_length:
sampler = LengthGroupedSampler(train_batch_size, lengths=vectorized_datasets["train"]["target_length"]) sampler = LengthGroupedSampler(train_batch_size, lengths=vectorized_datasets["train"]["target_length"])
train_dataloader = DataLoader( train_dataloader = DataLoader(
vectorized_datasets["train"], vectorized_datasets["train"],
...@@ -1631,7 +1611,7 @@ def main(): ...@@ -1631,7 +1611,7 @@ def main():
# save checkpoint and weights after each save_steps and at the end of training # save checkpoint and weights after each save_steps and at the end of training
if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps: if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps:
intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}") intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
# safe_serialization=False to avoid shared tensors saving issue (TODO: it's a temporary fix) # safe_serialization=False to avoid shared tensors saving issue (TODO(YL): it's a temporary fix)
# https://github.com/huggingface/transformers/issues/27293#issuecomment-1872560074 # https://github.com/huggingface/transformers/issues/27293#issuecomment-1872560074
accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False) accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
...@@ -1701,8 +1681,6 @@ def main(): ...@@ -1701,8 +1681,6 @@ def main():
): ):
generated_audios = generate_step(batch) generated_audios = generate_step(batch)
# Gather all predictions and targets # Gather all predictions and targets
# TODO: also add prompt ids
# TODO: better gather
generated_audios, input_ids, prompts = accelerator.pad_across_processes( generated_audios, input_ids, prompts = accelerator.pad_across_processes(
(generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0 (generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment