Unverified Commit 4c940934 authored by Jonghwan Hyeon's avatar Jonghwan Hyeon Committed by GitHub
Browse files

Output `None` as attention when layer is skipped (#30597)

* Output `None` as attention when layer is skipped

* Add test for output_attentions
parent 39359e5b
...@@ -727,7 +727,7 @@ class WavLMEncoder(nn.Module): ...@@ -727,7 +727,7 @@ class WavLMEncoder(nn.Module):
hidden_states, position_bias = layer_outputs[:2] hidden_states, position_bias = layer_outputs[:2]
if skip_the_layer: if skip_the_layer:
layer_outputs = (None, None) layer_outputs = (None, None, None)
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[2],) all_self_attentions = all_self_attentions + (layer_outputs[2],)
...@@ -810,7 +810,7 @@ class WavLMEncoderStableLayerNorm(nn.Module): ...@@ -810,7 +810,7 @@ class WavLMEncoderStableLayerNorm(nn.Module):
hidden_states, position_bias = layer_outputs[:2] hidden_states, position_bias = layer_outputs[:2]
if skip_the_layer: if skip_the_layer:
layer_outputs = (None, None) layer_outputs = (None, None, None)
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[2],) all_self_attentions = all_self_attentions + (layer_outputs[2],)
......
...@@ -288,6 +288,15 @@ class WavLMModelTester: ...@@ -288,6 +288,15 @@ class WavLMModelTester:
loss.backward() loss.backward()
def check_output_attentions(self, config, input_values, attention_mask):
model = WavLMModel(config=config)
model.config.layerdrop = 1.0
model.to(torch_device)
model.train()
outputs = model(input_values, attention_mask=attention_mask, output_attentions=True)
self.parent.assertTrue(len(outputs.attentions) > 0)
def check_labels_out_of_vocab(self, config, input_values, *args): def check_labels_out_of_vocab(self, config, input_values, *args):
model = WavLMForCTC(config) model = WavLMForCTC(config)
model.to(torch_device) model.to(torch_device)
...@@ -354,6 +363,10 @@ class WavLMModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -354,6 +363,10 @@ class WavLMModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_seq_classifier_training(*config_and_inputs) self.model_tester.check_seq_classifier_training(*config_and_inputs)
def test_output_attentions(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_output_attentions(*config_and_inputs)
def test_labels_out_of_vocab(self): def test_labels_out_of_vocab(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_labels_out_of_vocab(*config_and_inputs) self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment