"...composable_kernel.git" did not exist on "a2edd7d802b46737e886f0f42a4ee61af03243b7"
Unverified Commit 5518cc2f authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Merge pull request #52 from sanchit-gandhi/wer-norm

[training] compute normalised wer
parents c2b90bdc ed484586
import torch import torch
import evaluate import evaluate
from transformers import AutoModel, AutoProcessor, pipeline from transformers import AutoModel, AutoProcessor, pipeline, WhisperForConditionalGeneration, WhisperTokenizer, WhisperTokenizerFast
def clap_similarity(clap_model_name_or_path, texts, audios, device): def clap_similarity(clap_model_name_or_path, texts, audios, device):
...@@ -24,13 +24,36 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device): ...@@ -24,13 +24,36 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate): def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate):
metric = evaluate.load("wer") metric = evaluate.load("wer")
asr_pipeline = pipeline(model=asr_model_name_or_path, device=device) asr_pipeline = pipeline(model=asr_model_name_or_path, device=device)
return_language = None
if isinstance(asr_pipeline.model, WhisperForConditionalGeneration):
return_language = True
transcriptions = asr_pipeline( transcriptions = asr_pipeline(
[{"raw": audio, "sampling_rate": sampling_rate} for audio in audios], [{"raw": audio, "sampling_rate": sampling_rate} for audio in audios],
batch_size=int(per_device_eval_batch_size), batch_size=int(per_device_eval_batch_size),
return_language=return_language,
) )
word_error = 100 * metric.compute( if isinstance(asr_pipeline.tokenizer, (WhisperTokenizer, WhisperTokenizerFast)):
predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts] tokenizer = asr_pipeline.tokenizer
) else:
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-large-v3")
english_normalizer = tokenizer.normalize
basic_normalizer = tokenizer.basic_normalize
normalized_predictions = []
normalized_references = []
for pred, ref in zip(transcriptions, prompts):
normalizer = english_normalizer if hasattr(pred, "language") and pred["language"] == "english" else basic_normalizer
norm_ref = normalizer(ref)
if len(norm_ref) > 0:
norm_pred = normalizer(pred["text"])
normalized_predictions.append(norm_pred)
normalized_references.append(norm_pred)
word_error = 100 * metric.compute(predictions=normalized_predictions, references=normalized_references)
return word_error, [t["text"] for t in transcriptions] return word_error, [t["text"] for t in transcriptions]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment