Unverified Commit 83eec97e authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix TF Longformer (#9348)

* Fix longformer

* Apply style

* Remove serving content

* Forgot a condition

* Apply style

* Address Patrick's comments

* Fix dtype
parent 30fa0b78
...@@ -390,7 +390,7 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se ...@@ -390,7 +390,7 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se
True` else after `sep_token_id`. True` else after `sep_token_id`.
""" """
assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions" assert shape_list(sep_token_indices)[1] == 2, "`input_ids` should have two dimensions"
question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1] question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1]
question_end_index = tf.cast(question_end_index[:, None], tf.dtypes.int32) # size: batch_size x 1 question_end_index = tf.cast(question_end_index[:, None], tf.dtypes.int32) # size: batch_size x 1
# bool attention mask with True in locations of global attention # bool attention mask with True in locations of global attention
...@@ -1028,7 +1028,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -1028,7 +1028,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) )
# pad to full matrix # pad to full matrix
padding = tf.constant( padding = tf.convert_to_tensor(
[[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]] [[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]]
) )
...@@ -1523,8 +1523,7 @@ class TFLongformerEncoder(tf.keras.layers.Layer): ...@@ -1523,8 +1523,7 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
training=False, training=False,
): ):
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = all_global_attentions = () if output_attentions else None
all_global_attentions = () if (output_attentions and is_global_attn) else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
...@@ -1547,9 +1546,8 @@ class TFLongformerEncoder(tf.keras.layers.Layer): ...@@ -1547,9 +1546,8 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
# bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1) # bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1)
all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),) all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),)
if is_global_attn: # bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn
# bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)))
all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)))
# Add last layer # Add last layer
if output_hidden_states: if output_hidden_states:
...@@ -1766,24 +1764,26 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1766,24 +1764,26 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
) )
) )
paddings = tf.constant([[0, 0], [0, padding_len]]) paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]])
if input_ids is not None: if input_ids is not None:
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id) input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)
if position_ids is not None:
# pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings
position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id)
if position_ids is not None: if inputs_embeds is not None:
# pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings
position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id)
if inputs_embeds is not None: def pad_embeddings():
input_ids_padding = tf.fill((batch_size, padding_len), self.pad_token_id) input_ids_padding = tf.fill((batch_size, padding_len), self.pad_token_id)
inputs_embeds_padding = self.embeddings(input_ids_padding) inputs_embeds_padding = self.embeddings(input_ids_padding)
inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2) return tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
attention_mask = tf.pad( inputs_embeds = tf.cond(padding_len > 0, pad_embeddings, lambda: inputs_embeds)
attention_mask, paddings, constant_values=False
) # no attention on the padding tokens attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens
token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0 token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0
return ( return (
padding_len, padding_len,
...@@ -2171,16 +2171,14 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn ...@@ -2171,16 +2171,14 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
# set global attention on question tokens # set global attention on question tokens
if inputs["global_attention_mask"] is None and inputs["input_ids"] is not None: if inputs["global_attention_mask"] is None and inputs["input_ids"] is not None:
if inputs["input_ids"] is None: if (
logger.warning( shape_list(tf.where(inputs["input_ids"] == self.config.sep_token_id))[0]
"It is not possible to automatically generate the `global_attention_mask`. Please make sure that it is correctly set." != 3 * shape_list(inputs["input_ids"])[0]
)
elif (
tf.where(inputs["input_ids"] == self.config.sep_token_id).shape[0] != 3 * inputs["input_ids"].shape[0]
): ):
logger.warning( logger.warning(
f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error." f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error. The global attention is disabled for this forward pass."
) )
inputs["global_attention_mask"] = tf.fill(shape_list(inputs["input_ids"]), value=0)
else: else:
logger.info("Initializing global attention on question tokens...") logger.info("Initializing global attention on question tokens...")
# put global attention on all tokens until `config.sep_token_id` is reached # put global attention on all tokens until `config.sep_token_id` is reached
...@@ -2317,8 +2315,8 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque ...@@ -2317,8 +2315,8 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
inputs["global_attention_mask"] = tf.zeros_like(inputs["input_ids"]) inputs["global_attention_mask"] = tf.zeros_like(inputs["input_ids"])
inputs["global_attention_mask"] = tf.tensor_scatter_nd_update( inputs["global_attention_mask"] = tf.tensor_scatter_nd_update(
inputs["global_attention_mask"], inputs["global_attention_mask"],
[[i, 0] for i in range(inputs["input_ids"].shape[0])], [[i, 0] for i in range(shape_list(inputs["input_ids"])[0])],
[1 for _ in range(inputs["input_ids"].shape[0])], [1 for _ in range(shape_list(inputs["input_ids"])[0])],
) )
outputs = self.longformer( outputs = self.longformer(
...@@ -2443,7 +2441,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic ...@@ -2443,7 +2441,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
) )
flat_global_attention_mask = ( flat_global_attention_mask = (
tf.reshape(inputs["global_attention_mask"], (-1, inputs["global_attention_mask"].shape[-1])) tf.reshape(inputs["global_attention_mask"], (-1, shape_list(inputs["global_attention_mask"])[-1]))
if inputs["global_attention_mask"] is not None if inputs["global_attention_mask"] is not None
else None else None
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment