Commit eaf7947b authored by Dan Lyth's avatar Dan Lyth
Browse files

small train.py updates

parent 3170ac02
...@@ -24,8 +24,6 @@ import time ...@@ -24,8 +24,6 @@ import time
from multiprocess import set_start_method from multiprocess import set_start_method
from datetime import timedelta from datetime import timedelta
import evaluate
from tqdm import tqdm from tqdm import tqdm
from pathlib import Path from pathlib import Path
...@@ -33,22 +31,18 @@ import datasets ...@@ -33,22 +31,18 @@ import datasets
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from datasets import DatasetDict, Dataset, IterableDataset, concatenate_datasets from datasets import IterableDataset
from huggingface_hub import Repository, create_repo from huggingface_hub import Repository, create_repo
import transformers import transformers
from transformers import ( from transformers import (
AutoFeatureExtractor, AutoFeatureExtractor,
AutoModel,
AutoProcessor,
AutoTokenizer, AutoTokenizer,
HfArgumentParser HfArgumentParser,
) )
from transformers.trainer_pt_utils import LengthGroupedSampler from transformers.trainer_pt_utils import LengthGroupedSampler
from transformers import pipeline
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from transformers.utils import send_example_telemetry from transformers.utils import send_example_telemetry
from transformers import AutoModel
from accelerate import Accelerator from accelerate import Accelerator
...@@ -57,13 +51,13 @@ from accelerate.utils.memory import release_memory ...@@ -57,13 +51,13 @@ from accelerate.utils.memory import release_memory
from parler_tts import ( from parler_tts import (
ParlerTTSForConditionalGeneration, ParlerTTSForConditionalGeneration,
ParlerTTSConfig, ParlerTTSConfig
build_delay_pattern_mask,
) )
from parler_tts.utils import get_last_checkpoint, rotate_checkpoints, log_pred, log_metric from parler_tts.utils import get_last_checkpoint, rotate_checkpoints, log_pred, log_metric
from parler_tts.arguments import ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments from parler_tts.arguments import ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments
from parler_tts.data import DataCollatorParlerTTSWithPadding, DataCollatorEncodecWithPadding from parler_tts.data import DataCollatorParlerTTSWithPadding
from parler_tts.eval import clap_similarity, wer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -271,47 +265,22 @@ def main(): ...@@ -271,47 +265,22 @@ def main():
# Let's use word CLAP similary and WER metrics as our evaluation metrics # TODO move this to seperate file # Let's use word CLAP similary and WER metrics as our evaluation metrics # TODO move this to seperate file
# Define evaluation metrics during training, *i.e.* CLAP similarity # Define evaluation metrics during training, *i.e.* CLAP similarity
clap = AutoModel.from_pretrained(model_args.clap_model_name_or_path)
clap_processor = AutoProcessor.from_pretrained(model_args.clap_model_name_or_path)
metric = evaluate.load("wer")
def clap_similarity(texts, audios, device):
clap_inputs = clap_processor(text=texts, audios=audios, padding=True, return_tensors="pt").to(device)
clap.to(device)
with torch.no_grad():
text_features = clap.get_text_features(
clap_inputs["input_ids"], attention_mask=clap_inputs.get("attention_mask", None)
)
audio_features = clap.get_audio_features(clap_inputs["input_features"])
cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8)
clap.to("cpu")
clap_inputs.to("cpu")
return cosine_sim.mean().to("cpu")
def wer(prompts, audios, device):
asr_pipeline = pipeline(model=model_args.asr_model_name_or_path, device=device)
transcriptions = asr_pipeline(
[{"raw": audio, "sampling_rate": sampling_rate} for audio in audios],
batch_size=int(training_args.per_device_eval_batch_size),
)
word_error = 100 * metric.compute(
predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts]
)
return word_error, [t["text"] for t in transcriptions]
eval_methods = {"clap": clap_similarity, "wer": wer}
def compute_metrics(audios, descriptions, prompts, device="cpu"): def compute_metrics(audios, descriptions, prompts, device="cpu"):
results = {}
input_ids = descriptions input_ids = descriptions
texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True) texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True) prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
audios = [a.cpu().numpy() for a in audios] audios = [a.cpu().numpy() for a in audios]
results = {"clap": eval_methods["clap"](texts, audios, device)}
word_error, transcriptions = eval_methods["wer"](prompts, audios, device) clap_score = clap_similarity(model_args.clap_model_name_or_path, texts, audios, device)
results["clap"] = clap_score
word_error, transcriptions = wer(model_args.asr_model_name_or_path,
prompts,
audios,
device,
training_args.per_device_eval_batch_size,
sampling_rate)
results["wer"] = word_error results["wer"] = word_error
return results, texts, prompts, audios, transcriptions return results, texts, prompts, audios, transcriptions
...@@ -564,7 +533,6 @@ def main(): ...@@ -564,7 +533,6 @@ def main():
resume_step = None resume_step = None
for batch in train_dataloader: for batch in train_dataloader:
breakpoint()
with accelerator.accumulate(model): with accelerator.accumulate(model):
loss, train_metric = train_step(batch, accelerator, autocast_kwargs) loss, train_metric = train_step(batch, accelerator, autocast_kwargs)
accelerator.backward(loss) accelerator.backward(loss)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment