Unverified Commit 22a32cf4 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix TF LED/Longformer attentions computation (#10007)

* Fix test

* Remove commented test

* Fix name

* Apply style

* Fix check copies

* Remove prints

* Restore boolean

* Fix reshape
parent 0d8e554d
......@@ -266,13 +266,26 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
),
lambda: attn_scores,
)
attn_probs = tf.nn.softmax(attn_scores, axis=-1)
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
# Make sure to create a mask with the proper shape:
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
masked_index = tf.cond(
is_global_attn,
lambda: tf.tile(
is_index_masked[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
),
lambda: tf.tile(
is_index_masked[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
),
)
attn_probs = tf.where(
tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)),
0.0,
masked_index,
tf.zeros(shape_list(masked_index), dtype=tf.dtypes.float32),
attn_probs,
)
......@@ -330,11 +343,23 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
)
# make sure that local attention probabilities are set to 0 for indices of global attn
# When is_global_attn is True, the last dimension is always self.one_sided_attn_window_size * 2 + 1 + 1
# because of the concat Line 713.
# Make sure to create a mask with the proper shape:
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
masked_global_attn_index = tf.cond(
is_global_attn,
lambda: tf.tile(
is_index_global_attn[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
),
lambda: tf.tile(
is_index_global_attn[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
),
)
attn_probs = tf.where(
tf.broadcast_to(is_index_global_attn[:, :, None, None], shape_list(attn_probs)),
tf.zeros(shape_list(attn_probs), dtype=tf.dtypes.float32),
masked_global_attn_index,
tf.zeros(shape_list(masked_global_attn_index), dtype=tf.dtypes.float32),
attn_probs,
)
......@@ -418,14 +443,9 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
axis=1,
)
first_chunk_mask = (
tf.broadcast_to(
tf.tile(
tf.range(chunks_count + 1)[None, :, None, None],
shape=(
batch_size * num_heads,
chunks_count + 1,
window_overlap,
window_overlap,
),
(batch_size * num_heads, 1, window_overlap, window_overlap),
)
< 1
)
......@@ -473,7 +493,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])
# broadcast to full matrix
mask_4d = tf.broadcast_to(mask_2d[None, :, None, :], shape_list(input_tensor))
mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1))
# inf tensor used for masking
inf_tensor = -float("inf") * tf.ones_like(input_tensor, dtype=tf.dtypes.float32)
......@@ -818,7 +838,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
# mask global attn scores
attn_mask = tf.broadcast_to(is_index_masked[:, None, None, :], shape_list(global_attn_scores))
attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, shape_list(global_attn_scores)[1], 1, 1))
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
global_attn_scores = tf.reshape(
global_attn_scores,
......@@ -1761,7 +1781,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
batch_size, seq_len = input_shape[:2]
padding_len = (attention_window - seq_len % attention_window) % attention_window
if tf.math.greater(padding_len, 0):
if padding_len > 0:
logger.info(
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
seq_len, seq_len + padding_len, attention_window
......
......@@ -395,21 +395,20 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se
question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1]
question_end_index = tf.cast(question_end_index[:, None], tf.dtypes.int32) # size: batch_size x 1
# bool attention mask with True in locations of global attention
attention_mask = tf.range(input_ids_shape[1])
attention_mask = tf.range(input_ids_shape[1])[tf.newaxis, :]
attention_mask = tf.tile(attention_mask, (input_ids_shape[0], 1))
if before_sep_token is True:
attention_mask = tf.cast(
tf.broadcast_to(attention_mask, input_ids_shape) < tf.broadcast_to(question_end_index, input_ids_shape),
tf.dtypes.int32,
)
question_end_index = tf.tile(question_end_index, (1, input_ids_shape[1]))
attention_mask = tf.cast(attention_mask < question_end_index, tf.int32)
else:
# last token is separation token and should not be counted and in the middle are two separation tokens
question_end_index = tf.tile(question_end_index + 1, (1, input_ids_shape[1]))
attention_mask = (
tf.cast(
tf.broadcast_to(attention_mask, input_ids_shape)
> tf.broadcast_to(question_end_index + 1, input_ids_shape),
attention_mask > question_end_index,
tf.dtypes.int32,
)
* tf.cast(tf.broadcast_to(attention_mask, input_ids_shape) < input_ids_shape[-1], tf.dtypes.int32)
* tf.cast(attention_mask < input_ids_shape[-1], tf.dtypes.int32)
)
return attention_mask
......@@ -785,13 +784,26 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
),
lambda: attn_scores,
)
attn_probs = tf.nn.softmax(attn_scores, axis=-1)
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
# Make sure to create a mask with the proper shape:
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
masked_index = tf.cond(
is_global_attn,
lambda: tf.tile(
is_index_masked[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
),
lambda: tf.tile(
is_index_masked[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
),
)
attn_probs = tf.where(
tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)),
0.0,
masked_index,
tf.zeros(shape_list(masked_index), dtype=tf.dtypes.float32),
attn_probs,
)
......@@ -849,11 +861,23 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
# make sure that local attention probabilities are set to 0 for indices of global attn
# When is_global_attn is True, the last dimension is always self.one_sided_attn_window_size * 2 + 1 + 1
# because of the concat Line 713.
# Make sure to create a mask with the proper shape:
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
masked_global_attn_index = tf.cond(
is_global_attn,
lambda: tf.tile(
is_index_global_attn[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
),
lambda: tf.tile(
is_index_global_attn[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
),
)
attn_probs = tf.where(
tf.broadcast_to(is_index_global_attn[:, :, None, None], shape_list(attn_probs)),
tf.zeros(shape_list(attn_probs), dtype=tf.dtypes.float32),
masked_global_attn_index,
tf.zeros(shape_list(masked_global_attn_index), dtype=tf.dtypes.float32),
attn_probs,
)
......@@ -937,14 +961,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
axis=1,
)
first_chunk_mask = (
tf.broadcast_to(
tf.tile(
tf.range(chunks_count + 1)[None, :, None, None],
shape=(
batch_size * num_heads,
chunks_count + 1,
window_overlap,
window_overlap,
),
(batch_size * num_heads, 1, window_overlap, window_overlap),
)
< 1
)
......@@ -992,7 +1011,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])
# broadcast to full matrix
mask_4d = tf.broadcast_to(mask_2d[None, :, None, :], shape_list(input_tensor))
mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1))
# inf tensor used for masking
inf_tensor = -float("inf") * tf.ones_like(input_tensor, dtype=tf.dtypes.float32)
......@@ -1337,7 +1356,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
# mask global attn scores
attn_mask = tf.broadcast_to(is_index_masked[:, None, None, :], shape_list(global_attn_scores))
attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, shape_list(global_attn_scores)[1], 1, 1))
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
global_attn_scores = tf.reshape(
global_attn_scores,
......@@ -1735,7 +1754,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
batch_size, seq_len = input_shape[:2]
padding_len = (attention_window - seq_len % attention_window) % attention_window
if tf.math.greater(padding_len, 0):
if padding_len > 0:
logger.info(
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
seq_len, seq_len + padding_len, attention_window
......
......@@ -78,7 +78,7 @@ class TFLEDModelTester:
# [num_attention_heads, encoder_seq_length, encoder_key_length], but TFLongformerSelfAttention
# returns attention of shape [num_attention_heads, encoder_seq_length, self.attention_window + 1]
# because its local attention only attends to `self.attention_window` and one before and one after
self.key_length = self.attention_window + 1
self.key_length = self.attention_window + 2
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
# the `test_attention_outputs` and `test_hidden_states_output` tests
......@@ -369,15 +369,8 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
pass
def test_saved_model_with_attentions_output(self):
# This test don't pass because of the error:
# condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable
# This occurs line 323 in modeling_tf_led.py because the condition line 255
# returns a tensor of shape
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 2]
# if is_global_attn is True and a tensor of shape
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
# This is due to the tf.concat call line 703 that adds one dimension
# Need to check with PVP how to properly fix this
# Temporarily disable this test in order to find
# how to better handle it without timing out the CI
pass
@slow
......
......@@ -339,15 +339,8 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
@slow
def test_saved_model_with_attentions_output(self):
# This test don't pass because of the error:
# condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable
# This occurs line 323 in modeling_tf_led.py because the condition line 255
# returns a tensor of shape
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 2]
# if is_global_attn is True and a tensor of shape
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
# This is due to the tf.concat call line 703 that adds one dimension
# Need to check with PVP how to properly fix this
# Temporarily disable this test in order to find
# how to better handle it without timing out the CI
pass
@slow
......@@ -371,7 +364,7 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
pass
def test_xla_mode(self):
# TODO JP: Make Blenderbot XLA compliant
# TODO JP: Make Longformer XLA compliant
pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment