eval.py 4.94 KB
Newer Older
Yoach Lacombe's avatar
Yoach Lacombe committed
1
import torch
2
3
from torchaudio.pipelines import SQUIM_OBJECTIVE
import torchaudio
4
import evaluate
5
6
7
8
9
10
11
12
13
14
from transformers import (
    AutoModel,
    AutoProcessor,
    pipeline,
    WhisperForConditionalGeneration,
    WhisperTokenizer,
    WhisperTokenizerFast,
)
from accelerate.utils.memory import release_memory
import numpy as np
15
16


17
def clap_similarity(clap_model_name_or_path, texts, audios, device, input_sampling_rate=44100):
18
19
    clap = AutoModel.from_pretrained(clap_model_name_or_path)
    clap_processor = AutoProcessor.from_pretrained(clap_model_name_or_path)
20
21
22
23
24
25
26
27
28
29
    output_sampling_rate = clap_processor.feature_extractor.sampling_rate
    if input_sampling_rate != output_sampling_rate:
        audios = [
            torchaudio.functional.resample(torch.from_numpy(audio), input_sampling_rate, output_sampling_rate).numpy()
            for audio in audios
        ]
    clap_inputs = clap_processor(
        text=texts, audios=audios, padding=True, return_tensors="pt", sampling_rate=output_sampling_rate
    ).to(device)

30
31
32
33
34
35
36
    clap.to(device)
    with torch.no_grad():
        text_features = clap.get_text_features(
            clap_inputs["input_ids"], attention_mask=clap_inputs.get("attention_mask", None)
        )
        audio_features = clap.get_audio_features(clap_inputs["input_features"])

37
38
39
        cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8).mean()

    cosine_sim = cosine_sim.to("cpu")
40
41

    clap.to("cpu")
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    clap, clap_inputs, audio_features, text_features = release_memory(clap, clap_inputs, audio_features, text_features)
    return cosine_sim


def si_sdr(audios, device, input_sampling_rate=44100):
    max_audio_length = 15 * SQUIM_OBJECTIVE.sample_rate
    model = SQUIM_OBJECTIVE.get_model().to((device))

    output_sampling_rate = SQUIM_OBJECTIVE.sample_rate
    if input_sampling_rate != output_sampling_rate:
        audios = [
            torchaudio.functional.resample(
                torch.tensor(audio)[None, :].to(device).float(), input_sampling_rate, output_sampling_rate
            )
            for audio in audios
        ]
58

59
60
61
62
63
64
    def apply_squim(waveform):
        with torch.no_grad():
            waveform = waveform[:, : min(max_audio_length, waveform.shape[1])]
            _, _, sdr_sample = model(waveform)
            sdr_sample = sdr_sample.cpu()[0]
        return sdr_sample
Yoach Lacombe's avatar
Yoach Lacombe committed
65

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    si_sdrs = [apply_squim(audio) for audio in audios]
    audios, model = release_memory(audios, model)
    return si_sdrs


def wer(
    asr_model_name_or_path,
    prompts,
    audios,
    device,
    per_device_eval_batch_size,
    sampling_rate,
    noise_level_to_compute_clean_wer,
    si_sdr_measures,
):
81
    metric = evaluate.load("wer")
82
    asr_pipeline = pipeline(model=asr_model_name_or_path, device=device, chunk_length_s=25.0)
sanchit-gandhi's avatar
sanchit-gandhi committed
83
84
85
86
87

    return_language = None
    if isinstance(asr_pipeline.model, WhisperForConditionalGeneration):
        return_language = True

88
89
90
    transcriptions = asr_pipeline(
        [{"raw": audio, "sampling_rate": sampling_rate} for audio in audios],
        batch_size=int(per_device_eval_batch_size),
sanchit-gandhi's avatar
sanchit-gandhi committed
91
        return_language=return_language,
92
93
    )

sanchit-gandhi's avatar
sanchit-gandhi committed
94
95
96
97
98
99
100
101
102
103
104
105
    if isinstance(asr_pipeline.tokenizer, (WhisperTokenizer, WhisperTokenizerFast)):
        tokenizer = asr_pipeline.tokenizer
    else:
        tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-large-v3")

    english_normalizer = tokenizer.normalize
    basic_normalizer = tokenizer.basic_normalize

    normalized_predictions = []
    normalized_references = []

    for pred, ref in zip(transcriptions, prompts):
106
107
108
109
110
        normalizer = (
            english_normalizer
            if isinstance(pred.get("chunks", None), list) and pred["chunks"][0].get("language", None) == "english"
            else basic_normalizer
        )
sanchit-gandhi's avatar
sanchit-gandhi committed
111
112
        norm_ref = normalizer(ref)
        if len(norm_ref) > 0:
sanchit-gandhi's avatar
fix  
sanchit-gandhi committed
113
            norm_pred = normalizer(pred["text"])
sanchit-gandhi's avatar
sanchit-gandhi committed
114
            normalized_predictions.append(norm_pred)
Yoach Lacombe's avatar
Yoach Lacombe committed
115
            normalized_references.append(norm_ref)
116
117

    word_error = 100 * metric.compute(predictions=normalized_predictions, references=normalized_references)
118

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    clean_word_error = None
    noisy_word_error = None
    percent_clean_samples = 0
    if noise_level_to_compute_clean_wer and si_sdr_measures:
        si_sdr_measures = np.array(si_sdr_measures)
        mask = si_sdr_measures >= noise_level_to_compute_clean_wer
        if mask.any():
            clean_word_error = 100 * metric.compute(
                predictions=np.array(normalized_predictions)[mask], references=np.array(normalized_references)[mask]
            )
            noisy_word_error = 100 * metric.compute(
                predictions=np.array(normalized_predictions)[~mask], references=np.array(normalized_references)[~mask]
            )
            percent_clean_samples = mask.sum() / len(mask)

    asr_pipeline.model.to("cpu")
    asr_pipeline = release_memory(asr_pipeline)
    return word_error, [t["text"] for t in transcriptions], clean_word_error, noisy_word_error, percent_clean_samples