Unverified Commit 69646e79 authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Merge branch 'huggingface:main' into nits-improvements

parents 0bab56b7 9232a47b
import torch import torch
import evaluate import evaluate
from transformers import AutoModel, AutoProcessor, pipeline from transformers import AutoModel, AutoProcessor, pipeline, WhisperForConditionalGeneration, WhisperTokenizer, WhisperTokenizerFast
def clap_similarity(clap_model_name_or_path, texts, audios, device): def clap_similarity(clap_model_name_or_path, texts, audios, device):
...@@ -24,13 +24,36 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device): ...@@ -24,13 +24,36 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate): def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate):
metric = evaluate.load("wer") metric = evaluate.load("wer")
asr_pipeline = pipeline(model=asr_model_name_or_path, device=device) asr_pipeline = pipeline(model=asr_model_name_or_path, device=device)
return_language = None
if isinstance(asr_pipeline.model, WhisperForConditionalGeneration):
return_language = True
transcriptions = asr_pipeline( transcriptions = asr_pipeline(
[{"raw": audio, "sampling_rate": sampling_rate} for audio in audios], [{"raw": audio, "sampling_rate": sampling_rate} for audio in audios],
batch_size=int(per_device_eval_batch_size), batch_size=int(per_device_eval_batch_size),
return_language=return_language,
) )
word_error = 100 * metric.compute( if isinstance(asr_pipeline.tokenizer, (WhisperTokenizer, WhisperTokenizerFast)):
predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts] tokenizer = asr_pipeline.tokenizer
) else:
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-large-v3")
english_normalizer = tokenizer.normalize
basic_normalizer = tokenizer.basic_normalize
normalized_predictions = []
normalized_references = []
for pred, ref in zip(transcriptions, prompts):
normalizer = english_normalizer if hasattr(pred, "language") and pred["language"] == "english" else basic_normalizer
norm_ref = normalizer(ref)
if len(norm_ref) > 0:
norm_pred = normalizer(pred["text"])
normalized_predictions.append(norm_pred)
normalized_references.append(norm_pred)
word_error = 100 * metric.compute(predictions=normalized_predictions, references=normalized_references)
return word_error, [t["text"] for t in transcriptions] return word_error, [t["text"] for t in transcriptions]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment