Unverified Commit 83eec97e authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix TF Longformer (#9348)

* Fix longformer

* Apply style

* Remove serving content

* Forgot a condition

* Apply style

* Address Patrick's comments

* Fix dtype
parent 30fa0b78
......@@ -390,7 +390,7 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se
True` else after `sep_token_id`.
"""
assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions"
assert shape_list(sep_token_indices)[1] == 2, "`input_ids` should have two dimensions"
question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1]
question_end_index = tf.cast(question_end_index[:, None], tf.dtypes.int32) # size: batch_size x 1
# bool attention mask with True in locations of global attention
......@@ -1028,7 +1028,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
# pad to full matrix
padding = tf.constant(
padding = tf.convert_to_tensor(
[[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]]
)
......@@ -1523,8 +1523,7 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
training=False,
):
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_global_attentions = () if (output_attentions and is_global_attn) else None
all_attentions = all_global_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
......@@ -1547,7 +1546,6 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
# bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1)
all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),)
if is_global_attn:
# bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn
all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)))
......@@ -1766,7 +1764,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
)
)
paddings = tf.constant([[0, 0], [0, padding_len]])
paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]])
if input_ids is not None:
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)
......@@ -1776,13 +1774,15 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id)
if inputs_embeds is not None:
def pad_embeddings():
input_ids_padding = tf.fill((batch_size, padding_len), self.pad_token_id)
inputs_embeds_padding = self.embeddings(input_ids_padding)
inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
return tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
inputs_embeds = tf.cond(padding_len > 0, pad_embeddings, lambda: inputs_embeds)
attention_mask = tf.pad(
attention_mask, paddings, constant_values=False
) # no attention on the padding tokens
attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens
token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0
return (
......@@ -2171,16 +2171,14 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
# set global attention on question tokens
if inputs["global_attention_mask"] is None and inputs["input_ids"] is not None:
if inputs["input_ids"] is None:
logger.warning(
"It is not possible to automatically generate the `global_attention_mask`. Please make sure that it is correctly set."
)
elif (
tf.where(inputs["input_ids"] == self.config.sep_token_id).shape[0] != 3 * inputs["input_ids"].shape[0]
if (
shape_list(tf.where(inputs["input_ids"] == self.config.sep_token_id))[0]
!= 3 * shape_list(inputs["input_ids"])[0]
):
logger.warning(
f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error."
f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error. The global attention is disabled for this forward pass."
)
inputs["global_attention_mask"] = tf.fill(shape_list(inputs["input_ids"]), value=0)
else:
logger.info("Initializing global attention on question tokens...")
# put global attention on all tokens until `config.sep_token_id` is reached
......@@ -2317,8 +2315,8 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
inputs["global_attention_mask"] = tf.zeros_like(inputs["input_ids"])
inputs["global_attention_mask"] = tf.tensor_scatter_nd_update(
inputs["global_attention_mask"],
[[i, 0] for i in range(inputs["input_ids"].shape[0])],
[1 for _ in range(inputs["input_ids"].shape[0])],
[[i, 0] for i in range(shape_list(inputs["input_ids"])[0])],
[1 for _ in range(shape_list(inputs["input_ids"])[0])],
)
outputs = self.longformer(
......@@ -2443,7 +2441,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
flat_global_attention_mask = (
tf.reshape(inputs["global_attention_mask"], (-1, inputs["global_attention_mask"].shape[-1]))
tf.reshape(inputs["global_attention_mask"], (-1, shape_list(inputs["global_attention_mask"])[-1]))
if inputs["global_attention_mask"] is not None
else None
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment