Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
69646e79
Unverified
Commit
69646e79
authored
May 22, 2024
by
Yoach Lacombe
Committed by
GitHub
May 22, 2024
Browse files
Merge branch 'huggingface:main' into nits-improvements
parents
0bab56b7
9232a47b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
4 deletions
+27
-4
training/eval.py
training/eval.py
+27
-4
No files found.
training/eval.py
View file @
69646e79
import
torch
import
torch
import
evaluate
import
evaluate
from
transformers
import
AutoModel
,
AutoProcessor
,
pipeline
from
transformers
import
AutoModel
,
AutoProcessor
,
pipeline
,
WhisperForConditionalGeneration
,
WhisperTokenizer
,
WhisperTokenizerFast
def
clap_similarity
(
clap_model_name_or_path
,
texts
,
audios
,
device
):
def
clap_similarity
(
clap_model_name_or_path
,
texts
,
audios
,
device
):
...
@@ -24,13 +24,36 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
...
@@ -24,13 +24,36 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
def
wer
(
asr_model_name_or_path
,
prompts
,
audios
,
device
,
per_device_eval_batch_size
,
sampling_rate
):
def
wer
(
asr_model_name_or_path
,
prompts
,
audios
,
device
,
per_device_eval_batch_size
,
sampling_rate
):
metric
=
evaluate
.
load
(
"wer"
)
metric
=
evaluate
.
load
(
"wer"
)
asr_pipeline
=
pipeline
(
model
=
asr_model_name_or_path
,
device
=
device
)
asr_pipeline
=
pipeline
(
model
=
asr_model_name_or_path
,
device
=
device
)
return_language
=
None
if
isinstance
(
asr_pipeline
.
model
,
WhisperForConditionalGeneration
):
return_language
=
True
transcriptions
=
asr_pipeline
(
transcriptions
=
asr_pipeline
(
[{
"raw"
:
audio
,
"sampling_rate"
:
sampling_rate
}
for
audio
in
audios
],
[{
"raw"
:
audio
,
"sampling_rate"
:
sampling_rate
}
for
audio
in
audios
],
batch_size
=
int
(
per_device_eval_batch_size
),
batch_size
=
int
(
per_device_eval_batch_size
),
return_language
=
return_language
,
)
)
word_error
=
100
*
metric
.
compute
(
if
isinstance
(
asr_pipeline
.
tokenizer
,
(
WhisperTokenizer
,
WhisperTokenizerFast
)):
predictions
=
[
t
[
"text"
].
lower
()
for
t
in
transcriptions
],
references
=
[
t
.
lower
()
for
t
in
prompts
]
tokenizer
=
asr_pipeline
.
tokenizer
)
else
:
tokenizer
=
WhisperTokenizer
.
from_pretrained
(
"openai/whisper-large-v3"
)
english_normalizer
=
tokenizer
.
normalize
basic_normalizer
=
tokenizer
.
basic_normalize
normalized_predictions
=
[]
normalized_references
=
[]
for
pred
,
ref
in
zip
(
transcriptions
,
prompts
):
normalizer
=
english_normalizer
if
hasattr
(
pred
,
"language"
)
and
pred
[
"language"
]
==
"english"
else
basic_normalizer
norm_ref
=
normalizer
(
ref
)
if
len
(
norm_ref
)
>
0
:
norm_pred
=
normalizer
(
pred
[
"text"
])
normalized_predictions
.
append
(
norm_pred
)
normalized_references
.
append
(
norm_pred
)
word_error
=
100
*
metric
.
compute
(
predictions
=
normalized_predictions
,
references
=
normalized_references
)
return
word_error
,
[
t
[
"text"
]
for
t
in
transcriptions
]
return
word_error
,
[
t
[
"text"
]
for
t
in
transcriptions
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment