Unverified Commit 2f2fefd6 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix LongformerModel hidden states (#15537)



* add undo padding

* fix

* fix tuple issue

* make style and quality

* move unpad logic to LongformerEncoder + unpad attentions + update tests

* move unpad logic to TFLongformerEncoder
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 68dec6bf
...@@ -1246,6 +1246,7 @@ class LongformerEncoder(nn.Module): ...@@ -1246,6 +1246,7 @@ class LongformerEncoder(nn.Module):
hidden_states, hidden_states,
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
padding_len=0,
output_attentions=False, output_attentions=False,
output_hidden_states=False, output_hidden_states=False,
return_dict=True, return_dict=True,
...@@ -1308,6 +1309,16 @@ class LongformerEncoder(nn.Module): ...@@ -1308,6 +1309,16 @@ class LongformerEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
# undo padding
if padding_len > 0:
# unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1)
hidden_states = hidden_states[:, :-padding_len]
if output_hidden_states:
all_hidden_states = tuple([state[:, :-padding_len] for state in all_hidden_states])
if output_attentions:
all_attentions = tuple([state[:, :, :-padding_len, :] for state in all_attentions])
if not return_dict: if not return_dict:
return tuple( return tuple(
v for v in [hidden_states, all_hidden_states, all_attentions, all_global_attentions] if v is not None v for v in [hidden_states, all_hidden_states, all_attentions, all_global_attentions] if v is not None
...@@ -1697,6 +1708,7 @@ class LongformerModel(LongformerPreTrainedModel): ...@@ -1697,6 +1708,7 @@ class LongformerModel(LongformerPreTrainedModel):
embedding_output, embedding_output,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
head_mask=head_mask, head_mask=head_mask,
padding_len=padding_len,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -1704,11 +1716,6 @@ class LongformerModel(LongformerPreTrainedModel): ...@@ -1704,11 +1716,6 @@ class LongformerModel(LongformerPreTrainedModel):
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
# undo padding
if padding_len > 0:
# unpad `sequence_output` because the calling function is expecting a length == input_ids.size(1)
sequence_output = sequence_output[:, :-padding_len]
if not return_dict: if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:] return (sequence_output, pooled_output) + encoder_outputs[1:]
......
...@@ -1587,13 +1587,23 @@ class TFLongformerEncoder(tf.keras.layers.Layer): ...@@ -1587,13 +1587,23 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),) all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),)
# bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn # bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn
all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2))) all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)),)
# Add last layer # Add last layer
if output_hidden_states: if output_hidden_states:
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
all_hidden_states = all_hidden_states + (hidden_states_to_add,) all_hidden_states = all_hidden_states + (hidden_states_to_add,)
# undo padding
# unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1)
hidden_states = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
if output_attentions:
all_attentions = (
tuple([state[:, :, :-padding_len, :] for state in all_attentions])
if padding_len > 0
else all_attentions
)
if not return_dict: if not return_dict:
return tuple( return tuple(
v for v in [hidden_states, all_hidden_states, all_attentions, all_global_attentions] if v is not None v for v in [hidden_states, all_hidden_states, all_attentions, all_global_attentions] if v is not None
...@@ -1763,11 +1773,6 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1763,11 +1773,6 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
# undo padding
if padding_len > 0:
# unpad `sequence_output` because the calling function is expecting a length == input_ids.size(1)
sequence_output = sequence_output[:, :-padding_len]
if not inputs["return_dict"]: if not inputs["return_dict"]:
return ( return (
sequence_output, sequence_output,
......
...@@ -74,12 +74,6 @@ class LongformerModelTester: ...@@ -74,12 +74,6 @@ class LongformerModelTester:
# is x + self.attention_window + 1, where x is the number of tokens with global attention) # is x + self.attention_window + 1, where x is the number of tokens with global attention)
self.key_length = self.attention_window + 2 self.key_length = self.attention_window + 2
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
# the `test_attention_outputs` and `test_hidden_states_output` tests
self.encoder_seq_length = (
self.seq_length + (self.attention_window - self.seq_length % self.attention_window) % self.attention_window
)
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
......
...@@ -74,12 +74,6 @@ class TFLongformerModelTester: ...@@ -74,12 +74,6 @@ class TFLongformerModelTester:
# because its local attention only attends to `self.attention_window` and one before and one after # because its local attention only attends to `self.attention_window` and one before and one after
self.key_length = self.attention_window + 2 self.key_length = self.attention_window + 2
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
# the `test_attention_outputs` and `test_hidden_states_output` tests
self.encoder_seq_length = (
self.seq_length + (self.attention_window - self.seq_length % self.attention_window) % self.attention_window
)
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment