"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "1f6f32c24338ad1ff17475b836c7b4505da77714"
Unverified Commit 61cf2ea9 authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix incorrect output shapes for TF/PT LED (#13882)



* Fix issues with LED model

* Style pass

* Bugfixes

* correct attentions as well
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 5f34163b
...@@ -1858,6 +1858,11 @@ class LEDEncoder(LEDPreTrainedModel): ...@@ -1858,6 +1858,11 @@ class LEDEncoder(LEDPreTrainedModel):
if padding_len > 0: if padding_len > 0:
# unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1) # unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1)
hidden_states = hidden_states[:, :-padding_len] hidden_states = hidden_states[:, :-padding_len]
if output_hidden_states:
encoder_states = tuple([state[:, :-padding_len] for state in encoder_states])
if output_attentions:
all_attentions = tuple([state[:, :, :-padding_len, :] for state in all_attentions])
if not return_dict: if not return_dict:
return tuple( return tuple(
......
...@@ -1602,7 +1602,9 @@ class TFLEDEncoder(tf.keras.layers.Layer): ...@@ -1602,7 +1602,9 @@ class TFLEDEncoder(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config self.config = config
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
self.layerdrop = config.encoder_layerdrop if config.encoder_layerdrop > 0:
logger.warning("Layerdrop is currently disabled in TFLED models.")
self.layerdrop = 0.0
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
if isinstance(config.attention_window, int): if isinstance(config.attention_window, int):
...@@ -1867,7 +1869,9 @@ class TFLEDDecoder(tf.keras.layers.Layer): ...@@ -1867,7 +1869,9 @@ class TFLEDDecoder(tf.keras.layers.Layer):
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
self.layerdrop = config.decoder_layerdrop if config.decoder_layerdrop > 0:
logger.warning("Layerdrop is currently disabled in TFLED models.")
self.layerdrop = 0.0
self.embed_positions = TFLEDLearnedPositionalEmbedding( self.embed_positions = TFLEDLearnedPositionalEmbedding(
config.max_decoder_position_embeddings, config.max_decoder_position_embeddings,
config.d_model, config.d_model,
...@@ -2451,7 +2455,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): ...@@ -2451,7 +2455,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
past_key_values=outputs.past_key_values, # index 1 of d outputs past_key_values=outputs.past_key_values, # index 1 of d outputs
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
encoder_last_hidden_state=outputs.last_hidden_state, # index 0 of encoder outputs encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
encoder_attentions=outputs.encoder_attentions, # 2 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out
encoder_global_attentions=outputs.encoder_global_attentions, encoder_global_attentions=outputs.encoder_global_attentions,
......
...@@ -126,9 +126,7 @@ class LEDModelTester: ...@@ -126,9 +126,7 @@ class LEDModelTester:
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for # because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
# the `test_attention_outputs` and `test_hidden_states_output` tests # the `test_attention_outputs` and `test_hidden_states_output` tests
self.encoder_seq_length = ( self.encoder_seq_length = self.seq_length
self.seq_length + (self.attention_window - self.seq_length % self.attention_window) % self.attention_window
)
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
...@@ -354,32 +352,6 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -354,32 +352,6 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
# longformer cannot keep gradients in attentions or hidden states # longformer cannot keep gradients in attentions or hidden states
return return
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
# make sure tgt_length is padded
tgt_length = (
seq_length // config.attention_window[0] + (seq_length % config.attention_window[0] != 0)
) * config.attention_window[0]
encoder_expected_shape = (batch_size, config.num_attention_heads, tgt_length, seq_length)
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[layer_attentions.shape for layer_attentions in attentions],
[encoder_expected_shape] * len(attentions),
)
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length):
# make sure seq_length is padded
seq_length = (
seq_length // config.attention_window[0] + (seq_length % config.attention_window[0] != 0)
) * config.attention_window[0]
encoder_expected_shape = (batch_size, seq_length, config.hidden_size)
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
[encoder_expected_shape] * len(hidden_states),
)
def test_attention_outputs(self): def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True config.return_dict = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment