Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
dbac8899
Unverified
Commit
dbac8899
authored
Jan 03, 2022
by
Patrick von Platen
Committed by
GitHub
Jan 03, 2022
Browse files
[Tests] Correct Wav2Vec2 & WavLM tests (#15015)
* up * up * up
parent
0b4c3a1a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
35 deletions
+13
-35
.github/workflows/self-scheduled.yml
.github/workflows/self-scheduled.yml
+1
-1
tests/test_modeling_tf_wav2vec2.py
tests/test_modeling_tf_wav2vec2.py
+7
-8
tests/test_modeling_wavlm.py
tests/test_modeling_wavlm.py
+5
-26
No files found.
.github/workflows/self-scheduled.yml
View file @
dbac8899
...
@@ -290,7 +290,7 @@ jobs:
...
@@ -290,7 +290,7 @@ jobs:
-
name
:
Install dependencies
-
name
:
Install dependencies
run
:
|
run
:
|
apt -y update && apt install -y libsndfile1-dev git
apt -y update && apt install -y libsndfile1-dev git
espeak-ng
pip install --upgrade pip
pip install --upgrade pip
pip install .[sklearn,testing,onnx,sentencepiece,tf-speech,vision]
pip install .[sklearn,testing,onnx,sentencepiece,tf-speech,vision]
pip install https://github.com/kpu/kenlm/archive/master.zip
pip install https://github.com/kpu/kenlm/archive/master.zip
...
...
tests/test_modeling_tf_wav2vec2.py
View file @
dbac8899
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
import
copy
import
copy
import
glob
import
inspect
import
inspect
import
math
import
math
import
unittest
import
unittest
...
@@ -23,6 +24,7 @@ import numpy as np
...
@@ -23,6 +24,7 @@ import numpy as np
import
pytest
import
pytest
from
datasets
import
load_dataset
from
datasets
import
load_dataset
from
huggingface_hub
import
snapshot_download
from
transformers
import
Wav2Vec2Config
,
is_tf_available
from
transformers
import
Wav2Vec2Config
,
is_tf_available
from
transformers.file_utils
import
is_librosa_available
,
is_pyctcdecode_available
from
transformers.file_utils
import
is_librosa_available
,
is_pyctcdecode_available
from
transformers.testing_utils
import
require_librosa
,
require_pyctcdecode
,
require_tf
,
slow
from
transformers.testing_utils
import
require_librosa
,
require_pyctcdecode
,
require_tf
,
slow
...
@@ -485,8 +487,6 @@ class TFWav2Vec2UtilsTest(unittest.TestCase):
...
@@ -485,8 +487,6 @@ class TFWav2Vec2UtilsTest(unittest.TestCase):
@
slow
@
slow
class
TFWav2Vec2ModelIntegrationTest
(
unittest
.
TestCase
):
class
TFWav2Vec2ModelIntegrationTest
(
unittest
.
TestCase
):
def
_load_datasamples
(
self
,
num_samples
):
def
_load_datasamples
(
self
,
num_samples
):
from
datasets
import
load_dataset
ds
=
load_dataset
(
"hf-internal-testing/librispeech_asr_dummy"
,
"clean"
,
split
=
"validation"
)
ds
=
load_dataset
(
"hf-internal-testing/librispeech_asr_dummy"
,
"clean"
,
split
=
"validation"
)
# automatic decoding with librispeech
# automatic decoding with librispeech
speech_samples
=
ds
.
sort
(
"id"
).
filter
(
speech_samples
=
ds
.
sort
(
"id"
).
filter
(
...
@@ -556,18 +556,17 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
...
@@ -556,18 +556,17 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
@
require_pyctcdecode
@
require_pyctcdecode
@
require_librosa
@
require_librosa
def
test_wav2vec2_with_lm
(
self
):
def
test_wav2vec2_with_lm
(
self
):
ds
=
load_dataset
(
"common_voice"
,
"es"
,
split
=
"test"
,
streaming
=
True
)
downloaded_folder
=
snapshot_download
(
"patrickvonplaten/common_voice_es_sample"
)
sample
=
next
(
iter
(
ds
))
file_path
=
glob
.
glob
(
downloaded_folder
+
"/*"
)[
0
]
sample
=
librosa
.
load
(
file_path
,
sr
=
16_000
)[
0
]
resampled_audio
=
librosa
.
resample
(
sample
[
"audio"
][
"array"
],
48_000
,
16_000
)
model
=
TFWav2Vec2ForCTC
.
from_pretrained
(
"patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm"
)
model
=
TFWav2Vec2ForCTC
.
from_pretrained
(
"patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm"
)
processor
=
Wav2Vec2ProcessorWithLM
.
from_pretrained
(
"patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm"
)
processor
=
Wav2Vec2ProcessorWithLM
.
from_pretrained
(
"patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm"
)
input_values
=
processor
(
re
sample
d_audio
,
return_tensors
=
"tf"
).
input_values
input_values
=
processor
(
sample
,
return_tensors
=
"tf"
).
input_values
logits
=
model
(
input_values
).
logits
logits
=
model
(
input_values
).
logits
transcription
=
processor
.
batch_decode
(
logits
.
numpy
()).
text
transcription
=
processor
.
batch_decode
(
logits
.
numpy
()).
text
self
.
assertEqual
(
transcription
[
0
],
"
bien y qué regalo vas a abrir primero
"
)
self
.
assertEqual
(
transcription
[
0
],
"
el libro ha sido escrito por cervantes
"
)
tests/test_modeling_wavlm.py
View file @
dbac8899
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
# limitations under the License.
# limitations under the License.
""" Testing suite for the PyTorch WavLM model. """
""" Testing suite for the PyTorch WavLM model. """
import
copy
import
math
import
math
import
unittest
import
unittest
...
@@ -452,30 +451,9 @@ class WavLMModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -452,30 +451,9 @@ class WavLMModelTest(ModelTesterMixin, unittest.TestCase):
if
hasattr
(
module
,
"masked_spec_embed"
)
and
module
.
masked_spec_embed
is
not
None
:
if
hasattr
(
module
,
"masked_spec_embed"
)
and
module
.
masked_spec_embed
is
not
None
:
module
.
masked_spec_embed
.
data
.
fill_
(
3
)
module
.
masked_spec_embed
.
data
.
fill_
(
3
)
# overwrite from test_modeling_common
@
unittest
.
skip
(
reason
=
"Feed forward chunking is not implemented for WavLM"
)
# as WavLM is not very precise
def
test_feed_forward_chunking
(
self
):
def
test_feed_forward_chunking
(
self
):
(
pass
original_config
,
inputs_dict
,
)
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
torch
.
manual_seed
(
0
)
config
=
copy
.
deepcopy
(
original_config
)
model
=
model_class
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
hidden_states_no_chunk
=
model
(
**
self
.
_prepare_for_class
(
inputs_dict
,
model_class
))[
0
]
torch
.
manual_seed
(
0
)
config
.
chunk_size_feed_forward
=
1
model
=
model_class
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
hidden_states_with_chunk
=
model
(
**
self
.
_prepare_for_class
(
inputs_dict
,
model_class
))[
0
]
self
.
assertTrue
(
torch
.
allclose
(
hidden_states_no_chunk
,
hidden_states_with_chunk
,
atol
=
1e-2
))
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
...
@@ -528,7 +506,7 @@ class WavLMModelIntegrationTest(unittest.TestCase):
...
@@ -528,7 +506,7 @@ class WavLMModelIntegrationTest(unittest.TestCase):
def
test_inference_large
(
self
):
def
test_inference_large
(
self
):
model
=
WavLMModel
.
from_pretrained
(
"microsoft/wavlm-large"
).
to
(
torch_device
)
model
=
WavLMModel
.
from_pretrained
(
"microsoft/wavlm-large"
).
to
(
torch_device
)
feature_extractor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
feature_extractor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
"microsoft/wavlm-
base-plus
"
,
return_attention_mask
=
True
"microsoft/wavlm-
large
"
,
return_attention_mask
=
True
)
)
input_speech
=
self
.
_load_datasamples
(
2
)
input_speech
=
self
.
_load_datasamples
(
2
)
...
@@ -544,8 +522,9 @@ class WavLMModelIntegrationTest(unittest.TestCase):
...
@@ -544,8 +522,9 @@ class WavLMModelIntegrationTest(unittest.TestCase):
)
)
EXPECTED_HIDDEN_STATES_SLICE
=
torch
.
tensor
(
EXPECTED_HIDDEN_STATES_SLICE
=
torch
.
tensor
(
[[[
0.
161
2
,
0.
4314
],
[
0.
1690
,
0.4344
]],
[[
0.
2086
,
0.1
396
],
[
0.
3014
,
0.0
903
]]]
[[[
0.
212
2
,
0.
0500
],
[
0.
2118
,
0.0563
]],
[[
0.
1353
,
0.1
818
],
[
0.
2453
,
0.0
595
]]]
)
)
self
.
assertTrue
(
torch
.
allclose
(
hidden_states_slice
,
EXPECTED_HIDDEN_STATES_SLICE
,
rtol
=
5e-2
))
self
.
assertTrue
(
torch
.
allclose
(
hidden_states_slice
,
EXPECTED_HIDDEN_STATES_SLICE
,
rtol
=
5e-2
))
def
test_inference_diarization
(
self
):
def
test_inference_diarization
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment