Commit 43087d4a authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

clean training script

parent c734f3ec
......@@ -52,9 +52,7 @@ from transformers import (
from transformers.trainer_pt_utils import LengthGroupedSampler
from transformers import pipeline
from transformers.optimization import get_scheduler
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
from transformers.integrations import is_wandb_available
from transformers.utils import send_example_telemetry
from transformers import AutoModel
......@@ -68,13 +66,7 @@ from parler_tts import (
build_delay_pattern_mask,
)
if is_wandb_available():
from wandb import Audio
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.38.0.dev0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
from wandb import Audio
logger = logging.getLogger(__name__)
......@@ -202,7 +194,6 @@ class ModelArguments:
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
# TODO: pretrain from scratch
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
......@@ -256,9 +247,18 @@ class ModelArguments:
metadata={"help": "Generation max length."},
)
bandwidth: float = field(
default=6, # TODO
default=6,
metadata={"help": "Audio encoder bandwidth."},
)
asr_model_name_or_path: str = field(
default="distil-whisper/distil-large-v2",
metadata={"help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"}
)
clap_model_name_or_path: str = field(
default="laion/larger_clap_music_and_speech",
metadata={"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"}
)
@dataclass
......@@ -333,17 +333,17 @@ class DataTrainingArguments:
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
},
)
target_audio_column_name: str = field( # TODO
target_audio_column_name: str = field(
default="audio",
metadata={"help": "The name of the dataset column containing the target audio data. Defaults to 'audio'"},
)
description_column_name: str = field( # TODO
description_column_name: str = field(
default=None,
metadata={"help": "The name of the dataset column containing the text data. Defaults to 'None'."},
metadata={"help": "The name of the dataset column containing the description text data. Defaults to 'None'."},
)
prompt_column_name: str = field( # TODO
prompt_column_name: str = field(
default=None,
metadata={"help": "The name of the dataset column containing the text data. Defaults to 'None'."},
metadata={"help": "The name of the dataset column containing the prompt text data. Defaults to 'None'."},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
......@@ -482,9 +482,9 @@ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
)
},
)
audio_encode_per_device_eval_batch_size: int = field(
audio_encoder_per_device_batch_size: int = field(
default=8,
metadata={"help": ("TODO")},
metadata={"help": ("Specify the batch size of the audio encoding pre-processing steps.")},
)
......@@ -521,8 +521,6 @@ class DataCollatorParlerTTSWithPadding:
The prompt_tokenizer used for proccessing the data.
description_tokenizer (:class:`~transformers.AutoTokenizer`)
The description_tokenizer used for proccessing the data.
audio_feature_extractor (:class:`~transformers.AutoFeatureExtractor`)
The audio_feature_extractor used for proccessing the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
......@@ -540,8 +538,6 @@ class DataCollatorParlerTTSWithPadding:
prompt_tokenizer: AutoTokenizer
description_tokenizer: AutoTokenizer
audio_feature_extractor: AutoFeatureExtractor
feature_extractor_input_name: Optional[str] = "input_values"
padding: Union[bool, str] = "longest"
pad_to_multiple_of: Optional[int] = None
prompt_max_length: Optional[int] = None
......@@ -588,15 +584,6 @@ class DataCollatorParlerTTSWithPadding:
if "attention_mask" in prompt_input_ids:
batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"]
if self.feature_extractor_input_name in features[0]:
# TODO (YL): verify that it works - IMPORTANT -> probably not working
input_values = [
{self.feature_extractor_input_name: feature[self.feature_extractor_input_name]} for feature in features
]
input_values = self.feature_extractor.pad(input_values, return_tensors="pt")
batch[self.feature_extractor_input_name : input_values]
return batch
......@@ -1019,7 +1006,6 @@ def main():
)
# 3. Next, let's load the config.
# TODO(YL): add the option to create the config from scratch
config = ParlerTTSConfig.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
......@@ -1028,7 +1014,6 @@ def main():
)
# update pad token id and decoder_start_token_id
# TODO(YL): verify if this makes sense, maybe should do it for model.decoder
config.update(
{
"pad_token_id": model_args.pad_token_id
......@@ -1040,7 +1025,7 @@ def main():
}
)
# create model + TODO(YL): not from_pretrained probably
# create model
model = ParlerTTSForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
......@@ -1076,7 +1061,6 @@ def main():
# Freeze Encoders
model.freeze_encoders(model_args.freeze_text_encoder)
# TODO: remove when releasing
# Test all gather - used for warmout and avoiding timeout
test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
gathered_tensor = accelerator.gather(test_tensor)
......@@ -1100,7 +1084,6 @@ def main():
batch = {}
batch["input_ids"] = description_tokenizer(description.strip())["input_ids"]
# TODO: add possibility to train without description column
batch["prompt_input_ids"] = prompt_tokenizer(prompt.strip())["input_ids"]
return batch
......@@ -1154,7 +1137,7 @@ def main():
for split in vectorized_datasets:
data_loader = DataLoader(
raw_datasets[split],
batch_size=training_args.audio_encode_per_device_eval_batch_size,
batch_size=training_args.audio_encoder_per_device_batch_size,
collate_fn=encoder_data_collator,
num_workers=training_args.dataloader_num_workers,
pin_memory=True,
......@@ -1221,7 +1204,6 @@ def main():
output = {"labels": labels[:, 1:]}
return output
# TODO(YL): done multiple times, how to deal with it.
with accelerator.main_process_first():
vectorized_datasets[split] = vectorized_datasets[split].map(
postprocess_dataset,
......@@ -1302,9 +1284,9 @@ def main():
# Let's use word CLAP similary and WER metrics as our evaluation metrics,
# Define evaluation metrics during training, *i.e.* CLAP similarity TODO: allow using another CLAP
clap = AutoModel.from_pretrained("laion/larger_clap_music_and_speech")
clap_processor = AutoProcessor.from_pretrained("laion/larger_clap_music_and_speech")
# Define evaluation metrics during training, *i.e.* CLAP similarity
clap = AutoModel.from_pretrained(model_args.clap_model_name_or_path)
clap_processor = AutoProcessor.from_pretrained(model_args.clap_model_name_or_path)
metric = evaluate.load("wer")
def clap_similarity(texts, audios, device):
......@@ -1323,7 +1305,7 @@ def main():
return cosine_sim.mean().to("cpu")
def wer(prompts, audios, device):
asr_pipeline = pipeline(model="distil-whisper/distil-large-v2", device=device)
asr_pipeline = pipeline(model=model_args.asr_model_name_or_path, device=device)
transcriptions = asr_pipeline(
[{"raw": audio, "sampling_rate": sampling_rate} for audio in audios],
batch_size=int(training_args.per_device_eval_batch_size),
......@@ -1394,8 +1376,6 @@ def main():
# Instantiate custom data collator
data_collator = DataCollatorParlerTTSWithPadding(
audio_feature_extractor=feature_extractor,
feature_extractor_input_name=feature_extractor_input_name,
prompt_tokenizer=prompt_tokenizer,
description_tokenizer=description_tokenizer,
pad_to_multiple_of=data_args.pad_to_multiple_of,
......@@ -1531,7 +1511,6 @@ def main():
outputs = model(**batch)
# CE (data) loss
ce_loss = outputs.loss
# TODO: add CE per codebook
metrics = {"loss": ce_loss}
return ce_loss, metrics
......@@ -1578,8 +1557,9 @@ def main():
for epoch in range(epochs_trained, num_epochs):
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
# TODO(YL): add args
sampler = LengthGroupedSampler(train_batch_size, lengths=vectorized_datasets["train"]["target_length"])
sampler = None
if training_args.group_by_length:
sampler = LengthGroupedSampler(train_batch_size, lengths=vectorized_datasets["train"]["target_length"])
train_dataloader = DataLoader(
vectorized_datasets["train"],
collate_fn=data_collator,
......@@ -1631,7 +1611,7 @@ def main():
# save checkpoint and weights after each save_steps and at the end of training
if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps:
intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
# safe_serialization=False to avoid shared tensors saving issue (TODO: it's a temporary fix)
# safe_serialization=False to avoid shared tensors saving issue (TODO(YL): it's a temporary fix)
# https://github.com/huggingface/transformers/issues/27293#issuecomment-1872560074
accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False)
accelerator.wait_for_everyone()
......@@ -1701,8 +1681,6 @@ def main():
):
generated_audios = generate_step(batch)
# Gather all predictions and targets
# TODO: also add prompt ids
# TODO: better gather
generated_audios, input_ids, prompts = accelerator.pad_across_processes(
(generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment