Commit 76099f6c authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

generalise to multilingual

parent aca3f5e4
import torch import torch
import evaluate import evaluate
from transformers import AutoModel, AutoProcessor, pipeline from transformers import AutoModel, AutoProcessor, pipeline, WhisperForConditionalGeneration, WhisperTokenizer, WhisperTokenizerFast
def clap_similarity(clap_model_name_or_path, texts, audios, device): def clap_similarity(clap_model_name_or_path, texts, audios, device):
...@@ -24,14 +24,35 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device): ...@@ -24,14 +24,35 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate): def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate):
metric = evaluate.load("wer") metric = evaluate.load("wer")
asr_pipeline = pipeline(model=asr_model_name_or_path, device=device) asr_pipeline = pipeline(model=asr_model_name_or_path, device=device)
return_language = None
if isinstance(asr_pipeline.model, WhisperForConditionalGeneration):
return_language = True
transcriptions = asr_pipeline( transcriptions = asr_pipeline(
[{"raw": audio, "sampling_rate": sampling_rate} for audio in audios], [{"raw": audio, "sampling_rate": sampling_rate} for audio in audios],
batch_size=int(per_device_eval_batch_size), batch_size=int(per_device_eval_batch_size),
return_language=return_language,
) )
normalizer = asr_pipeline.tokenizer.normalize if isinstance(asr_pipeline.tokenizer, (WhisperTokenizer, WhisperTokenizerFast)):
normalized_predictions = [normalizer(t["text"]) for t in transcriptions] tokenizer = asr_pipeline.tokenizer
normalized_references = [normalizer(t) for t in prompts] else:
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-large-v3")
english_normalizer = tokenizer.normalize
basic_normalizer = tokenizer.basic_normalize
normalized_predictions = []
normalized_references = []
for pred, ref in zip(transcriptions, prompts):
normalizer = english_normalizer if hasattr(pred, "language") and pred["language"] == "english" else basic_normalizer
norm_ref = normalizer(ref)
if len(norm_ref) > 0:
norm_pred = normalizer(pred)
normalized_predictions.append(norm_pred)
normalized_references.append(norm_pred)
word_error = 100 * metric.compute(predictions=normalized_predictions, references=normalized_references) word_error = 100 * metric.compute(predictions=normalized_predictions, references=normalized_references)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment