Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
dbac8899
Unverified
Commit
dbac8899
authored
Jan 03, 2022
by
Patrick von Platen
Committed by
GitHub
Jan 03, 2022
Browse files
[Tests] Correct Wav2Vec2 & WavLM tests (#15015)
* up * up * up
parent
0b4c3a1a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
35 deletions
+13
-35
.github/workflows/self-scheduled.yml
.github/workflows/self-scheduled.yml
+1
-1
tests/test_modeling_tf_wav2vec2.py
tests/test_modeling_tf_wav2vec2.py
+7
-8
tests/test_modeling_wavlm.py
tests/test_modeling_wavlm.py
+5
-26
No files found.
.github/workflows/self-scheduled.yml
View file @
dbac8899
...
...
@@ -290,7 +290,7 @@ jobs:
-
name
:
Install dependencies
run
:
|
apt -y update && apt install -y libsndfile1-dev git
apt -y update && apt install -y libsndfile1-dev git
espeak-ng
pip install --upgrade pip
pip install .[sklearn,testing,onnx,sentencepiece,tf-speech,vision]
pip install https://github.com/kpu/kenlm/archive/master.zip
...
...
tests/test_modeling_tf_wav2vec2.py
View file @
dbac8899
...
...
@@ -15,6 +15,7 @@
import
copy
import
glob
import
inspect
import
math
import
unittest
...
...
@@ -23,6 +24,7 @@ import numpy as np
import
pytest
from
datasets
import
load_dataset
from
huggingface_hub
import
snapshot_download
from
transformers
import
Wav2Vec2Config
,
is_tf_available
from
transformers.file_utils
import
is_librosa_available
,
is_pyctcdecode_available
from
transformers.testing_utils
import
require_librosa
,
require_pyctcdecode
,
require_tf
,
slow
...
...
@@ -485,8 +487,6 @@ class TFWav2Vec2UtilsTest(unittest.TestCase):
@
slow
class
TFWav2Vec2ModelIntegrationTest
(
unittest
.
TestCase
):
def
_load_datasamples
(
self
,
num_samples
):
from
datasets
import
load_dataset
ds
=
load_dataset
(
"hf-internal-testing/librispeech_asr_dummy"
,
"clean"
,
split
=
"validation"
)
# automatic decoding with librispeech
speech_samples
=
ds
.
sort
(
"id"
).
filter
(
...
...
@@ -556,18 +556,17 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
@
require_pyctcdecode
@
require_librosa
def
test_wav2vec2_with_lm
(
self
):
ds
=
load_dataset
(
"common_voice"
,
"es"
,
split
=
"test"
,
streaming
=
True
)
sample
=
next
(
iter
(
ds
))
resampled_audio
=
librosa
.
resample
(
sample
[
"audio"
][
"array"
],
48_000
,
16_000
)
downloaded_folder
=
snapshot_download
(
"patrickvonplaten/common_voice_es_sample"
)
file_path
=
glob
.
glob
(
downloaded_folder
+
"/*"
)[
0
]
sample
=
librosa
.
load
(
file_path
,
sr
=
16_000
)[
0
]
model
=
TFWav2Vec2ForCTC
.
from_pretrained
(
"patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm"
)
processor
=
Wav2Vec2ProcessorWithLM
.
from_pretrained
(
"patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm"
)
input_values
=
processor
(
re
sample
d_audio
,
return_tensors
=
"tf"
).
input_values
input_values
=
processor
(
sample
,
return_tensors
=
"tf"
).
input_values
logits
=
model
(
input_values
).
logits
transcription
=
processor
.
batch_decode
(
logits
.
numpy
()).
text
self
.
assertEqual
(
transcription
[
0
],
"
bien y qué regalo vas a abrir primero
"
)
self
.
assertEqual
(
transcription
[
0
],
"
el libro ha sido escrito por cervantes
"
)
tests/test_modeling_wavlm.py
View file @
dbac8899
...
...
@@ -14,7 +14,6 @@
# limitations under the License.
""" Testing suite for the PyTorch WavLM model. """
import
copy
import
math
import
unittest
...
...
@@ -452,30 +451,9 @@ class WavLMModelTest(ModelTesterMixin, unittest.TestCase):
if
hasattr
(
module
,
"masked_spec_embed"
)
and
module
.
masked_spec_embed
is
not
None
:
module
.
masked_spec_embed
.
data
.
fill_
(
3
)
# overwrite from test_modeling_common
# as WavLM is not very precise
@
unittest
.
skip
(
reason
=
"Feed forward chunking is not implemented for WavLM"
)
def
test_feed_forward_chunking
(
self
):
(
original_config
,
inputs_dict
,
)
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
torch
.
manual_seed
(
0
)
config
=
copy
.
deepcopy
(
original_config
)
model
=
model_class
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
hidden_states_no_chunk
=
model
(
**
self
.
_prepare_for_class
(
inputs_dict
,
model_class
))[
0
]
torch
.
manual_seed
(
0
)
config
.
chunk_size_feed_forward
=
1
model
=
model_class
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
hidden_states_with_chunk
=
model
(
**
self
.
_prepare_for_class
(
inputs_dict
,
model_class
))[
0
]
self
.
assertTrue
(
torch
.
allclose
(
hidden_states_no_chunk
,
hidden_states_with_chunk
,
atol
=
1e-2
))
pass
@
slow
def
test_model_from_pretrained
(
self
):
...
...
@@ -528,7 +506,7 @@ class WavLMModelIntegrationTest(unittest.TestCase):
def
test_inference_large
(
self
):
model
=
WavLMModel
.
from_pretrained
(
"microsoft/wavlm-large"
).
to
(
torch_device
)
feature_extractor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
"microsoft/wavlm-
base-plus
"
,
return_attention_mask
=
True
"microsoft/wavlm-
large
"
,
return_attention_mask
=
True
)
input_speech
=
self
.
_load_datasamples
(
2
)
...
...
@@ -544,8 +522,9 @@ class WavLMModelIntegrationTest(unittest.TestCase):
)
EXPECTED_HIDDEN_STATES_SLICE
=
torch
.
tensor
(
[[[
0.
161
2
,
0.
4314
],
[
0.
1690
,
0.4344
]],
[[
0.
2086
,
0.1
396
],
[
0.
3014
,
0.0
903
]]]
[[[
0.
212
2
,
0.
0500
],
[
0.
2118
,
0.0563
]],
[[
0.
1353
,
0.1
818
],
[
0.
2453
,
0.0
595
]]]
)
self
.
assertTrue
(
torch
.
allclose
(
hidden_states_slice
,
EXPECTED_HIDDEN_STATES_SLICE
,
rtol
=
5e-2
))
def
test_inference_diarization
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment