Unverified Commit b4eef63a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Data2Vec] Speed up test (#17660)

parent 5e428b71
...@@ -535,7 +535,7 @@ class Data2VecAudioModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -535,7 +535,7 @@ class Data2VecAudioModelTest(ModelTesterMixin, unittest.TestCase):
def test_mask_feature_prob_ctc(self): def test_mask_feature_prob_ctc(self):
model = Data2VecAudioForCTC.from_pretrained( model = Data2VecAudioForCTC.from_pretrained(
"facebook/data2vec-audio-base-960h", mask_feature_prob=0.2, mask_feature_length=2 "hf-internal-testing/tiny-random-data2vec-seq-class", mask_feature_prob=0.2, mask_feature_length=2
) )
model.to(torch_device).train() model.to(torch_device).train()
processor = Wav2Vec2Processor.from_pretrained( processor = Wav2Vec2Processor.from_pretrained(
...@@ -554,7 +554,7 @@ class Data2VecAudioModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -554,7 +554,7 @@ class Data2VecAudioModelTest(ModelTesterMixin, unittest.TestCase):
attention_mask=batch["attention_mask"].to(torch_device), attention_mask=batch["attention_mask"].to(torch_device),
).logits ).logits
self.assertEqual(logits.shape, (4, 299, 32)) self.assertEqual(logits.shape, (4, 1498, 32))
def test_mask_time_prob_ctc(self): def test_mask_time_prob_ctc(self):
model = Data2VecAudioForCTC.from_pretrained( model = Data2VecAudioForCTC.from_pretrained(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment