Commit 7d0f3369 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

[BugFix] Fix extract_features method for WavLM models (#3350)

Summary:
resolve https://github.com/pytorch/audio/issues/3347

`position_bias` is ignored in `extract_features` method, this doesn't affect Wav2Vec2 or HuBERT models, but it changes the output of transformer layers (except the first layer) in WavLM model. This PR fixes it by adding `position_bias` to the method.

Pull Request resolved: https://github.com/pytorch/audio/pull/3350

Reviewed By: mthrok

Differential Revision: D46112148

Pulled By: nateanl

fbshipit-source-id: 3d21aa4b32b22da437b440097fd9b00238152596
parent fce54fd1
...@@ -144,6 +144,14 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -144,6 +144,14 @@ class TestHFIntegration(TorchaudioTestCase):
hyp = imported.encoder.transformer(x) hyp = imported.encoder.transformer(x)
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
# Test get_intermediate_outputs method
b, l, e = 16, 3, config["hidden_size"]
x = torch.randn(b, l, e)
ref = original.encoder(x, output_hidden_states=True).hidden_states
hyp = imported.encoder.transformer.get_intermediate_outputs(x)
for i in range(len(hyp)):
self.assertEqual(ref[i + 1], hyp[i], atol=1e-4, rtol=0.001)
def _test_import_finetune(self, original, imported, config): def _test_import_finetune(self, original, imported, config):
# Aux # Aux
x = torch.randn(3, 10, config["hidden_size"]) x = torch.randn(3, 10, config["hidden_size"])
...@@ -243,6 +251,14 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -243,6 +251,14 @@ class TestHFIntegration(TorchaudioTestCase):
hyp = imported.encoder.transformer(x) hyp = imported.encoder.transformer(x)
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
# Test get_intermediate_outputs method
b, l, e = 16, 3, config["hidden_size"]
x = torch.randn(b, l, e)
ref = original.encoder(x, output_hidden_states=True).hidden_states
hyp = imported.encoder.transformer.get_intermediate_outputs(x)
for i in range(len(hyp)):
self.assertEqual(ref[i + 1], hyp[i], atol=1e-4, rtol=0.001)
def _test_recreate(self, imported, reloaded, config): def _test_recreate(self, imported, reloaded, config):
# FeatureExtractor # FeatureExtractor
x = torch.randn(3, 1024) x = torch.randn(3, 1024)
......
...@@ -458,9 +458,10 @@ class Transformer(Module): ...@@ -458,9 +458,10 @@ class Transformer(Module):
raise ValueError(f"`num_layers` must be between [1, {len(self.layers)}]") raise ValueError(f"`num_layers` must be between [1, {len(self.layers)}]")
ret: List[Tensor] = [] ret: List[Tensor] = []
position_bias = None
x = self._preprocess(x) x = self._preprocess(x)
for layer in self.layers: for layer in self.layers:
x, _ = layer(x, attention_mask) # Ignore position_bias x, position_bias = layer(x, attention_mask, position_bias=position_bias)
ret.append(x) ret.append(x)
if num_layers is not None and len(ret) >= num_layers: if num_layers is not None and len(ret) >= num_layers:
return ret return ret
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment