Unverified Commit 858b7d58 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[TF Longformer] Improve Speed for TF Longformer (#6447)

* add tf graph compile tests

* fix conflict

* remove more tf transpose statements

* fix conflicts

* fix comment typos

* move function to class function

* fix black

* fix black

* make style
parent a75c64d8
......@@ -1088,7 +1088,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-cased",
......
......@@ -677,7 +677,7 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
self.electra = TFElectraMainLayer(config, name="electra")
self.classifier = TFElectraClassificationHead(config, name="classifier")
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/electra-small-discriminator",
......
......@@ -97,24 +97,36 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
self.embed_dim = config.hidden_size
self.query = tf.keras.layers.Dense(
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="query"
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="query",
)
self.key = tf.keras.layers.Dense(
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="key"
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="key",
)
self.value = tf.keras.layers.Dense(
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="value"
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="value",
)
# separate projection layers for tokens with global attention
self.query_global = tf.keras.layers.Dense(
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="query_global"
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="query_global",
)
self.key_global = tf.keras.layers.Dense(
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="key_global"
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="key_global",
)
self.value_global = tf.keras.layers.Dense(
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="value_global"
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="value_global",
)
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
......@@ -148,23 +160,21 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
"""
# retrieve input args
hidden_states, attention_mask, output_attentions = inputs
attention_mask = tf.squeeze(tf.squeeze(attention_mask, axis=2), axis=1)
# is index masked or global attention
is_index_masked = tf.math.less(attention_mask, 0)
is_index_global_attn = tf.math.greater(attention_mask, 0)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
hidden_states = tf.transpose(hidden_states, (1, 0, 2))
(
hidden_states,
attention_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
output_attentions,
) = inputs
# project hidden states
query_vectors = self.query(hidden_states)
key_vectors = self.key(hidden_states)
value_vectors = self.value(hidden_states)
seq_len, batch_size, embed_dim = shape_list(hidden_states)
batch_size, seq_len, embed_dim = shape_list(hidden_states)
tf.debugging.assert_equal(
embed_dim,
self.embed_dim,
......@@ -174,24 +184,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# normalize query
query_vectors /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32))
query_vectors = tf.transpose(
tf.reshape(query_vectors, (seq_len, batch_size, self.num_heads, self.head_dim)), (1, 0, 2, 3)
)
key_vectors = tf.transpose(
tf.reshape(key_vectors, (seq_len, batch_size, self.num_heads, self.head_dim)), (1, 0, 2, 3)
)
query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
# attn_probs = (batch_size, seq_len, num_heads, window*2+1)
attn_scores = self._sliding_chunks_query_key_matmul(
query_vectors, key_vectors, self.one_sided_attn_window_size
)
# values to pad for attention probs
float_mask = tf.cast((attention_mask != 0)[:, :, None, None], dtype=tf.float32) * -10000.0
# diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask = self._sliding_chunks_query_key_matmul(
tf.ones(shape_list(float_mask), dtype=tf.float32), float_mask, self.one_sided_attn_window_size
tf.ones(shape_list(attention_mask), dtype=tf.float32),
attention_mask,
self.one_sided_attn_window_size,
)
# pad local attention probs
......@@ -231,15 +236,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
attn_probs = tf.where(
tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)), 0.0, attn_probs
tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)),
0.0,
attn_probs,
)
# apply dropout
attn_probs = self.dropout(attn_probs, training=training)
value_vectors = tf.transpose(
tf.reshape(value_vectors, (seq_len, batch_size, self.num_heads, self.head_dim)), (1, 0, 2, 3)
)
value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
# if global attention, compute sum of global and local attn
attn_output = tf.cond(
......@@ -257,9 +262,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
tf.debugging.assert_equal(
shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size"
shape_list(attn_output),
[batch_size, seq_len, self.num_heads, self.head_dim],
message="Unexpected size",
)
attn_output = tf.reshape(tf.transpose(attn_output, (1, 0, 2, 3)), (seq_len, batch_size, embed_dim))
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
# compute value for global attention and overwrite to attention output
# TODO: remove the redundant computation
......@@ -278,8 +285,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
lambda: attn_output,
)
attn_output = tf.transpose(attn_output, (1, 0, 2))
# GLOBAL ATTN:
# With global attention, return global attention probabilities only
# batch_size x num_heads x max_num_global_attention_tokens x sequence_length
......@@ -294,7 +299,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
attn_probs = tf.cond(
is_global_attn,
lambda: self._get_global_attn_probs(attn_probs, max_num_global_attn_indices),
lambda: tf.transpose(attn_probs, (0, 2, 1, 3)),
lambda: attn_probs,
)
outputs = (attn_output, attn_probs)
......@@ -310,7 +315,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
],
axis=-1,
)
attn_probs = tf.transpose(attn_probs, (0, 2, 1, 3))
return attn_probs
def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
......@@ -332,7 +336,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
chunks_count = seq_len // window_overlap - 1
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
query = tf.reshape(tf.transpose(query, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim))
query = tf.reshape(
tf.transpose(query, (0, 2, 1, 3)),
(batch_size * num_heads, seq_len, head_dim),
)
key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim))
chunked_query = self._chunk(query, window_overlap)
......@@ -374,9 +381,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
diagonal_attn_scores_first_chunk = tf.concat(
[
tf.roll(diagonal_chunked_attention_scores, shift=[1, window_overlap], axis=[2, 3])[
:, :, :window_overlap, :window_overlap
],
tf.roll(
diagonal_chunked_attention_scores,
shift=[1, window_overlap],
axis=[2, 3],
)[:, :, :window_overlap, :window_overlap],
tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)),
],
axis=1,
......@@ -385,13 +394,20 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
first_chunk_mask = (
tf.broadcast_to(
tf.range(chunks_count + 1)[None, :, None, None],
shape=(batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap),
shape=(
batch_size * num_heads,
chunks_count + 1,
window_overlap,
window_overlap,
),
)
< 1
)
diagonal_attn_scores_low_triang = tf.where(
first_chunk_mask, diagonal_attn_scores_first_chunk, diagonal_attn_scores_low_triang
first_chunk_mask,
diagonal_attn_scores_first_chunk,
diagonal_attn_scores_low_triang,
)
# merging upper and lower triangle
......@@ -401,7 +417,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# separate batch_size and num_heads dimensions again
diagonal_attention_scores = tf.transpose(
tf.reshape(diagonal_attention_scores, (batch_size, num_heads, seq_len, 2 * window_overlap + 1)),
tf.reshape(
diagonal_attention_scores,
(batch_size, num_heads, seq_len, 2 * window_overlap + 1),
),
(0, 2, 1, 3),
)
......@@ -412,7 +431,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
def _mask_invalid_locations(input_tensor, window_overlap):
# create correct upper triangle bool mask
mask_2d_upper = tf.reverse(
tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0), axis=[0]
tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0),
axis=[0],
)
# pad to full matrix
padding = tf.constant(
......@@ -443,7 +463,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
batch_size, seq_len, num_heads, head_dim = shape_list(value)
tf.debugging.assert_equal(
seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap"
seq_len % (window_overlap * 2),
0,
message="Seq_len has to be multiple of 2 * window_overlap",
)
tf.debugging.assert_equal(
shape_list(attn_probs)[:3],
......@@ -461,11 +483,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
chunked_attn_probs = tf.reshape(
tf.transpose(attn_probs, (0, 2, 1, 3)),
(batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1),
(
batch_size * num_heads,
seq_len // window_overlap,
window_overlap,
2 * window_overlap + 1,
),
)
# group batch_size and num_heads dimensions into one
value = tf.reshape(tf.transpose(value, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim))
value = tf.reshape(
tf.transpose(value, (0, 2, 1, 3)),
(batch_size * num_heads, seq_len, head_dim),
)
# pad seq_len with w at the beginning of the sequence and another window overlap at the end
......@@ -478,10 +508,13 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count
chunked_value = tf.signal.frame(
tf.reshape(padded_value, (batch_size * num_heads, -1)), frame_size, frame_hop_size
tf.reshape(padded_value, (batch_size * num_heads, -1)),
frame_size,
frame_hop_size,
)
chunked_value = tf.reshape(
chunked_value, (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim)
chunked_value,
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
)
tf.debugging.assert_equal(
......@@ -493,7 +526,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
context = tf.transpose(tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)), (0, 2, 1, 3))
context = tf.transpose(
tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)),
(0, 2, 1, 3),
)
return context
@staticmethod
......@@ -502,12 +538,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
hidden_states_padded = tf.pad(
hidden_states_padded, paddings
) # padding value is not important because it will be overwritten
if len(shape_list(hidden_states_padded)) > 3:
batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded)
hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length))
else:
batch_size, seq_length, hidden_dim = shape_list(hidden_states_padded)
hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, hidden_dim, seq_length))
return hidden_states_padded
@staticmethod
......@@ -539,7 +573,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
:, :, :-window_overlap
] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap
chunked_hidden_states = tf.reshape(
chunked_hidden_states, (total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim)
chunked_hidden_states,
(total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim),
) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap
chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
return chunked_hidden_states
......@@ -566,7 +601,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
chunked_hidden_states = tf.reshape(
chunked_hidden_states, (batch_size, num_output_chunks, 2 * window_overlap, hidden_dim)
chunked_hidden_states,
(batch_size, num_output_chunks, 2 * window_overlap, hidden_dim),
)
return chunked_hidden_states
......@@ -619,7 +655,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
key_vectors_only_global = tf.scatter_nd(
is_local_index_global_attn_nonzero,
global_key_vectors,
shape=(batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim),
shape=(
batch_size,
max_num_global_attn_indices,
self.num_heads,
self.head_dim,
),
)
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
......@@ -633,7 +674,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# scatter mask
attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update(
attn_probs_from_global_key_trans, is_local_index_no_global_attn_nonzero, mask
attn_probs_from_global_key_trans,
is_local_index_no_global_attn_nonzero,
mask,
)
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
......@@ -664,7 +707,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
value_vectors_only_global = tf.scatter_nd(
is_local_index_global_attn_nonzero,
global_value_vectors,
shape=(batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim),
shape=(
batch_size,
max_num_global_attn_indices,
self.num_heads,
self.head_dim,
),
)
# compute attn output only global
......@@ -690,14 +738,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
is_index_masked,
training,
):
seq_len, batch_size = shape_list(hidden_states)[:2]
batch_size, seq_len = shape_list(hidden_states)[:2]
# prepare global hidden states
global_attn_hidden_states = tf.gather_nd(hidden_states, tf.reverse(is_index_global_attn_nonzero, axis=[1]))
global_attn_hidden_states = tf.gather_nd(hidden_states, is_index_global_attn_nonzero)
global_attn_hidden_states = tf.scatter_nd(
tf.reverse(is_local_index_global_attn_nonzero, axis=[1]),
is_local_index_global_attn_nonzero,
global_attn_hidden_states,
shape=(max_num_global_attn_indices, batch_size, self.embed_dim),
shape=(batch_size, max_num_global_attn_indices, self.embed_dim),
)
# global key, query, value
......@@ -708,27 +756,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# normalize
global_query_vectors_only_global /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32))
# (batch_size * self.num_heads, max_num_global_attn_indices, head_dim)
global_query_vectors_only_global = tf.transpose(
tf.reshape(
global_query_vectors_only_global,
(max_num_global_attn_indices, batch_size * self.num_heads, self.head_dim),
),
(1, 0, 2),
)
# (..., batch_size * self.num_heads, seq_len, head_dim)
global_key_vectors = tf.transpose(
tf.reshape(global_key_vectors, (-1, batch_size * self.num_heads, self.head_dim)), (1, 0, 2)
)
# (..., batch_size * self.num_heads, seq_len, head_dim)
global_value_vectors = tf.transpose(
tf.reshape(global_value_vectors, (-1, batch_size * self.num_heads, self.head_dim)), (1, 0, 2)
)
global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size)
global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size)
global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size)
# compute attn scores
global_attn_scores = tf.matmul(global_query_vectors_only_global, tf.transpose(global_key_vectors, (0, 2, 1)))
global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
tf.debugging.assert_equal(
shape_list(global_attn_scores),
......@@ -737,7 +770,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
global_attn_scores = tf.reshape(
global_attn_scores, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
global_attn_scores,
(batch_size, self.num_heads, max_num_global_attn_indices, seq_len),
)
global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3))
......@@ -748,7 +782,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# scatter mask
global_attn_scores_trans = tf.tensor_scatter_nd_update(
global_attn_scores_trans, is_local_index_no_global_attn_nonzero, global_attn_mask
global_attn_scores_trans,
is_local_index_no_global_attn_nonzero,
global_attn_mask,
)
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
......@@ -757,7 +793,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
global_attn_scores = tf.reshape(
global_attn_scores, (batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
global_attn_scores,
(batch_size * self.num_heads, max_num_global_attn_indices, seq_len),
)
# compute global attn probs
......@@ -776,12 +813,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
global_attn_output = tf.reshape(
global_attn_output, (batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim)
global_attn_output,
(batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim),
)
# get only non zero global attn output
nonzero_global_attn_output = tf.gather_nd(
tf.transpose(global_attn_output, (0, 2, 1, 3)), is_local_index_global_attn_nonzero
tf.transpose(global_attn_output, (0, 2, 1, 3)),
is_local_index_global_attn_nonzero,
)
nonzero_global_attn_output = tf.reshape(
nonzero_global_attn_output,
......@@ -789,12 +828,21 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
# overwrite values with global attention
attn_output = tf.tensor_scatter_nd_update(
attn_output, tf.reverse(is_index_global_attn_nonzero, axis=[1]), nonzero_global_attn_output
attn_output, is_index_global_attn_nonzero, nonzero_global_attn_output
)
return attn_output
def reshape_and_transpose(self, vector, batch_size):
return tf.reshape(
tf.transpose(
tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)),
(0, 2, 1, 3),
),
(batch_size * self.num_heads, -1, self.head_dim),
)
class TFLongformerAttention(tf.keras.layers.Layer):
def __init__(self, config, layer_id=0, **kwargs):
......@@ -806,10 +854,20 @@ class TFLongformerAttention(tf.keras.layers.Layer):
raise NotImplementedError
def call(self, inputs, training=False):
input_tensor, attention_mask, output_attentions = inputs
(
hidden_states,
attention_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
output_attentions,
) = inputs
self_outputs = self.self_attention([input_tensor, attention_mask, output_attentions], training=training)
attention_output = self.dense_output(self_outputs[0], input_tensor, training=training)
self_outputs = self.self_attention(
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, output_attentions],
training=training,
)
attention_output = self.dense_output(self_outputs[0], hidden_states, training=training)
outputs = (attention_output,) + self_outputs[1:]
return outputs
......@@ -823,9 +881,19 @@ class TFLongformerLayer(tf.keras.layers.Layer):
self.longformer_output = TFBertOutput(config, name="output")
def call(self, inputs, training=False):
hidden_states, attention_mask, output_attentions = inputs
(
hidden_states,
attention_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
output_attentions,
) = inputs
attention_outputs = self.attention([hidden_states, attention_mask, output_attentions], training=training)
attention_outputs = self.attention(
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, output_attentions],
training=training,
)
attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output)
layer_output = self.longformer_output(intermediate_output, attention_output, training=training)
......@@ -848,12 +916,14 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
attention_mask=None,
head_mask=None,
padding_len=0,
is_index_masked=None,
is_index_global_attn=None,
is_global_attn=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
):
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
......@@ -861,11 +931,21 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
all_hidden_states = all_hidden_states + (hidden_states_to_add,)
layer_outputs = layer_module([hidden_states, attention_mask, output_attentions], training=training)
layer_outputs = layer_module(
[
hidden_states,
attention_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
output_attentions,
],
training=training,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),)
# Add last layer
if output_hidden_states:
......@@ -875,7 +955,9 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
......@@ -982,7 +1064,14 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
if global_attention_mask is not None:
attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask)
padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size(
(
padding_len,
input_ids,
attention_mask,
token_type_ids,
position_ids,
inputs_embeds,
) = self._pad_to_window_size(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
......@@ -991,27 +1080,32 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
pad_token_id=self.pad_token_id,
)
# is index masked or global attention
is_index_masked = tf.math.less(attention_mask, 1)
is_index_global_attn = tf.math.greater(attention_mask, 1)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# Sizes are [batch_size, to_seq_length, 1, 1]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
extended_attention_mask = attention_mask[:, :, tf.newaxis, tf.newaxis]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# Since attention_mask is 1.0 for positions we want to locall attend locally and 0.0 for
# masked and global attn positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
extended_attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
padding_len=padding_len,
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
......@@ -1081,7 +1175,14 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
) # no attention on the padding tokens
token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0
return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds
return (
padding_len,
input_ids,
attention_mask,
token_type_ids,
position_ids,
inputs_embeds,
)
@staticmethod
def _merge_to_attention_mask(attention_mask: tf.Tensor, global_attention_mask: tf.Tensor):
......@@ -1231,7 +1332,10 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
return outputs
@add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING)
@add_start_docstrings(
"""Longformer Model with a `language modeling` head on top. """,
LONGFORMER_START_DOCSTRING,
)
class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......@@ -1320,7 +1424,9 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
self.longformer = TFLongformerMainLayer(config, name="longformer")
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
config.num_labels,
kernel_initializer=get_initializer(config.initializer_range),
name="qa_outputs",
)
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
......
......@@ -110,15 +110,6 @@ class TFModelTesterMixin:
def test_initialization(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# configs_no_init = _config_zero_init(config)
# for model_class in self.all_model_classes:
# model = model_class(config=configs_no_init)
# for name, param in model.named_parameters():
# if param.requires_grad:
# self.assertIn(param.data.mean().item(), [0.0, 1.0],
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
def test_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......@@ -134,6 +125,19 @@ class TFModelTesterMixin:
self.assert_outputs_same(after_outputs, outputs)
def test_graph_mode(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@tf.function
def run_in_graph_mode():
return model(inputs)
outputs = run_in_graph_mode()
self.assertIsNotNone(outputs)
@slow
def test_saved_model_with_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
......@@ -385,15 +385,16 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
self.assertTrue(shape_list(hidden_states), [1, 8, 4])
# pad along seq length dim
paddings = tf.constant([[0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
hidden_states = TFLongformerSelfAttention._chunk(hidden_states, window_overlap=2)
padded_hidden_states = TFLongformerSelfAttention._pad_and_transpose_last_two_dims(hidden_states, paddings)
self.assertTrue(shape_list(padded_hidden_states) == [1, 8, 5])
self.assertTrue(shape_list(padded_hidden_states) == [1, 1, 8, 5])
expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32)
tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, -1, :], rtol=1e-6)
tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, 0, -1, :], rtol=1e-6)
tf.debugging.assert_near(
hidden_states[0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6
hidden_states[0, 0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6
)
def test_mask_invalid_locations(self):
......@@ -437,10 +438,16 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
hidden_states = self._get_hidden_states()
batch_size, seq_length, hidden_size = hidden_states.shape
attention_mask = tf.zeros((batch_size, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask = tf.where(tf.range(4)[None, None, None, :] > 1, -10000.0, attention_mask)
attention_mask = tf.zeros((batch_size, seq_length), dtype=tf.dtypes.float32)
is_index_global_attn = tf.math.greater(attention_mask, 1)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0, attention_mask[:, :, None, None])
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
output_hidden_states = layer([hidden_states, attention_mask, None])[0]
output_hidden_states = layer(
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, None]
)[0]
expected_slice = tf.convert_to_tensor(
[0.00188, 0.012196, -0.017051, -0.025571, -0.02996, 0.017297, -0.011521, 0.004848], dtype=tf.dtypes.float32
......@@ -461,12 +468,18 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_1 = tf.where(tf.range(4)[None, None, None, :] > 1, 10000.0, attention_mask_1)
attention_mask_1 = tf.where(tf.range(4)[None, None, None, :] > 2, -10000.0, attention_mask_1)
attention_mask_2 = tf.where(tf.range(4)[None, None, None, :] > 0, 10000.0, attention_mask_2)
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 1, 10000.0, attention_mask_1)
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 2, -10000.0, attention_mask_1)
attention_mask_2 = tf.where(tf.range(4)[None, :, None, None] > 0, 10000.0, attention_mask_2)
attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0)
output_hidden_states = layer([hidden_states, attention_mask, None])[0]
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
output_hidden_states = layer(
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn, None]
)[0]
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
expected_slice_0 = tf.convert_to_tensor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment