Commit eaf7947b authored by Dan Lyth's avatar Dan Lyth
Browse files

small train.py updates

parent 3170ac02
......@@ -24,8 +24,6 @@ import time
from multiprocess import set_start_method
from datetime import timedelta
import evaluate
from tqdm import tqdm
from pathlib import Path
......@@ -33,22 +31,18 @@ import datasets
import torch
from torch.utils.data import DataLoader
from datasets import DatasetDict, Dataset, IterableDataset, concatenate_datasets
from datasets import IterableDataset
from huggingface_hub import Repository, create_repo
import transformers
from transformers import (
AutoFeatureExtractor,
AutoModel,
AutoProcessor,
AutoTokenizer,
HfArgumentParser
HfArgumentParser,
)
from transformers.trainer_pt_utils import LengthGroupedSampler
from transformers import pipeline
from transformers.optimization import get_scheduler
from transformers.utils import send_example_telemetry
from transformers import AutoModel
from accelerate import Accelerator
......@@ -57,13 +51,13 @@ from accelerate.utils.memory import release_memory
from parler_tts import (
ParlerTTSForConditionalGeneration,
ParlerTTSConfig,
build_delay_pattern_mask,
ParlerTTSConfig
)
from parler_tts.utils import get_last_checkpoint, rotate_checkpoints, log_pred, log_metric
from parler_tts.arguments import ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments
from parler_tts.data import DataCollatorParlerTTSWithPadding, DataCollatorEncodecWithPadding
from parler_tts.data import DataCollatorParlerTTSWithPadding
from parler_tts.eval import clap_similarity, wer
logger = logging.getLogger(__name__)
......@@ -271,47 +265,22 @@ def main():
# Let's use word CLAP similary and WER metrics as our evaluation metrics # TODO move this to seperate file
# Define evaluation metrics during training, *i.e.* CLAP similarity
clap = AutoModel.from_pretrained(model_args.clap_model_name_or_path)
clap_processor = AutoProcessor.from_pretrained(model_args.clap_model_name_or_path)
metric = evaluate.load("wer")
def clap_similarity(texts, audios, device):
clap_inputs = clap_processor(text=texts, audios=audios, padding=True, return_tensors="pt").to(device)
clap.to(device)
with torch.no_grad():
text_features = clap.get_text_features(
clap_inputs["input_ids"], attention_mask=clap_inputs.get("attention_mask", None)
)
audio_features = clap.get_audio_features(clap_inputs["input_features"])
cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8)
clap.to("cpu")
clap_inputs.to("cpu")
return cosine_sim.mean().to("cpu")
def wer(prompts, audios, device):
asr_pipeline = pipeline(model=model_args.asr_model_name_or_path, device=device)
transcriptions = asr_pipeline(
[{"raw": audio, "sampling_rate": sampling_rate} for audio in audios],
batch_size=int(training_args.per_device_eval_batch_size),
)
word_error = 100 * metric.compute(
predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts]
)
return word_error, [t["text"] for t in transcriptions]
eval_methods = {"clap": clap_similarity, "wer": wer}
def compute_metrics(audios, descriptions, prompts, device="cpu"):
results = {}
input_ids = descriptions
texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
audios = [a.cpu().numpy() for a in audios]
results = {"clap": eval_methods["clap"](texts, audios, device)}
word_error, transcriptions = eval_methods["wer"](prompts, audios, device)
clap_score = clap_similarity(model_args.clap_model_name_or_path, texts, audios, device)
results["clap"] = clap_score
word_error, transcriptions = wer(model_args.asr_model_name_or_path,
prompts,
audios,
device,
training_args.per_device_eval_batch_size,
sampling_rate)
results["wer"] = word_error
return results, texts, prompts, audios, transcriptions
......@@ -564,7 +533,6 @@ def main():
resume_step = None
for batch in train_dataloader:
breakpoint()
with accelerator.accumulate(model):
loss, train_metric = train_step(batch, accelerator, autocast_kwargs)
accelerator.backward(loss)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment