Unverified Commit 1e847b40 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[WavLM] give model for precision (#14958)

parent 1c121916
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch WavLM model. """ """ Testing suite for the PyTorch WavLM model. """
import copy
import math import math
import unittest import unittest
...@@ -451,6 +452,31 @@ class WavLMModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -451,6 +452,31 @@ class WavLMModelTest(ModelTesterMixin, unittest.TestCase):
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
module.masked_spec_embed.data.fill_(3) module.masked_spec_embed.data.fill_(3)
# overwrite from test_modeling_common
# as WavLM is not very precise
def test_feed_forward_chunking(self):
(
original_config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
torch.manual_seed(0)
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
model.eval()
hidden_states_no_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
torch.manual_seed(0)
config.chunk_size_feed_forward = 1
model = model_class(config)
model.to(torch_device)
model.eval()
hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-2))
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
model = WavLMModel.from_pretrained("microsoft/wavlm-base-plus") model = WavLMModel.from_pretrained("microsoft/wavlm-base-plus")
...@@ -497,7 +523,7 @@ class WavLMModelIntegrationTest(unittest.TestCase): ...@@ -497,7 +523,7 @@ class WavLMModelIntegrationTest(unittest.TestCase):
[[[0.0577, 0.1161], [0.0579, 0.1165]], [[0.0199, 0.1237], [0.0059, 0.0605]]] [[[0.0577, 0.1161], [0.0579, 0.1165]], [[0.0199, 0.1237], [0.0059, 0.0605]]]
) )
# TODO: update the tolerance after the CI moves to torch 1.10 # TODO: update the tolerance after the CI moves to torch 1.10
self.assertTrue(torch.allclose(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, atol=1e-2)) self.assertTrue(torch.allclose(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, atol=5e-2))
def test_inference_large(self): def test_inference_large(self):
model = WavLMModel.from_pretrained("microsoft/wavlm-large").to(torch_device) model = WavLMModel.from_pretrained("microsoft/wavlm-large").to(torch_device)
...@@ -520,7 +546,7 @@ class WavLMModelIntegrationTest(unittest.TestCase): ...@@ -520,7 +546,7 @@ class WavLMModelIntegrationTest(unittest.TestCase):
EXPECTED_HIDDEN_STATES_SLICE = torch.tensor( EXPECTED_HIDDEN_STATES_SLICE = torch.tensor(
[[[0.1612, 0.4314], [0.1690, 0.4344]], [[0.2086, 0.1396], [0.3014, 0.0903]]] [[[0.1612, 0.4314], [0.1690, 0.4344]], [[0.2086, 0.1396], [0.3014, 0.0903]]]
) )
self.assertTrue(torch.allclose(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, rtol=1e-2)) self.assertTrue(torch.allclose(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, rtol=5e-2))
def test_inference_diarization(self): def test_inference_diarization(self):
model = WavLMForAudioFrameClassification.from_pretrained("microsoft/wavlm-base-plus-sd").to(torch_device) model = WavLMForAudioFrameClassification.from_pretrained("microsoft/wavlm-base-plus-sd").to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment