"vscode:/vscode.git/clone" did not exist on "653076ca307520ee85fd5f5de6918019f8521bb5"
Unverified Commit e1f6e490 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Fix integration tests for TFWav2Vec2 and TFHubert

parent 41cd52a7
......@@ -511,14 +511,12 @@ class TFHubertModelIntegrationTest(unittest.TestCase):
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_ctc_normal_batched(self):
model = TFHubertForCTC.from_pretrained("facebook/hubert-base-ls960")
processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-base-ls960", do_lower_case=True)
model = TFHubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft")
processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft", do_lower_case=True)
input_speech = self._load_datasamples(2)
input_values = processor(
input_speech, return_tensors="tf", padding=True, truncation=True, sampling_rate=16000
).input_values
input_values = processor(input_speech, return_tensors="tf", padding=True, sampling_rate=16000).input_values
logits = model(input_values).logits
......@@ -527,7 +525,7 @@ class TFHubertModelIntegrationTest(unittest.TestCase):
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight lowing cloth that was the only garment he wore",
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
......@@ -537,20 +535,20 @@ class TFHubertModelIntegrationTest(unittest.TestCase):
input_speech = self._load_datasamples(4)
inputs = processor(input_speech, return_tensors="tf", padding=True, truncation=True)
inputs = processor(input_speech, return_tensors="tf", padding=True, sampling_rate=16000)
input_values = inputs.input_values
attention_mask = inputs.attention_mask
logits = model(input_values, attention_mask=attention_mask).logits
predicted_ids = tf.argmax(logits, dim=-1)
predicted_ids = tf.argmax(logits, axis=-1)
predicted_trans = processor.batch_decode(predicted_ids)
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
"the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about",
"his instant panic was followed by a small sharp blow high on his chest",
"his instant of panic was followed by a small sharp blow high on his chest",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
......@@ -516,9 +516,7 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
input_speech = self._load_datasamples(2)
input_values = processor(
input_speech, return_tensors="tf", padding=True, truncation=True, sampling_rate=16000
).input_values
input_values = processor(input_speech, return_tensors="tf", padding=True, sampling_rate=16000).input_values
logits = model(input_values).logits
......@@ -537,14 +535,14 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
input_speech = self._load_datasamples(4)
inputs = processor(input_speech, return_tensors="tf", padding=True, truncation=True)
inputs = processor(input_speech, return_tensors="tf", padding=True, sampling_rate=16000)
input_values = inputs.input_values
attention_mask = inputs.attention_mask
logits = model(input_values, attention_mask=attention_mask).logits
predicted_ids = tf.argmax(logits, dim=-1)
predicted_ids = tf.argmax(logits, axis=-1)
predicted_trans = processor.batch_decode(predicted_ids)
EXPECTED_TRANSCRIPTIONS = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment