Unverified Commit 22a32cf4 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix TF LED/Longformer attentions computation (#10007)

* Fix test

* Remove commented test

* Fix name

* Apply style

* Fix check copies

* Remove prints

* Restore boolean

* Fix reshape
parent 0d8e554d
...@@ -266,13 +266,26 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -266,13 +266,26 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
), ),
lambda: attn_scores, lambda: attn_scores,
) )
attn_probs = tf.nn.softmax(attn_scores, axis=-1) attn_probs = tf.nn.softmax(attn_scores, axis=-1)
# softmax sometimes inserts NaN if all positions are masked, replace them with 0 # softmax sometimes inserts NaN if all positions are masked, replace them with 0
# Make sure to create a mask with the proper shape:
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
masked_index = tf.cond(
is_global_attn,
lambda: tf.tile(
is_index_masked[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
),
lambda: tf.tile(
is_index_masked[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
),
)
attn_probs = tf.where( attn_probs = tf.where(
tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)), masked_index,
0.0, tf.zeros(shape_list(masked_index), dtype=tf.dtypes.float32),
attn_probs, attn_probs,
) )
...@@ -330,11 +343,23 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -330,11 +343,23 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
) )
# make sure that local attention probabilities are set to 0 for indices of global attn # make sure that local attention probabilities are set to 0 for indices of global attn
# When is_global_attn is True, the last dimension is always self.one_sided_attn_window_size * 2 + 1 + 1 # Make sure to create a mask with the proper shape:
# because of the concat Line 713. # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
masked_global_attn_index = tf.cond(
is_global_attn,
lambda: tf.tile(
is_index_global_attn[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
),
lambda: tf.tile(
is_index_global_attn[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
),
)
attn_probs = tf.where( attn_probs = tf.where(
tf.broadcast_to(is_index_global_attn[:, :, None, None], shape_list(attn_probs)), masked_global_attn_index,
tf.zeros(shape_list(attn_probs), dtype=tf.dtypes.float32), tf.zeros(shape_list(masked_global_attn_index), dtype=tf.dtypes.float32),
attn_probs, attn_probs,
) )
...@@ -418,14 +443,9 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -418,14 +443,9 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
axis=1, axis=1,
) )
first_chunk_mask = ( first_chunk_mask = (
tf.broadcast_to( tf.tile(
tf.range(chunks_count + 1)[None, :, None, None], tf.range(chunks_count + 1)[None, :, None, None],
shape=( (batch_size * num_heads, 1, window_overlap, window_overlap),
batch_size * num_heads,
chunks_count + 1,
window_overlap,
window_overlap,
),
) )
< 1 < 1
) )
...@@ -473,7 +493,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -473,7 +493,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1]) mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])
# broadcast to full matrix # broadcast to full matrix
mask_4d = tf.broadcast_to(mask_2d[None, :, None, :], shape_list(input_tensor)) mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1))
# inf tensor used for masking # inf tensor used for masking
inf_tensor = -float("inf") * tf.ones_like(input_tensor, dtype=tf.dtypes.float32) inf_tensor = -float("inf") * tf.ones_like(input_tensor, dtype=tf.dtypes.float32)
...@@ -818,7 +838,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -818,7 +838,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3)) global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
# mask global attn scores # mask global attn scores
attn_mask = tf.broadcast_to(is_index_masked[:, None, None, :], shape_list(global_attn_scores)) attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, shape_list(global_attn_scores)[1], 1, 1))
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores) global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
global_attn_scores = tf.reshape( global_attn_scores = tf.reshape(
global_attn_scores, global_attn_scores,
...@@ -1761,7 +1781,7 @@ class TFLEDEncoder(tf.keras.layers.Layer): ...@@ -1761,7 +1781,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
batch_size, seq_len = input_shape[:2] batch_size, seq_len = input_shape[:2]
padding_len = (attention_window - seq_len % attention_window) % attention_window padding_len = (attention_window - seq_len % attention_window) % attention_window
if tf.math.greater(padding_len, 0): if padding_len > 0:
logger.info( logger.info(
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format( "Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
seq_len, seq_len + padding_len, attention_window seq_len, seq_len + padding_len, attention_window
......
...@@ -395,21 +395,20 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se ...@@ -395,21 +395,20 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se
question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1] question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1]
question_end_index = tf.cast(question_end_index[:, None], tf.dtypes.int32) # size: batch_size x 1 question_end_index = tf.cast(question_end_index[:, None], tf.dtypes.int32) # size: batch_size x 1
# bool attention mask with True in locations of global attention # bool attention mask with True in locations of global attention
attention_mask = tf.range(input_ids_shape[1]) attention_mask = tf.range(input_ids_shape[1])[tf.newaxis, :]
attention_mask = tf.tile(attention_mask, (input_ids_shape[0], 1))
if before_sep_token is True: if before_sep_token is True:
attention_mask = tf.cast( question_end_index = tf.tile(question_end_index, (1, input_ids_shape[1]))
tf.broadcast_to(attention_mask, input_ids_shape) < tf.broadcast_to(question_end_index, input_ids_shape), attention_mask = tf.cast(attention_mask < question_end_index, tf.int32)
tf.dtypes.int32,
)
else: else:
# last token is separation token and should not be counted and in the middle are two separation tokens # last token is separation token and should not be counted and in the middle are two separation tokens
question_end_index = tf.tile(question_end_index + 1, (1, input_ids_shape[1]))
attention_mask = ( attention_mask = (
tf.cast( tf.cast(
tf.broadcast_to(attention_mask, input_ids_shape) attention_mask > question_end_index,
> tf.broadcast_to(question_end_index + 1, input_ids_shape),
tf.dtypes.int32, tf.dtypes.int32,
) )
* tf.cast(tf.broadcast_to(attention_mask, input_ids_shape) < input_ids_shape[-1], tf.dtypes.int32) * tf.cast(attention_mask < input_ids_shape[-1], tf.dtypes.int32)
) )
return attention_mask return attention_mask
...@@ -785,13 +784,26 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -785,13 +784,26 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
), ),
lambda: attn_scores, lambda: attn_scores,
) )
attn_probs = tf.nn.softmax(attn_scores, axis=-1) attn_probs = tf.nn.softmax(attn_scores, axis=-1)
# softmax sometimes inserts NaN if all positions are masked, replace them with 0 # softmax sometimes inserts NaN if all positions are masked, replace them with 0
# Make sure to create a mask with the proper shape:
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
masked_index = tf.cond(
is_global_attn,
lambda: tf.tile(
is_index_masked[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
),
lambda: tf.tile(
is_index_masked[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
),
)
attn_probs = tf.where( attn_probs = tf.where(
tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)), masked_index,
0.0, tf.zeros(shape_list(masked_index), dtype=tf.dtypes.float32),
attn_probs, attn_probs,
) )
...@@ -849,11 +861,23 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -849,11 +861,23 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) )
# make sure that local attention probabilities are set to 0 for indices of global attn # make sure that local attention probabilities are set to 0 for indices of global attn
# When is_global_attn is True, the last dimension is always self.one_sided_attn_window_size * 2 + 1 + 1 # Make sure to create a mask with the proper shape:
# because of the concat Line 713. # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
masked_global_attn_index = tf.cond(
is_global_attn,
lambda: tf.tile(
is_index_global_attn[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
),
lambda: tf.tile(
is_index_global_attn[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
),
)
attn_probs = tf.where( attn_probs = tf.where(
tf.broadcast_to(is_index_global_attn[:, :, None, None], shape_list(attn_probs)), masked_global_attn_index,
tf.zeros(shape_list(attn_probs), dtype=tf.dtypes.float32), tf.zeros(shape_list(masked_global_attn_index), dtype=tf.dtypes.float32),
attn_probs, attn_probs,
) )
...@@ -937,14 +961,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -937,14 +961,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
axis=1, axis=1,
) )
first_chunk_mask = ( first_chunk_mask = (
tf.broadcast_to( tf.tile(
tf.range(chunks_count + 1)[None, :, None, None], tf.range(chunks_count + 1)[None, :, None, None],
shape=( (batch_size * num_heads, 1, window_overlap, window_overlap),
batch_size * num_heads,
chunks_count + 1,
window_overlap,
window_overlap,
),
) )
< 1 < 1
) )
...@@ -992,7 +1011,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -992,7 +1011,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1]) mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])
# broadcast to full matrix # broadcast to full matrix
mask_4d = tf.broadcast_to(mask_2d[None, :, None, :], shape_list(input_tensor)) mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1))
# inf tensor used for masking # inf tensor used for masking
inf_tensor = -float("inf") * tf.ones_like(input_tensor, dtype=tf.dtypes.float32) inf_tensor = -float("inf") * tf.ones_like(input_tensor, dtype=tf.dtypes.float32)
...@@ -1337,7 +1356,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -1337,7 +1356,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3)) global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
# mask global attn scores # mask global attn scores
attn_mask = tf.broadcast_to(is_index_masked[:, None, None, :], shape_list(global_attn_scores)) attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, shape_list(global_attn_scores)[1], 1, 1))
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores) global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
global_attn_scores = tf.reshape( global_attn_scores = tf.reshape(
global_attn_scores, global_attn_scores,
...@@ -1735,7 +1754,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1735,7 +1754,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
batch_size, seq_len = input_shape[:2] batch_size, seq_len = input_shape[:2]
padding_len = (attention_window - seq_len % attention_window) % attention_window padding_len = (attention_window - seq_len % attention_window) % attention_window
if tf.math.greater(padding_len, 0): if padding_len > 0:
logger.info( logger.info(
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format( "Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
seq_len, seq_len + padding_len, attention_window seq_len, seq_len + padding_len, attention_window
......
...@@ -78,7 +78,7 @@ class TFLEDModelTester: ...@@ -78,7 +78,7 @@ class TFLEDModelTester:
# [num_attention_heads, encoder_seq_length, encoder_key_length], but TFLongformerSelfAttention # [num_attention_heads, encoder_seq_length, encoder_key_length], but TFLongformerSelfAttention
# returns attention of shape [num_attention_heads, encoder_seq_length, self.attention_window + 1] # returns attention of shape [num_attention_heads, encoder_seq_length, self.attention_window + 1]
# because its local attention only attends to `self.attention_window` and one before and one after # because its local attention only attends to `self.attention_window` and one before and one after
self.key_length = self.attention_window + 1 self.key_length = self.attention_window + 2
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for # because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
# the `test_attention_outputs` and `test_hidden_states_output` tests # the `test_attention_outputs` and `test_hidden_states_output` tests
...@@ -369,15 +369,8 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -369,15 +369,8 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
pass pass
def test_saved_model_with_attentions_output(self): def test_saved_model_with_attentions_output(self):
# This test don't pass because of the error: # Temporarily disable this test in order to find
# condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable # how to better handle it without timing out the CI
# This occurs line 323 in modeling_tf_led.py because the condition line 255
# returns a tensor of shape
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 2]
# if is_global_attn is True and a tensor of shape
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
# This is due to the tf.concat call line 703 that adds one dimension
# Need to check with PVP how to properly fix this
pass pass
@slow @slow
......
...@@ -339,15 +339,8 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -339,15 +339,8 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
@slow @slow
def test_saved_model_with_attentions_output(self): def test_saved_model_with_attentions_output(self):
# This test don't pass because of the error: # Temporarily disable this test in order to find
# condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable # how to better handle it without timing out the CI
# This occurs line 323 in modeling_tf_led.py because the condition line 255
# returns a tensor of shape
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 2]
# if is_global_attn is True and a tensor of shape
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
# This is due to the tf.concat call line 703 that adds one dimension
# Need to check with PVP how to properly fix this
pass pass
@slow @slow
...@@ -371,7 +364,7 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -371,7 +364,7 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
pass pass
def test_xla_mode(self): def test_xla_mode(self):
# TODO JP: Make Blenderbot XLA compliant # TODO JP: Make Longformer XLA compliant
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment