Unverified Commit 14d058b9 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[W2V2 with LM] Fix decoder test with params (#21277)

parent 94a7edd9
......@@ -230,7 +230,6 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
self.assertListEqual(logit_scores_decoder, decoded_processor.logit_score)
self.assertListEqual(lm_scores_decoder, decoded_processor.lm_score)
@unittest.skip("Fix me Sanchit")
def test_decoder_with_params(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
......@@ -240,7 +239,7 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
logits = self._get_dummy_logits()
beam_width = 20
beam_width = 15
beam_prune_logp = -20.0
token_min_logp = -4.0
......@@ -264,9 +263,17 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
)
decoded_decoder = [d[0][0] for d in decoded_decoder_out]
logit_scores = [d[0][2] for d in decoded_decoder_out]
lm_scores = [d[0][3] for d in decoded_decoder_out]
self.assertListEqual(decoded_decoder, decoded_processor)
self.assertListEqual(["<s> </s> </s>", "<s> <s> </s>"], decoded_processor)
self.assertListEqual(["</s> <s> <s>", "<s> <s> <s>"], decoded_processor)
self.assertTrue(np.array_equal(logit_scores, decoded_processor_out.logit_score))
self.assertTrue(np.allclose([-20.054, -18.447], logit_scores, atol=1e-3))
self.assertTrue(np.array_equal(lm_scores, decoded_processor_out.lm_score))
self.assertTrue(np.allclose([-15.554, -13.9474], lm_scores, atol=1e-3))
def test_decoder_with_params_of_lm(self):
feature_extractor = self.get_feature_extractor()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment