Unverified Commit dbac8899 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Tests] Correct Wav2Vec2 & WavLM tests (#15015)

* up

* up

* up
parent 0b4c3a1a
...@@ -290,7 +290,7 @@ jobs: ...@@ -290,7 +290,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
apt -y update && apt install -y libsndfile1-dev git apt -y update && apt install -y libsndfile1-dev git espeak-ng
pip install --upgrade pip pip install --upgrade pip
pip install .[sklearn,testing,onnx,sentencepiece,tf-speech,vision] pip install .[sklearn,testing,onnx,sentencepiece,tf-speech,vision]
pip install https://github.com/kpu/kenlm/archive/master.zip pip install https://github.com/kpu/kenlm/archive/master.zip
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import copy import copy
import glob
import inspect import inspect
import math import math
import unittest import unittest
...@@ -23,6 +24,7 @@ import numpy as np ...@@ -23,6 +24,7 @@ import numpy as np
import pytest import pytest
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import snapshot_download
from transformers import Wav2Vec2Config, is_tf_available from transformers import Wav2Vec2Config, is_tf_available
from transformers.file_utils import is_librosa_available, is_pyctcdecode_available from transformers.file_utils import is_librosa_available, is_pyctcdecode_available
from transformers.testing_utils import require_librosa, require_pyctcdecode, require_tf, slow from transformers.testing_utils import require_librosa, require_pyctcdecode, require_tf, slow
...@@ -485,8 +487,6 @@ class TFWav2Vec2UtilsTest(unittest.TestCase): ...@@ -485,8 +487,6 @@ class TFWav2Vec2UtilsTest(unittest.TestCase):
@slow @slow
class TFWav2Vec2ModelIntegrationTest(unittest.TestCase): class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples): def _load_datasamples(self, num_samples):
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech # automatic decoding with librispeech
speech_samples = ds.sort("id").filter( speech_samples = ds.sort("id").filter(
...@@ -556,18 +556,17 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -556,18 +556,17 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
@require_pyctcdecode @require_pyctcdecode
@require_librosa @require_librosa
def test_wav2vec2_with_lm(self): def test_wav2vec2_with_lm(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True) downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
sample = next(iter(ds)) file_path = glob.glob(downloaded_folder + "/*")[0]
sample = librosa.load(file_path, sr=16_000)[0]
resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(resampled_audio, return_tensors="tf").input_values input_values = processor(sample, return_tensors="tf").input_values
logits = model(input_values).logits logits = model(input_values).logits
transcription = processor.batch_decode(logits.numpy()).text transcription = processor.batch_decode(logits.numpy()).text
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch WavLM model. """ """ Testing suite for the PyTorch WavLM model. """
import copy
import math import math
import unittest import unittest
...@@ -452,30 +451,9 @@ class WavLMModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -452,30 +451,9 @@ class WavLMModelTest(ModelTesterMixin, unittest.TestCase):
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
module.masked_spec_embed.data.fill_(3) module.masked_spec_embed.data.fill_(3)
# overwrite from test_modeling_common @unittest.skip(reason="Feed forward chunking is not implemented for WavLM")
# as WavLM is not very precise
def test_feed_forward_chunking(self): def test_feed_forward_chunking(self):
( pass
original_config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
torch.manual_seed(0)
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
model.eval()
hidden_states_no_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
torch.manual_seed(0)
config.chunk_size_feed_forward = 1
model = model_class(config)
model.to(torch_device)
model.eval()
hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-2))
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
...@@ -528,7 +506,7 @@ class WavLMModelIntegrationTest(unittest.TestCase): ...@@ -528,7 +506,7 @@ class WavLMModelIntegrationTest(unittest.TestCase):
def test_inference_large(self): def test_inference_large(self):
model = WavLMModel.from_pretrained("microsoft/wavlm-large").to(torch_device) model = WavLMModel.from_pretrained("microsoft/wavlm-large").to(torch_device)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"microsoft/wavlm-base-plus", return_attention_mask=True "microsoft/wavlm-large", return_attention_mask=True
) )
input_speech = self._load_datasamples(2) input_speech = self._load_datasamples(2)
...@@ -544,8 +522,9 @@ class WavLMModelIntegrationTest(unittest.TestCase): ...@@ -544,8 +522,9 @@ class WavLMModelIntegrationTest(unittest.TestCase):
) )
EXPECTED_HIDDEN_STATES_SLICE = torch.tensor( EXPECTED_HIDDEN_STATES_SLICE = torch.tensor(
[[[0.1612, 0.4314], [0.1690, 0.4344]], [[0.2086, 0.1396], [0.3014, 0.0903]]] [[[0.2122, 0.0500], [0.2118, 0.0563]], [[0.1353, 0.1818], [0.2453, 0.0595]]]
) )
self.assertTrue(torch.allclose(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, rtol=5e-2)) self.assertTrue(torch.allclose(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, rtol=5e-2))
def test_inference_diarization(self): def test_inference_diarization(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment