Commit 3170ac02 authored by Dan Lyth's avatar Dan Lyth
Browse files

adding eval.py and simple train.py, re-instating run_parler_tts_training.py

parent 09df5026
import torch
import evaluate
from transformers import AutoModel, AutoProcessor, pipeline
def clap_similarity(clap_model_name_or_path, texts, audios, device):
clap = AutoModel.from_pretrained(clap_model_name_or_path)
clap_processor = AutoProcessor.from_pretrained(clap_model_name_or_path)
clap_inputs = clap_processor(text=texts, audios=audios, padding=True, return_tensors="pt").to(device)
clap.to(device)
with torch.no_grad():
text_features = clap.get_text_features(
clap_inputs["input_ids"], attention_mask=clap_inputs.get("attention_mask", None)
)
audio_features = clap.get_audio_features(clap_inputs["input_features"])
cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8)
clap.to("cpu")
clap_inputs.to("cpu")
return cosine_sim.mean().to("cpu")
def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate):
metric = evaluate.load("wer")
asr_pipeline = pipeline(model=asr_model_name_or_path, device=device)
transcriptions = asr_pipeline(
[{"raw": audio, "sampling_rate": sampling_rate} for audio in audios],
batch_size=int(per_device_eval_batch_size),
)
word_error = 100 * metric.compute(
predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts]
)
return word_error, [t["text"] for t in transcriptions]
\ No newline at end of file
This diff is collapsed.
...@@ -63,7 +63,7 @@ from parler_tts import ( ...@@ -63,7 +63,7 @@ from parler_tts import (
from parler_tts.utils import get_last_checkpoint, rotate_checkpoints, log_pred, log_metric from parler_tts.utils import get_last_checkpoint, rotate_checkpoints, log_pred, log_metric
from parler_tts.arguments import ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments from parler_tts.arguments import ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments
from parler_tts.data import load_multiple_datasets, DataCollatorParlerTTSWithPadding, DataCollatorEncodecWithPadding from parler_tts.data import DataCollatorParlerTTSWithPadding, DataCollatorEncodecWithPadding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -104,7 +104,7 @@ def main(): ...@@ -104,7 +104,7 @@ def main():
padding = "max_length" if data_args.pad_to_max_length else "longest" padding = "max_length" if data_args.pad_to_max_length else "longest"
####### A. Preparation # Accelerator preparation
kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=60))] kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=60))]
if training_args.torch_compile: if training_args.torch_compile:
# TODO(YL): add more compile modes? # TODO(YL): add more compile modes?
...@@ -182,7 +182,7 @@ def main(): ...@@ -182,7 +182,7 @@ def main():
set_seed(training_args.seed) set_seed(training_args.seed)
num_workers = data_args.preprocessing_num_workers num_workers = data_args.preprocessing_num_workers
# 1. First, lett's instantiate the feature extractor, tokenizers and model # 1. First, let's instantiate the feature extractor (DAC), tokenizers and model
# Note for distributed training, the .from_pretrained methods guarantee that only # Note for distributed training, the .from_pretrained methods guarantee that only
# one local process can concurrently download model & vocab. # one local process can concurrently download model & vocab.
...@@ -222,79 +222,7 @@ def main(): ...@@ -222,79 +222,7 @@ def main():
description_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True description_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
# 2. Now, let's load the dataset # 2. Now, let's load the dataset
# TODO add MDS dataset loading here
if data_args.save_to_disk is not None:
os.makedirs(data_args.save_to_disk, exist_ok=True)
# assume that the dataset has been saved to `save_to_disk` if the latter is not empty
dataset_was_precomputed = len(os.listdir(data_args.save_to_disk)) > 0
if dataset_was_precomputed:
vectorized_datasets = datasets.load_from_disk(data_args.save_to_disk)
else:
raw_datasets = DatasetDict()
columns_to_keep = {
"target_audio_column_name": data_args.target_audio_column_name,
"prompt_column_name": data_args.prompt_column_name,
}
if data_args.description_column_name is not None:
columns_to_keep["description_column_name"] = data_args.description_column_name
if training_args.do_train:
raw_datasets["train"] = load_multiple_datasets(
accelerator,
data_args.train_dataset_name,
data_args.train_dataset_config_name,
metadata_dataset_names=data_args.train_metadata_dataset_name,
splits=data_args.train_split_name,
dataset_samples=data_args.train_dataset_samples,
seed=training_args.seed,
cache_dir=model_args.cache_dir,
num_proc=data_args.preprocessing_num_workers,
id_column_name=data_args.id_column_name,
columns_to_keep=columns_to_keep.values(),
prompt_column_name=data_args.prompt_column_name,
audio_column_name=data_args.target_audio_column_name,
sampling_rate=sampling_rate,
logger=logger,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
for key in columns_to_keep:
if columns_to_keep[key] not in raw_datasets["train"].column_names:
raise ValueError(
f"--{key} '{columns_to_keep[key]}' not found in dataset '{data_args.train_dataset_name}'."
f" Make sure to set `--{key}` to the correct audio column - one of"
f" {', '.join(raw_datasets['train'].column_names)}."
)
if data_args.max_train_samples is not None:
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
if training_args.do_eval:
raw_datasets["eval"] = load_multiple_datasets(
accelerator,
data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
data_args.eval_dataset_config_name
if data_args.eval_dataset_config_name
else data_args.train_dataset_config_name,
metadata_dataset_names=data_args.eval_metadata_dataset_name,
splits=data_args.eval_split_name,
cache_dir=model_args.cache_dir,
num_proc=data_args.preprocessing_num_workers,
id_column_name=data_args.id_column_name,
columns_to_keep=columns_to_keep.values(),
prompt_column_name=data_args.prompt_column_name,
audio_column_name=data_args.target_audio_column_name,
sampling_rate=sampling_rate,
logger=logger,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
if data_args.max_eval_samples is not None:
raw_datasets["eval"] = (
raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
)
# 3. Next, let's load the config. # 3. Next, let's load the config.
config = ParlerTTSConfig.from_pretrained( config = ParlerTTSConfig.from_pretrained(
...@@ -330,250 +258,17 @@ def main(): ...@@ -330,250 +258,17 @@ def main():
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
# 4. Now we preprocess the datasets including loading the audio, resampling and normalization # 4. Now we preprocess the datasets including loading the audio, resampling and normalization
# Thankfully, `datasets` takes care of automatically loading and resampling the audio, # TODO add MDS dataset preprocessing here (only thing we'll need is the delay pattern)
# so that we just need to set the correct target sampling rate and normalize the input
# via the `feature_extractor`
# derive max & min input length for sample rate & max duration # derive max & min input length for sample rate & max duration
sampling_rate = feature_extractor.sampling_rate sampling_rate = feature_extractor.sampling_rate
max_target_length = data_args.max_duration_in_seconds * sampling_rate
min_target_length = data_args.min_duration_in_seconds * sampling_rate
target_audio_column_name = data_args.target_audio_column_name
description_column_name = data_args.description_column_name
prompt_column_name = data_args.prompt_column_name
feature_extractor_input_name = feature_extractor.model_input_names[0]
audio_encoder_pad_token_id = config.decoder.pad_token_id
audio_encoder_eos_token_id = config.decoder.eos_token_id
audio_encoder_bos_token_id = model.generation_config.decoder_start_token_id
max_length = model.generation_config.max_length
num_codebooks = model.decoder.config.num_codebooks
bandwidth = model_args.bandwidth
# Freeze Encoders # Freeze Encoders
model.freeze_encoders(model_args.freeze_text_encoder) model.freeze_encoders(model_args.freeze_text_encoder) # TODO check this implementation
# Test all gather - used for warmout and avoiding timeout
test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
gathered_tensor = accelerator.gather(test_tensor)
print("gathered_tensor", gathered_tensor)
accelerator.wait_for_everyone()
if not dataset_was_precomputed:
# Filter on text length
if description_column_name is not None and data_args.max_text_length is not None:
with accelerator.main_process_first():
# filter description that is shorter than max_text_length
raw_datasets = raw_datasets.filter(
lambda x: len(x) < data_args.max_text_length,
num_proc=num_workers,
input_columns=[description_column_name],
)
# Preprocessing the dataset.
# We need to tokenize the texts.
def pass_through_processors(description, prompt):
batch = {}
batch["input_ids"] = description_tokenizer(description.strip())["input_ids"]
batch["prompt_input_ids"] = prompt_tokenizer(prompt.strip())["input_ids"]
return batch
with accelerator.main_process_first():
# this is a trick to avoid to rewrite the entire audio column which takes ages
vectorized_datasets = raw_datasets.map(
pass_through_processors,
remove_columns=next(iter(raw_datasets.values())).column_names,
input_columns=[description_column_name, prompt_column_name],
num_proc=num_workers,
desc="preprocess datasets",
)
# We use Accelerate to perform distributed inference
# T5 doesn't support fp16
autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
# Now we encode the audio labels with encodec.
####### B. Encode audio
logger.info("*** Encode target audio with encodec ***")
# no need to prepare audio_decoder because used for inference without mixed precision
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
if training_args.torch_compile:
audio_decoder = accelerator.prepare_model(model.audio_encoder, evaluation_mode=True)
else:
audio_decoder = model.audio_encoder
encoder_data_collator = DataCollatorEncodecWithPadding(
feature_extractor,
audio_column_name=target_audio_column_name,
feature_extractor_input_name=feature_extractor_input_name,
max_length=max_target_length,
padding=padding,
)
def apply_audio_decoder(batch):
len_audio = batch.pop("len_audio")
audio_decoder.to(batch["input_values"].device).eval()
with torch.no_grad():
labels = audio_decoder.encode(**batch, bandwidth=bandwidth)["audio_codes"]
output = {}
output["len_audio"] = len_audio
# (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
output["labels"] = labels.squeeze(0).transpose(1, 2)
output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / len_audio.max()
return output
for split in vectorized_datasets:
data_loader = DataLoader(
raw_datasets[split],
batch_size=training_args.audio_encoder_per_device_batch_size,
collate_fn=encoder_data_collator,
num_workers=training_args.dataloader_num_workers,
pin_memory=True,
)
data_loader = accelerator.prepare(data_loader)
all_generated_labels = []
all_lens = []
for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process):
generate_labels = apply_audio_decoder(batch)
generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0)
generate_labels = accelerator.gather_for_metrics(generate_labels)
if accelerator.is_main_process:
lab = generate_labels["labels"].cpu().transpose(1, 2).to(torch.int16)
rat = generate_labels["ratio"].cpu().squeeze()
lens = generate_labels["len_audio"].cpu().squeeze()
lab = [l[:, : int(ratio * length)] for (l, ratio, length) in zip(lab, rat, lens)]
all_generated_labels.extend(lab)
all_lens.extend(lens)
# (1, codebooks, seq_len) where seq_len=1
bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id
if accelerator.is_main_process:
tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens})
tmp_labels.save_to_disk(
os.path.join(data_args.temporary_save_to_disk, split),
num_proc=1 if split == "eval" else data_args.preprocessing_num_workers,
)
accelerator.wait_for_everyone()
del all_generated_labels
tmp_labels = datasets.load_from_disk(os.path.join(data_args.temporary_save_to_disk, split))
with accelerator.main_process_first():
vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1)
def postprocess_dataset(labels):
# (1, codebooks, seq_len)
labels = torch.tensor(labels).unsqueeze(0)
# add bos
labels = torch.cat([bos_labels, labels], dim=-1)
labels, delay_pattern_mask = build_delay_pattern_mask(
labels,
bos_token_id=audio_encoder_bos_token_id,
pad_token_id=audio_encoder_eos_token_id,
max_length=labels.shape[-1] + num_codebooks,
num_codebooks=num_codebooks,
)
# the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
# to take care of EOS
# we want labels to look like this:
# - [B, a, b, E, E, E, E]
# - [B, B, c, d, E, E, E]
# - [B, B, B, e, f, E, E]
# - [B, B, B, B, g, h, E]
labels = torch.where(delay_pattern_mask == -1, audio_encoder_eos_token_id, delay_pattern_mask)
# the first timestamp is associated to a row full of BOS, let's get rid of it
# we also remove the last timestampts (full of PAD)
output = {"labels": labels[:, 1:]}
return output
with accelerator.main_process_first():
vectorized_datasets[split] = vectorized_datasets[split].map(
postprocess_dataset,
num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
input_columns=["labels"],
desc="Postprocessing labeling",
)
accelerator.free_memory()
del generate_labels, all_lens
with accelerator.main_process_first():
# NOTE: filtering is done at the end because in the `datasets` library, caching audio files is done after most operations
# caching audio files is time and disk-space consuming, so we want to avoid it at all costs, especially for large (>1Kh) audio datasets.
# That's also why we avoid to concat the processed datasets (vectorized_datasets) with the audio column present in raw_datasets.
def is_audio_in_length_range(length):
return length > min_target_length and length < max_target_length
# filter data that is shorter than min_target_length
vectorized_datasets = vectorized_datasets.filter(
is_audio_in_length_range,
num_proc=num_workers,
input_columns=["target_length"],
)
if description_column_name is not None and data_args.max_description_token_length is not None:
with accelerator.main_process_first():
# filter description that is shorter than max_text_length
vectorized_datasets = vectorized_datasets.filter(
lambda x: len(x) < data_args.max_description_token_length,
num_proc=num_workers,
input_columns=["input_ids"],
)
if data_args.max_prompt_token_length is not None:
with accelerator.main_process_first():
# filter description that is shorter than max_text_length
vectorized_datasets = vectorized_datasets.filter(
lambda x: len(x) < data_args.max_prompt_token_length,
num_proc=num_workers,
input_columns=["prompt_input_ids"],
)
if data_args.save_to_disk is not None and not dataset_was_precomputed:
if accelerator.is_main_process:
vectorized_datasets.save_to_disk(
data_args.save_to_disk,
num_proc=min(data_args.preprocessing_num_workers, len(vectorized_datasets["eval"]) - 1),
)
logger.info(f"Dataset saved at {data_args.save_to_disk}")
audio_max_length = None
if training_args.torch_compile:
audio_max_length = max(vectorized_datasets["train"]["target_length"])
with accelerator.main_process_first():
max_sample = vectorized_datasets["train"].filter(
lambda x: x == audio_max_length,
num_proc=num_workers,
input_columns=["target_length"],
)
audio_max_length = torch.tensor(max_sample[0]["labels"]).shape[1]
# for large datasets it is advised to run the preprocessing on a
# single machine first with ``args.preprocessing_only`` since there will mostly likely
# be a timeout when running the script in distributed mode.
# In a second step ``args.preprocessing_only`` can then be set to `False` to load the
# cached dataset
if data_args.preprocessing_only and data_args.save_to_disk is None:
raise ValueError(
"`preprocessing_only=True` but `save_to_disk` is not set. The latter should indicates where to save the dataset locally."
)
elif data_args.preprocessing_only:
logger.info(f"Data preprocessing finished. Files save at {data_args.save_to_disk}")
return
# 6. Next, we can prepare the training. # 6. Next, we can prepare the training.
# Let's use word CLAP similary and WER metrics as our evaluation metrics, # Let's use word CLAP similary and WER metrics as our evaluation metrics # TODO move this to seperate file
# Define evaluation metrics during training, *i.e.* CLAP similarity # Define evaluation metrics during training, *i.e.* CLAP similarity
clap = AutoModel.from_pretrained(model_args.clap_model_name_or_path) clap = AutoModel.from_pretrained(model_args.clap_model_name_or_path)
...@@ -630,7 +325,7 @@ def main(): ...@@ -630,7 +325,7 @@ def main():
if training_args.max_steps < 0: if training_args.max_steps < 0:
num_epochs = int(training_args.num_train_epochs) num_epochs = int(training_args.num_train_epochs)
steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps) steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps) # TODO fix this missing variable
total_train_steps = steps_per_epoch * num_epochs total_train_steps = steps_per_epoch * num_epochs
elif training_args.max_steps > 0: elif training_args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs") logger.info("max_steps is given, it will override any value given in num_train_epochs")
...@@ -673,7 +368,7 @@ def main(): ...@@ -673,7 +368,7 @@ def main():
padding=padding, padding=padding,
prompt_max_length=data_args.max_prompt_token_length, prompt_max_length=data_args.max_prompt_token_length,
description_max_length=data_args.max_description_token_length, description_max_length=data_args.max_description_token_length,
audio_max_length=audio_max_length, audio_max_length=audio_max_length, # TODO add this variable
) )
# Prepare everything with accelerate # Prepare everything with accelerate
...@@ -869,6 +564,7 @@ def main(): ...@@ -869,6 +564,7 @@ def main():
resume_step = None resume_step = None
for batch in train_dataloader: for batch in train_dataloader:
breakpoint()
with accelerator.accumulate(model): with accelerator.accumulate(model):
loss, train_metric = train_step(batch, accelerator, autocast_kwargs) loss, train_metric = train_step(batch, accelerator, autocast_kwargs)
accelerator.backward(loss) accelerator.backward(loss)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment