Unverified Commit 3f0febc4 authored by moto's avatar moto Committed by GitHub
Browse files

Fix HF wav2vec2 test (#1585)

parent 7deea259
...@@ -86,7 +86,7 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -86,7 +86,7 @@ class TestHFIntegration(TorchaudioTestCase):
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
# Feature projection # Feature projection
x = torch.randn(3, 10, config['conv_dim'][-1]) x = torch.randn(3, 10, config['conv_dim'][-1])
ref = original.wav2vec2.feature_projection(x) ref = original.wav2vec2.feature_projection(x)[0]
hyp = imported.encoder.feature_projection(x) hyp = imported.encoder.feature_projection(x)
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
# Convolutional Positional Encoder # Convolutional Positional Encoder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment