Unverified Commit 19e737b9 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Making TF Longformer-like models compliant with AMP (#10233)

* AMP

* Add LED

* Apply style

* Fix longformer
parent cd8c4c3f
...@@ -55,8 +55,7 @@ LARGE_NEGATIVE = -1e8 ...@@ -55,8 +55,7 @@ LARGE_NEGATIVE = -1e8
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
shifted_input_ids = tf.cast(input_ids, tf.int32) shifted_input_ids = tf.roll(input_ids, 1, axis=-1)
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
# replace possible -100 values in labels by `pad_token_id` # replace possible -100 values in labels by `pad_token_id`
...@@ -65,7 +64,8 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to ...@@ -65,7 +64,8 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
) )
# "Verify that `labels` has only positive values and -100" # "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32)) if tf.executing_eagerly():
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
# Make sure the assertion op is called by wrapping the result in an identity no-op # Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]): with tf.control_dependencies([assert_gte0]):
...@@ -79,14 +79,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i ...@@ -79,14 +79,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz, tgt_len = input_ids_shape bsz, tgt_len = input_ids_shape
mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1]) mask_cond = tf.range(shape_list(mask)[-1])
mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
mask = tf.cast(mask, tf.float32)
if past_key_values_length > 0: if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1) mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
...@@ -97,9 +96,11 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values ...@@ -97,9 +96,11 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
""" """
src_len = shape_list(mask)[1] src_len = shape_list(mask)[1]
tgt_len = tgt_len if tgt_len is not None else src_len tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32) one_cst = tf.constant(1.0)
mask = tf.cast(mask, dtype=one_cst.dtype)
expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
return (1.0 - expanded_mask) * LARGE_NEGATIVE return (one_cst - expanded_mask) * LARGE_NEGATIVE
class TFLEDLearnedPositionalEmbedding(TFSharedEmbeddings): class TFLEDLearnedPositionalEmbedding(TFSharedEmbeddings):
...@@ -115,9 +116,7 @@ class TFLEDLearnedPositionalEmbedding(TFSharedEmbeddings): ...@@ -115,9 +116,7 @@ class TFLEDLearnedPositionalEmbedding(TFSharedEmbeddings):
"""Input is expected to be of size [bsz x seqlen].""" """Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_shape[:2] bsz, seq_len = input_shape[:2]
positions = tf.range( positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range"
)
return super().call(positions) return super().call(positions)
...@@ -212,6 +211,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -212,6 +211,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
value_vectors = self.value(hidden_states) value_vectors = self.value(hidden_states)
batch_size, seq_len, embed_dim = shape_list(hidden_states) batch_size, seq_len, embed_dim = shape_list(hidden_states)
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
embed_dim, embed_dim,
self.embed_dim, self.embed_dim,
...@@ -219,7 +219,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -219,7 +219,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
) )
# normalize query # normalize query
query_vectors /= tf.math.sqrt(tf.convert_to_tensor(self.head_dim, dtype=tf.dtypes.float32)) query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype))
query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
...@@ -230,7 +230,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -230,7 +230,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# diagonal mask with zeros everywhere and -inf inplace of padding # diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask = self._sliding_chunks_query_key_matmul( diagonal_mask = self._sliding_chunks_query_key_matmul(
tf.ones(shape_list(attention_mask), dtype=tf.float32), tf.ones(shape_list(attention_mask)),
attention_mask, attention_mask,
self.one_sided_attn_window_size, self.one_sided_attn_window_size,
) )
...@@ -238,6 +238,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -238,6 +238,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# pad local attention probs # pad local attention probs
attn_scores += diagonal_mask attn_scores += diagonal_mask
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_scores), shape_list(attn_scores),
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1], [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
...@@ -285,16 +286,18 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -285,16 +286,18 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
) )
attn_probs = tf.where( attn_probs = tf.where(
masked_index, masked_index,
tf.zeros(shape_list(masked_index), dtype=tf.dtypes.float32), tf.zeros(shape_list(masked_index), dtype=attn_probs.dtype),
attn_probs, attn_probs,
) )
if layer_head_mask is not None: if layer_head_mask is not None:
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(layer_head_mask), shape_list(layer_head_mask),
[self.num_heads], [self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
) )
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
# apply dropout # apply dropout
...@@ -316,6 +319,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -316,6 +319,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
), ),
) )
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_output), shape_list(attn_output),
[batch_size, seq_len, self.num_heads, self.head_dim], [batch_size, seq_len, self.num_heads, self.head_dim],
...@@ -359,7 +363,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -359,7 +363,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
) )
attn_probs = tf.where( attn_probs = tf.where(
masked_global_attn_index, masked_global_attn_index,
tf.zeros(shape_list(masked_global_attn_index), dtype=tf.dtypes.float32), tf.zeros(shape_list(masked_global_attn_index), dtype=attn_probs.dtype),
attn_probs, attn_probs,
) )
...@@ -375,6 +379,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -375,6 +379,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
""" """
batch_size, seq_len, num_heads, head_dim = shape_list(query) batch_size, seq_len, num_heads, head_dim = shape_list(query)
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
seq_len % (window_overlap * 2), seq_len % (window_overlap * 2),
0, 0,
...@@ -401,10 +406,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -401,10 +406,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap # bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap
chunked_query = tf.cast(chunked_query, dtype=chunked_key.dtype)
chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query, chunked_key) # multiply chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query, chunked_key) # multiply
# convert diagonals into columns # convert diagonals into columns
paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32) paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]])
diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(chunked_attention_scores, paddings) diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(chunked_attention_scores, paddings)
# allocate space for the overall attention matrix where the chunks are combined. The last dimension # allocate space for the overall attention matrix where the chunks are combined. The last dimension
...@@ -426,7 +432,10 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -426,7 +432,10 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# - copying the lower triangle # - copying the lower triangle
diagonal_attn_scores_low_triang = tf.concat( diagonal_attn_scores_low_triang = tf.concat(
[ [
tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)), tf.zeros(
(batch_size * num_heads, 1, window_overlap, window_overlap),
dtype=diagonal_chunked_attention_scores.dtype,
),
diagonal_chunked_attention_scores[:, :, -(window_overlap + 1) : -1, window_overlap + 1 :], diagonal_chunked_attention_scores[:, :, -(window_overlap + 1) : -1, window_overlap + 1 :],
], ],
axis=1, axis=1,
...@@ -438,7 +447,10 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -438,7 +447,10 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
shift=[1, window_overlap], shift=[1, window_overlap],
axis=[2, 3], axis=[2, 3],
)[:, :, :window_overlap, :window_overlap], )[:, :, :window_overlap, :window_overlap],
tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)), tf.zeros(
(batch_size * num_heads, 1, window_overlap, window_overlap),
dtype=diagonal_chunked_attention_scores.dtype,
),
], ],
axis=1, axis=1,
) )
...@@ -496,7 +508,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -496,7 +508,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1)) mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1))
# inf tensor used for masking # inf tensor used for masking
inf_tensor = -float("inf") * tf.ones_like(input_tensor, dtype=tf.dtypes.float32) inf_tensor = -float("inf") * tf.ones_like(input_tensor)
# mask # mask
input_tensor = tf.where(tf.math.greater(mask_4d, 0), inf_tensor, input_tensor) input_tensor = tf.where(tf.math.greater(mask_4d, 0), inf_tensor, input_tensor)
...@@ -511,6 +523,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -511,6 +523,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
batch_size, seq_len, num_heads, head_dim = shape_list(value) batch_size, seq_len, num_heads, head_dim = shape_list(value)
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
seq_len % (window_overlap * 2), seq_len % (window_overlap * 2),
0, 0,
...@@ -547,7 +560,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -547,7 +560,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
) )
# pad seq_len with w at the beginning of the sequence and another window overlap at the end # pad seq_len with w at the beginning of the sequence and another window overlap at the end
paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap], [0, 0]], dtype=tf.dtypes.int32) paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap], [0, 0]])
padded_value = tf.pad(value, paddings, constant_values=-1) padded_value = tf.pad(value, paddings, constant_values=-1)
# chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
...@@ -563,6 +576,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -563,6 +576,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim), (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
) )
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(chunked_value), shape_list(chunked_value),
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim], [batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
...@@ -640,6 +654,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -640,6 +654,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# chunk with overlap # chunk with overlap
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size) chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size)
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(chunked_hidden_states), shape_list(chunked_hidden_states),
[batch_size, num_output_chunks, frame_size], [batch_size, num_output_chunks, frame_size],
...@@ -657,7 +672,8 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -657,7 +672,8 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
def _get_global_attn_indices(is_index_global_attn): def _get_global_attn_indices(is_index_global_attn):
""" compute global attn indices required throughout forward pass """ """ compute global attn indices required throughout forward pass """
# helper variable # helper variable
num_global_attn_indices = tf.reduce_sum(tf.cast(is_index_global_attn, dtype=tf.dtypes.int32), axis=1) num_global_attn_indices = tf.math.count_nonzero(is_index_global_attn, axis=1)
num_global_attn_indices = tf.cast(num_global_attn_indices, dtype=tf.constant(1).dtype)
# max number of global attn indices in batch # max number of global attn indices in batch
max_num_global_attn_indices = tf.reduce_max(num_global_attn_indices) max_num_global_attn_indices = tf.reduce_max(num_global_attn_indices)
...@@ -719,6 +735,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -719,6 +735,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
shape_list(attn_probs_from_global_key_trans)[-2:] shape_list(attn_probs_from_global_key_trans)[-2:]
) )
mask = tf.ones(mask_shape) * -10000.0 mask = tf.ones(mask_shape) * -10000.0
mask = tf.cast(mask, dtype=attn_probs_from_global_key_trans.dtype)
# scatter mask # scatter mask
attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update( attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update(
...@@ -805,7 +822,9 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -805,7 +822,9 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
global_value_vectors = self.value_global(hidden_states) global_value_vectors = self.value_global(hidden_states)
# normalize # normalize
global_query_vectors_only_global /= tf.math.sqrt(tf.convert_to_tensor(self.head_dim, dtype=tf.dtypes.float32)) global_query_vectors_only_global /= tf.math.sqrt(
tf.cast(self.head_dim, dtype=global_query_vectors_only_global.dtype)
)
global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size) global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size)
global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size) global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size)
global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size) global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size)
...@@ -813,6 +832,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -813,6 +832,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# compute attn scores # compute attn scores
global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True) global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(global_attn_scores), shape_list(global_attn_scores),
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len], [batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
...@@ -828,6 +848,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -828,6 +848,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
shape_list(global_attn_scores_trans)[-2:] shape_list(global_attn_scores_trans)[-2:]
) )
global_attn_mask = tf.ones(mask_shape) * -10000.0 global_attn_mask = tf.ones(mask_shape) * -10000.0
global_attn_mask = tf.cast(global_attn_mask, dtype=global_attn_scores_trans.dtype)
# scatter mask # scatter mask
global_attn_scores_trans = tf.tensor_scatter_nd_update( global_attn_scores_trans = tf.tensor_scatter_nd_update(
...@@ -850,6 +871,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -850,6 +871,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# apply layer head maskin # apply layer head maskin
if layer_head_mask is not None: if layer_head_mask is not None:
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(layer_head_mask), shape_list(layer_head_mask),
[self.num_heads], [self.num_heads],
...@@ -868,6 +890,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -868,6 +890,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# global attn output # global attn output
global_attn_output = tf.matmul(global_attn_probs, global_value_vectors) global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(global_attn_output), shape_list(global_attn_output),
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim], [batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
...@@ -1023,6 +1046,7 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer): ...@@ -1023,6 +1046,7 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1] src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True) attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_weights), shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len], [bsz * self.num_heads, tgt_len, src_len],
...@@ -1030,22 +1054,28 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer): ...@@ -1030,22 +1054,28 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
) )
if attention_mask is not None: if attention_mask is not None:
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attention_mask), shape_list(attention_mask),
[bsz, 1, tgt_len, src_len], [bsz, 1, tgt_len, src_len],
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}", message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
) )
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + tf.cast(
attention_mask, dtype=attn_weights.dtype
)
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1) attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(layer_head_mask), shape_list(layer_head_mask),
[self.num_heads], [self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
) )
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len) attn_weights, (bsz, self.num_heads, tgt_len, src_len)
) )
...@@ -1055,6 +1085,7 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer): ...@@ -1055,6 +1085,7 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
attn_output = tf.matmul(attn_probs, value_states) attn_output = tf.matmul(attn_probs, value_states)
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_output), shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim], [bsz * self.num_heads, tgt_len, self.head_dim],
...@@ -1111,6 +1142,7 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer): ...@@ -1111,6 +1142,7 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer):
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(hidden_states), shape_list(hidden_states),
shape_list(residual), shape_list(residual),
...@@ -1707,12 +1739,13 @@ class TFLEDEncoder(tf.keras.layers.Layer): ...@@ -1707,12 +1739,13 @@ class TFLEDEncoder(tf.keras.layers.Layer):
all_attentions = all_global_attentions = () if inputs["output_attentions"] else None all_attentions = all_global_attentions = () if inputs["output_attentions"] else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None: if inputs["head_mask"] is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0], shape_list(inputs["head_mask"])[0],
len(self.layers), len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
) )
# encoder layers # encoder layers
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
...@@ -1981,12 +2014,13 @@ class TFLEDDecoder(tf.keras.layers.Layer): ...@@ -1981,12 +2014,13 @@ class TFLEDDecoder(tf.keras.layers.Layer):
present_key_values = () present_key_values = ()
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None: if inputs["head_mask"] is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0], shape_list(inputs["head_mask"])[0],
len(self.layers), len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
) )
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
......
...@@ -392,23 +392,22 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se ...@@ -392,23 +392,22 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se
""" """
assert shape_list(sep_token_indices)[1] == 2, "`input_ids` should have two dimensions" assert shape_list(sep_token_indices)[1] == 2, "`input_ids` should have two dimensions"
question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1] question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1][:, None]
question_end_index = tf.cast(question_end_index[:, None], tf.dtypes.int32) # size: batch_size x 1
# bool attention mask with True in locations of global attention # bool attention mask with True in locations of global attention
attention_mask = tf.range(input_ids_shape[1])[tf.newaxis, :] attention_mask = tf.expand_dims(tf.range(input_ids_shape[1]), axis=0)
attention_mask = tf.tile(attention_mask, (input_ids_shape[0], 1)) attention_mask = tf.tile(attention_mask, (input_ids_shape[0], 1))
if before_sep_token is True: if before_sep_token is True:
question_end_index = tf.tile(question_end_index, (1, input_ids_shape[1])) question_end_index = tf.tile(question_end_index, (1, input_ids_shape[1]))
attention_mask = tf.cast(attention_mask < question_end_index, tf.int32) attention_mask = tf.cast(attention_mask < question_end_index, dtype=question_end_index.dtype)
else: else:
# last token is separation token and should not be counted and in the middle are two separation tokens # last token is separation token and should not be counted and in the middle are two separation tokens
question_end_index = tf.tile(question_end_index + 1, (1, input_ids_shape[1])) question_end_index = tf.tile(question_end_index + 1, (1, input_ids_shape[1]))
attention_mask = ( attention_mask = (
tf.cast( tf.cast(
attention_mask > question_end_index, attention_mask > question_end_index,
tf.dtypes.int32, dtype=question_end_index.dtype,
) )
* tf.cast(attention_mask < input_ids_shape[-1], tf.dtypes.int32) * tf.cast(attention_mask < input_ids_shape[-1], dtype=question_end_index.dtype)
) )
return attention_mask return attention_mask
...@@ -730,6 +729,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -730,6 +729,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
value_vectors = self.value(hidden_states) value_vectors = self.value(hidden_states)
batch_size, seq_len, embed_dim = shape_list(hidden_states) batch_size, seq_len, embed_dim = shape_list(hidden_states)
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
embed_dim, embed_dim,
self.embed_dim, self.embed_dim,
...@@ -737,7 +737,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -737,7 +737,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) )
# normalize query # normalize query
query_vectors /= tf.math.sqrt(tf.convert_to_tensor(self.head_dim, dtype=tf.dtypes.float32)) query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype))
query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
...@@ -748,7 +748,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -748,7 +748,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# diagonal mask with zeros everywhere and -inf inplace of padding # diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask = self._sliding_chunks_query_key_matmul( diagonal_mask = self._sliding_chunks_query_key_matmul(
tf.ones(shape_list(attention_mask), dtype=tf.float32), tf.ones(shape_list(attention_mask)),
attention_mask, attention_mask,
self.one_sided_attn_window_size, self.one_sided_attn_window_size,
) )
...@@ -756,6 +756,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -756,6 +756,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# pad local attention probs # pad local attention probs
attn_scores += diagonal_mask attn_scores += diagonal_mask
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_scores), shape_list(attn_scores),
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1], [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
...@@ -803,16 +804,18 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -803,16 +804,18 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) )
attn_probs = tf.where( attn_probs = tf.where(
masked_index, masked_index,
tf.zeros(shape_list(masked_index), dtype=tf.dtypes.float32), tf.zeros(shape_list(masked_index), dtype=attn_probs.dtype),
attn_probs, attn_probs,
) )
if layer_head_mask is not None: if layer_head_mask is not None:
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(layer_head_mask), shape_list(layer_head_mask),
[self.num_heads], [self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
) )
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
# apply dropout # apply dropout
...@@ -834,6 +837,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -834,6 +837,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
), ),
) )
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_output), shape_list(attn_output),
[batch_size, seq_len, self.num_heads, self.head_dim], [batch_size, seq_len, self.num_heads, self.head_dim],
...@@ -877,7 +881,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -877,7 +881,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) )
attn_probs = tf.where( attn_probs = tf.where(
masked_global_attn_index, masked_global_attn_index,
tf.zeros(shape_list(masked_global_attn_index), dtype=tf.dtypes.float32), tf.zeros(shape_list(masked_global_attn_index), dtype=attn_probs.dtype),
attn_probs, attn_probs,
) )
...@@ -893,6 +897,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -893,6 +897,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
""" """
batch_size, seq_len, num_heads, head_dim = shape_list(query) batch_size, seq_len, num_heads, head_dim = shape_list(query)
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
seq_len % (window_overlap * 2), seq_len % (window_overlap * 2),
0, 0,
...@@ -919,10 +924,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -919,10 +924,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap # bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap
chunked_query = tf.cast(chunked_query, dtype=chunked_key.dtype)
chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query, chunked_key) # multiply chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query, chunked_key) # multiply
# convert diagonals into columns # convert diagonals into columns
paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32) paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]])
diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(chunked_attention_scores, paddings) diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(chunked_attention_scores, paddings)
# allocate space for the overall attention matrix where the chunks are combined. The last dimension # allocate space for the overall attention matrix where the chunks are combined. The last dimension
...@@ -944,7 +950,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -944,7 +950,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# - copying the lower triangle # - copying the lower triangle
diagonal_attn_scores_low_triang = tf.concat( diagonal_attn_scores_low_triang = tf.concat(
[ [
tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)), tf.zeros(
(batch_size * num_heads, 1, window_overlap, window_overlap),
dtype=diagonal_chunked_attention_scores.dtype,
),
diagonal_chunked_attention_scores[:, :, -(window_overlap + 1) : -1, window_overlap + 1 :], diagonal_chunked_attention_scores[:, :, -(window_overlap + 1) : -1, window_overlap + 1 :],
], ],
axis=1, axis=1,
...@@ -956,7 +965,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -956,7 +965,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
shift=[1, window_overlap], shift=[1, window_overlap],
axis=[2, 3], axis=[2, 3],
)[:, :, :window_overlap, :window_overlap], )[:, :, :window_overlap, :window_overlap],
tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)), tf.zeros(
(batch_size * num_heads, 1, window_overlap, window_overlap),
dtype=diagonal_chunked_attention_scores.dtype,
),
], ],
axis=1, axis=1,
) )
...@@ -1014,7 +1026,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -1014,7 +1026,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1)) mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1))
# inf tensor used for masking # inf tensor used for masking
inf_tensor = -float("inf") * tf.ones_like(input_tensor, dtype=tf.dtypes.float32) inf_tensor = -float("inf") * tf.ones_like(input_tensor)
# mask # mask
input_tensor = tf.where(tf.math.greater(mask_4d, 0), inf_tensor, input_tensor) input_tensor = tf.where(tf.math.greater(mask_4d, 0), inf_tensor, input_tensor)
...@@ -1029,6 +1041,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -1029,6 +1041,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
batch_size, seq_len, num_heads, head_dim = shape_list(value) batch_size, seq_len, num_heads, head_dim = shape_list(value)
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
seq_len % (window_overlap * 2), seq_len % (window_overlap * 2),
0, 0,
...@@ -1065,7 +1078,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -1065,7 +1078,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) )
# pad seq_len with w at the beginning of the sequence and another window overlap at the end # pad seq_len with w at the beginning of the sequence and another window overlap at the end
paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap], [0, 0]], dtype=tf.dtypes.int32) paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap], [0, 0]])
padded_value = tf.pad(value, paddings, constant_values=-1) padded_value = tf.pad(value, paddings, constant_values=-1)
# chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
...@@ -1081,6 +1094,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -1081,6 +1094,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim), (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
) )
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(chunked_value), shape_list(chunked_value),
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim], [batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
...@@ -1158,6 +1172,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -1158,6 +1172,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# chunk with overlap # chunk with overlap
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size) chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size)
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(chunked_hidden_states), shape_list(chunked_hidden_states),
[batch_size, num_output_chunks, frame_size], [batch_size, num_output_chunks, frame_size],
...@@ -1175,7 +1190,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -1175,7 +1190,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
def _get_global_attn_indices(is_index_global_attn): def _get_global_attn_indices(is_index_global_attn):
""" compute global attn indices required throughout forward pass """ """ compute global attn indices required throughout forward pass """
# helper variable # helper variable
num_global_attn_indices = tf.reduce_sum(tf.cast(is_index_global_attn, dtype=tf.dtypes.int32), axis=1) num_global_attn_indices = tf.math.count_nonzero(is_index_global_attn, axis=1)
num_global_attn_indices = tf.cast(num_global_attn_indices, dtype=tf.constant(1).dtype)
# max number of global attn indices in batch # max number of global attn indices in batch
max_num_global_attn_indices = tf.reduce_max(num_global_attn_indices) max_num_global_attn_indices = tf.reduce_max(num_global_attn_indices)
...@@ -1237,6 +1253,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -1237,6 +1253,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
shape_list(attn_probs_from_global_key_trans)[-2:] shape_list(attn_probs_from_global_key_trans)[-2:]
) )
mask = tf.ones(mask_shape) * -10000.0 mask = tf.ones(mask_shape) * -10000.0
mask = tf.cast(mask, dtype=attn_probs_from_global_key_trans.dtype)
# scatter mask # scatter mask
attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update( attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update(
...@@ -1323,7 +1340,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -1323,7 +1340,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
global_value_vectors = self.value_global(hidden_states) global_value_vectors = self.value_global(hidden_states)
# normalize # normalize
global_query_vectors_only_global /= tf.math.sqrt(tf.convert_to_tensor(self.head_dim, dtype=tf.dtypes.float32)) global_query_vectors_only_global /= tf.math.sqrt(
tf.cast(self.head_dim, dtype=global_query_vectors_only_global.dtype)
)
global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size) global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size)
global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size) global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size)
global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size) global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size)
...@@ -1331,6 +1350,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -1331,6 +1350,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# compute attn scores # compute attn scores
global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True) global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(global_attn_scores), shape_list(global_attn_scores),
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len], [batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
...@@ -1346,6 +1366,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -1346,6 +1366,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
shape_list(global_attn_scores_trans)[-2:] shape_list(global_attn_scores_trans)[-2:]
) )
global_attn_mask = tf.ones(mask_shape) * -10000.0 global_attn_mask = tf.ones(mask_shape) * -10000.0
global_attn_mask = tf.cast(global_attn_mask, dtype=global_attn_scores_trans.dtype)
# scatter mask # scatter mask
global_attn_scores_trans = tf.tensor_scatter_nd_update( global_attn_scores_trans = tf.tensor_scatter_nd_update(
...@@ -1368,6 +1389,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -1368,6 +1389,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# apply layer head maskin # apply layer head maskin
if layer_head_mask is not None: if layer_head_mask is not None:
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(layer_head_mask), shape_list(layer_head_mask),
[self.num_heads], [self.num_heads],
...@@ -1386,6 +1408,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -1386,6 +1408,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# global attn output # global attn output
global_attn_output = tf.matmul(global_attn_probs, global_value_vectors) global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(global_attn_output), shape_list(global_attn_output),
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim], [batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
...@@ -2230,6 +2253,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn ...@@ -2230,6 +2253,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
logger.info("Initializing global attention on question tokens...") logger.info("Initializing global attention on question tokens...")
# put global attention on all tokens until `config.sep_token_id` is reached # put global attention on all tokens until `config.sep_token_id` is reached
sep_token_indices = tf.where(inputs["input_ids"] == self.config.sep_token_id) sep_token_indices = tf.where(inputs["input_ids"] == self.config.sep_token_id)
sep_token_indices = tf.cast(sep_token_indices, dtype=inputs["input_ids"].dtype)
inputs["global_attention_mask"] = _compute_global_attention_mask( inputs["global_attention_mask"] = _compute_global_attention_mask(
shape_list(inputs["input_ids"]), sep_token_indices shape_list(inputs["input_ids"]), sep_token_indices
) )
......
...@@ -362,10 +362,6 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -362,10 +362,6 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
self.assertEqual(model.config.output_hidden_states, True) self.assertEqual(model.config.output_hidden_states, True)
check_encoder_attentions_output(outputs) check_encoder_attentions_output(outputs)
def test_mixed_precision(self):
# TODO JP: Make LED float16 compliant
pass
def test_xla_mode(self): def test_xla_mode(self):
# TODO JP: Make LED XLA compliant # TODO JP: Make LED XLA compliant
pass pass
......
...@@ -343,10 +343,6 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -343,10 +343,6 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI # This test is too long (>30sec) and makes fail the CI
pass pass
def test_mixed_precision(self):
# TODO JP: Make Longformer float16 compliant
pass
def test_xla_mode(self): def test_xla_mode(self):
# TODO JP: Make Longformer XLA compliant # TODO JP: Make Longformer XLA compliant
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment