Unverified Commit 19e737b9 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Making TF Longformer-like models compliant with AMP (#10233)

* AMP

* Add LED

* Apply style

* Fix longformer
parent cd8c4c3f
......@@ -392,23 +392,22 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se
"""
assert shape_list(sep_token_indices)[1] == 2, "`input_ids` should have two dimensions"
question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1]
question_end_index = tf.cast(question_end_index[:, None], tf.dtypes.int32) # size: batch_size x 1
question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1][:, None]
# bool attention mask with True in locations of global attention
attention_mask = tf.range(input_ids_shape[1])[tf.newaxis, :]
attention_mask = tf.expand_dims(tf.range(input_ids_shape[1]), axis=0)
attention_mask = tf.tile(attention_mask, (input_ids_shape[0], 1))
if before_sep_token is True:
question_end_index = tf.tile(question_end_index, (1, input_ids_shape[1]))
attention_mask = tf.cast(attention_mask < question_end_index, tf.int32)
attention_mask = tf.cast(attention_mask < question_end_index, dtype=question_end_index.dtype)
else:
# last token is separation token and should not be counted and in the middle are two separation tokens
question_end_index = tf.tile(question_end_index + 1, (1, input_ids_shape[1]))
attention_mask = (
tf.cast(
attention_mask > question_end_index,
tf.dtypes.int32,
dtype=question_end_index.dtype,
)
* tf.cast(attention_mask < input_ids_shape[-1], tf.dtypes.int32)
* tf.cast(attention_mask < input_ids_shape[-1], dtype=question_end_index.dtype)
)
return attention_mask
......@@ -730,14 +729,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
value_vectors = self.value(hidden_states)
batch_size, seq_len, embed_dim = shape_list(hidden_states)
tf.debugging.assert_equal(
embed_dim,
self.embed_dim,
message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}",
)
if tf.executing_eagerly():
tf.debugging.assert_equal(
embed_dim,
self.embed_dim,
message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}",
)
# normalize query
query_vectors /= tf.math.sqrt(tf.convert_to_tensor(self.head_dim, dtype=tf.dtypes.float32))
query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype))
query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
......@@ -748,7 +748,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask = self._sliding_chunks_query_key_matmul(
tf.ones(shape_list(attention_mask), dtype=tf.float32),
tf.ones(shape_list(attention_mask)),
attention_mask,
self.one_sided_attn_window_size,
)
......@@ -756,11 +756,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# pad local attention probs
attn_scores += diagonal_mask
tf.debugging.assert_equal(
shape_list(attn_scores),
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
message=f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}",
)
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_scores),
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
message=f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}",
)
# compute global attn indices required through out forward fn
(
......@@ -803,16 +804,18 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
attn_probs = tf.where(
masked_index,
tf.zeros(shape_list(masked_index), dtype=tf.dtypes.float32),
tf.zeros(shape_list(masked_index), dtype=attn_probs.dtype),
attn_probs,
)
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
# apply dropout
......@@ -834,11 +837,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
),
)
tf.debugging.assert_equal(
shape_list(attn_output),
[batch_size, seq_len, self.num_heads, self.head_dim],
message="Unexpected size",
)
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[batch_size, seq_len, self.num_heads, self.head_dim],
message="Unexpected size",
)
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
......@@ -877,7 +881,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
attn_probs = tf.where(
masked_global_attn_index,
tf.zeros(shape_list(masked_global_attn_index), dtype=tf.dtypes.float32),
tf.zeros(shape_list(masked_global_attn_index), dtype=attn_probs.dtype),
attn_probs,
)
......@@ -893,16 +897,17 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
"""
batch_size, seq_len, num_heads, head_dim = shape_list(query)
tf.debugging.assert_equal(
seq_len % (window_overlap * 2),
0,
message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}",
)
tf.debugging.assert_equal(
shape_list(query),
shape_list(key),
message=f"Shape of query and key should be equal, but got query: {shape_list(query)} and key: {shape_list(key)}",
)
if tf.executing_eagerly():
tf.debugging.assert_equal(
seq_len % (window_overlap * 2),
0,
message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}",
)
tf.debugging.assert_equal(
shape_list(query),
shape_list(key),
message=f"Shape of query and key should be equal, but got query: {shape_list(query)} and key: {shape_list(key)}",
)
chunks_count = seq_len // window_overlap - 1
......@@ -919,10 +924,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap
chunked_query = tf.cast(chunked_query, dtype=chunked_key.dtype)
chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query, chunked_key) # multiply
# convert diagonals into columns
paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]])
diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(chunked_attention_scores, paddings)
# allocate space for the overall attention matrix where the chunks are combined. The last dimension
......@@ -944,7 +950,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# - copying the lower triangle
diagonal_attn_scores_low_triang = tf.concat(
[
tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)),
tf.zeros(
(batch_size * num_heads, 1, window_overlap, window_overlap),
dtype=diagonal_chunked_attention_scores.dtype,
),
diagonal_chunked_attention_scores[:, :, -(window_overlap + 1) : -1, window_overlap + 1 :],
],
axis=1,
......@@ -956,7 +965,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
shift=[1, window_overlap],
axis=[2, 3],
)[:, :, :window_overlap, :window_overlap],
tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)),
tf.zeros(
(batch_size * num_heads, 1, window_overlap, window_overlap),
dtype=diagonal_chunked_attention_scores.dtype,
),
],
axis=1,
)
......@@ -1014,7 +1026,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1))
# inf tensor used for masking
inf_tensor = -float("inf") * tf.ones_like(input_tensor, dtype=tf.dtypes.float32)
inf_tensor = -float("inf") * tf.ones_like(input_tensor)
# mask
input_tensor = tf.where(tf.math.greater(mask_4d, 0), inf_tensor, input_tensor)
......@@ -1029,21 +1041,22 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
batch_size, seq_len, num_heads, head_dim = shape_list(value)
tf.debugging.assert_equal(
seq_len % (window_overlap * 2),
0,
message="Seq_len has to be multiple of 2 * window_overlap",
)
tf.debugging.assert_equal(
shape_list(attn_probs)[:3],
shape_list(value)[:3],
message="value and attn_probs must have same dims (except head_dim)",
)
tf.debugging.assert_equal(
shape_list(attn_probs)[3],
2 * window_overlap + 1,
message="attn_probs last dim has to be 2 * window_overlap + 1",
)
if tf.executing_eagerly():
tf.debugging.assert_equal(
seq_len % (window_overlap * 2),
0,
message="Seq_len has to be multiple of 2 * window_overlap",
)
tf.debugging.assert_equal(
shape_list(attn_probs)[:3],
shape_list(value)[:3],
message="value and attn_probs must have same dims (except head_dim)",
)
tf.debugging.assert_equal(
shape_list(attn_probs)[3],
2 * window_overlap + 1,
message="attn_probs last dim has to be 2 * window_overlap + 1",
)
chunks_count = seq_len // window_overlap - 1
......@@ -1065,7 +1078,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
# pad seq_len with w at the beginning of the sequence and another window overlap at the end
paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap], [0, 0]], dtype=tf.dtypes.int32)
paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap], [0, 0]])
padded_value = tf.pad(value, paddings, constant_values=-1)
# chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
......@@ -1081,11 +1094,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
)
tf.debugging.assert_equal(
shape_list(chunked_value),
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
message="Chunked value has the wrong shape",
)
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(chunked_value),
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
message="Chunked value has the wrong shape",
)
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
......@@ -1158,11 +1172,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# chunk with overlap
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size)
tf.debugging.assert_equal(
shape_list(chunked_hidden_states),
[batch_size, num_output_chunks, frame_size],
message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}.",
)
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(chunked_hidden_states),
[batch_size, num_output_chunks, frame_size],
message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}.",
)
chunked_hidden_states = tf.reshape(
chunked_hidden_states,
......@@ -1175,7 +1190,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
def _get_global_attn_indices(is_index_global_attn):
""" compute global attn indices required throughout forward pass """
# helper variable
num_global_attn_indices = tf.reduce_sum(tf.cast(is_index_global_attn, dtype=tf.dtypes.int32), axis=1)
num_global_attn_indices = tf.math.count_nonzero(is_index_global_attn, axis=1)
num_global_attn_indices = tf.cast(num_global_attn_indices, dtype=tf.constant(1).dtype)
# max number of global attn indices in batch
max_num_global_attn_indices = tf.reduce_max(num_global_attn_indices)
......@@ -1237,6 +1253,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
shape_list(attn_probs_from_global_key_trans)[-2:]
)
mask = tf.ones(mask_shape) * -10000.0
mask = tf.cast(mask, dtype=attn_probs_from_global_key_trans.dtype)
# scatter mask
attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update(
......@@ -1323,7 +1340,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
global_value_vectors = self.value_global(hidden_states)
# normalize
global_query_vectors_only_global /= tf.math.sqrt(tf.convert_to_tensor(self.head_dim, dtype=tf.dtypes.float32))
global_query_vectors_only_global /= tf.math.sqrt(
tf.cast(self.head_dim, dtype=global_query_vectors_only_global.dtype)
)
global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size)
global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size)
global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size)
......@@ -1331,11 +1350,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# compute attn scores
global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
tf.debugging.assert_equal(
shape_list(global_attn_scores),
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
message=f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {shape_list(global_attn_scores)}.",
)
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(global_attn_scores),
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
message=f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {shape_list(global_attn_scores)}.",
)
global_attn_scores = tf.reshape(
global_attn_scores,
......@@ -1346,6 +1366,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
shape_list(global_attn_scores_trans)[-2:]
)
global_attn_mask = tf.ones(mask_shape) * -10000.0
global_attn_mask = tf.cast(global_attn_mask, dtype=global_attn_scores_trans.dtype)
# scatter mask
global_attn_scores_trans = tf.tensor_scatter_nd_update(
......@@ -1368,11 +1389,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# apply layer head maskin
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
)
......@@ -1386,11 +1408,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# global attn output
global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
tf.debugging.assert_equal(
shape_list(global_attn_output),
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {shape_list(global_attn_output)}.",
)
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(global_attn_output),
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {shape_list(global_attn_output)}.",
)
global_attn_output = tf.reshape(
global_attn_output,
......@@ -2230,6 +2253,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
logger.info("Initializing global attention on question tokens...")
# put global attention on all tokens until `config.sep_token_id` is reached
sep_token_indices = tf.where(inputs["input_ids"] == self.config.sep_token_id)
sep_token_indices = tf.cast(sep_token_indices, dtype=inputs["input_ids"].dtype)
inputs["global_attention_mask"] = _compute_global_attention_mask(
shape_list(inputs["input_ids"]), sep_token_indices
)
......
......@@ -362,10 +362,6 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
self.assertEqual(model.config.output_hidden_states, True)
check_encoder_attentions_output(outputs)
def test_mixed_precision(self):
# TODO JP: Make LED float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make LED XLA compliant
pass
......
......@@ -343,10 +343,6 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def test_mixed_precision(self):
# TODO JP: Make Longformer float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make Longformer XLA compliant
pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment