Unverified Commit 858b7d58 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[TF Longformer] Improve Speed for TF Longformer (#6447)

* add tf graph compile tests

* fix conflict

* remove more tf transpose statements

* fix conflicts

* fix comment typos

* move function to class function

* fix black

* fix black

* make style
parent a75c64d8
...@@ -1088,7 +1088,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific ...@@ -1088,7 +1088,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
) )
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-cased", checkpoint="bert-base-cased",
......
...@@ -677,7 +677,7 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla ...@@ -677,7 +677,7 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
self.electra = TFElectraMainLayer(config, name="electra") self.electra = TFElectraMainLayer(config, name="electra")
self.classifier = TFElectraClassificationHead(config, name="classifier") self.classifier = TFElectraClassificationHead(config, name="classifier")
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)")) @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/electra-small-discriminator", checkpoint="google/electra-small-discriminator",
......
...@@ -97,24 +97,36 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -97,24 +97,36 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.query = tf.keras.layers.Dense( self.query = tf.keras.layers.Dense(
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="query" self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="query",
) )
self.key = tf.keras.layers.Dense( self.key = tf.keras.layers.Dense(
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="key" self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="key",
) )
self.value = tf.keras.layers.Dense( self.value = tf.keras.layers.Dense(
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="value" self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="value",
) )
# separate projection layers for tokens with global attention # separate projection layers for tokens with global attention
self.query_global = tf.keras.layers.Dense( self.query_global = tf.keras.layers.Dense(
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="query_global" self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="query_global",
) )
self.key_global = tf.keras.layers.Dense( self.key_global = tf.keras.layers.Dense(
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="key_global" self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="key_global",
) )
self.value_global = tf.keras.layers.Dense( self.value_global = tf.keras.layers.Dense(
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="value_global" self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="value_global",
) )
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
...@@ -148,23 +160,21 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -148,23 +160,21 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
""" """
# retrieve input args # retrieve input args
hidden_states, attention_mask, output_attentions = inputs (
hidden_states,
attention_mask = tf.squeeze(tf.squeeze(attention_mask, axis=2), axis=1) attention_mask,
# is index masked or global attention is_index_masked,
is_index_global_attn,
is_index_masked = tf.math.less(attention_mask, 0) is_global_attn,
is_index_global_attn = tf.math.greater(attention_mask, 0) output_attentions,
is_global_attn = tf.math.reduce_any(is_index_global_attn) ) = inputs
hidden_states = tf.transpose(hidden_states, (1, 0, 2))
# project hidden states # project hidden states
query_vectors = self.query(hidden_states) query_vectors = self.query(hidden_states)
key_vectors = self.key(hidden_states) key_vectors = self.key(hidden_states)
value_vectors = self.value(hidden_states) value_vectors = self.value(hidden_states)
seq_len, batch_size, embed_dim = shape_list(hidden_states) batch_size, seq_len, embed_dim = shape_list(hidden_states)
tf.debugging.assert_equal( tf.debugging.assert_equal(
embed_dim, embed_dim,
self.embed_dim, self.embed_dim,
...@@ -174,24 +184,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -174,24 +184,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# normalize query # normalize query
query_vectors /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32)) query_vectors /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32))
query_vectors = tf.transpose( query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
tf.reshape(query_vectors, (seq_len, batch_size, self.num_heads, self.head_dim)), (1, 0, 2, 3) key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
)
key_vectors = tf.transpose(
tf.reshape(key_vectors, (seq_len, batch_size, self.num_heads, self.head_dim)), (1, 0, 2, 3)
)
# attn_probs = (batch_size, seq_len, num_heads, window*2+1) # attn_probs = (batch_size, seq_len, num_heads, window*2+1)
attn_scores = self._sliding_chunks_query_key_matmul( attn_scores = self._sliding_chunks_query_key_matmul(
query_vectors, key_vectors, self.one_sided_attn_window_size query_vectors, key_vectors, self.one_sided_attn_window_size
) )
# values to pad for attention probs
float_mask = tf.cast((attention_mask != 0)[:, :, None, None], dtype=tf.float32) * -10000.0
# diagonal mask with zeros everywhere and -inf inplace of padding # diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask = self._sliding_chunks_query_key_matmul( diagonal_mask = self._sliding_chunks_query_key_matmul(
tf.ones(shape_list(float_mask), dtype=tf.float32), float_mask, self.one_sided_attn_window_size tf.ones(shape_list(attention_mask), dtype=tf.float32),
attention_mask,
self.one_sided_attn_window_size,
) )
# pad local attention probs # pad local attention probs
...@@ -231,15 +236,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -231,15 +236,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# softmax sometimes inserts NaN if all positions are masked, replace them with 0 # softmax sometimes inserts NaN if all positions are masked, replace them with 0
attn_probs = tf.where( attn_probs = tf.where(
tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)), 0.0, attn_probs tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)),
0.0,
attn_probs,
) )
# apply dropout # apply dropout
attn_probs = self.dropout(attn_probs, training=training) attn_probs = self.dropout(attn_probs, training=training)
value_vectors = tf.transpose( value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
tf.reshape(value_vectors, (seq_len, batch_size, self.num_heads, self.head_dim)), (1, 0, 2, 3)
)
# if global attention, compute sum of global and local attn # if global attention, compute sum of global and local attn
attn_output = tf.cond( attn_output = tf.cond(
...@@ -257,9 +262,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -257,9 +262,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) )
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size" shape_list(attn_output),
[batch_size, seq_len, self.num_heads, self.head_dim],
message="Unexpected size",
) )
attn_output = tf.reshape(tf.transpose(attn_output, (1, 0, 2, 3)), (seq_len, batch_size, embed_dim)) attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
# compute value for global attention and overwrite to attention output # compute value for global attention and overwrite to attention output
# TODO: remove the redundant computation # TODO: remove the redundant computation
...@@ -278,8 +285,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -278,8 +285,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
lambda: attn_output, lambda: attn_output,
) )
attn_output = tf.transpose(attn_output, (1, 0, 2))
# GLOBAL ATTN: # GLOBAL ATTN:
# With global attention, return global attention probabilities only # With global attention, return global attention probabilities only
# batch_size x num_heads x max_num_global_attention_tokens x sequence_length # batch_size x num_heads x max_num_global_attention_tokens x sequence_length
...@@ -294,7 +299,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -294,7 +299,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
attn_probs = tf.cond( attn_probs = tf.cond(
is_global_attn, is_global_attn,
lambda: self._get_global_attn_probs(attn_probs, max_num_global_attn_indices), lambda: self._get_global_attn_probs(attn_probs, max_num_global_attn_indices),
lambda: tf.transpose(attn_probs, (0, 2, 1, 3)), lambda: attn_probs,
) )
outputs = (attn_output, attn_probs) outputs = (attn_output, attn_probs)
...@@ -310,7 +315,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -310,7 +315,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
], ],
axis=-1, axis=-1,
) )
attn_probs = tf.transpose(attn_probs, (0, 2, 1, 3))
return attn_probs return attn_probs
def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
...@@ -332,7 +336,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -332,7 +336,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
chunks_count = seq_len // window_overlap - 1 chunks_count = seq_len // window_overlap - 1
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2 # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
query = tf.reshape(tf.transpose(query, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim)) query = tf.reshape(
tf.transpose(query, (0, 2, 1, 3)),
(batch_size * num_heads, seq_len, head_dim),
)
key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim)) key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim))
chunked_query = self._chunk(query, window_overlap) chunked_query = self._chunk(query, window_overlap)
...@@ -374,9 +381,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -374,9 +381,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) )
diagonal_attn_scores_first_chunk = tf.concat( diagonal_attn_scores_first_chunk = tf.concat(
[ [
tf.roll(diagonal_chunked_attention_scores, shift=[1, window_overlap], axis=[2, 3])[ tf.roll(
:, :, :window_overlap, :window_overlap diagonal_chunked_attention_scores,
], shift=[1, window_overlap],
axis=[2, 3],
)[:, :, :window_overlap, :window_overlap],
tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)), tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)),
], ],
axis=1, axis=1,
...@@ -385,13 +394,20 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -385,13 +394,20 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
first_chunk_mask = ( first_chunk_mask = (
tf.broadcast_to( tf.broadcast_to(
tf.range(chunks_count + 1)[None, :, None, None], tf.range(chunks_count + 1)[None, :, None, None],
shape=(batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap), shape=(
batch_size * num_heads,
chunks_count + 1,
window_overlap,
window_overlap,
),
) )
< 1 < 1
) )
diagonal_attn_scores_low_triang = tf.where( diagonal_attn_scores_low_triang = tf.where(
first_chunk_mask, diagonal_attn_scores_first_chunk, diagonal_attn_scores_low_triang first_chunk_mask,
diagonal_attn_scores_first_chunk,
diagonal_attn_scores_low_triang,
) )
# merging upper and lower triangle # merging upper and lower triangle
...@@ -401,7 +417,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -401,7 +417,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# separate batch_size and num_heads dimensions again # separate batch_size and num_heads dimensions again
diagonal_attention_scores = tf.transpose( diagonal_attention_scores = tf.transpose(
tf.reshape(diagonal_attention_scores, (batch_size, num_heads, seq_len, 2 * window_overlap + 1)), tf.reshape(
diagonal_attention_scores,
(batch_size, num_heads, seq_len, 2 * window_overlap + 1),
),
(0, 2, 1, 3), (0, 2, 1, 3),
) )
...@@ -412,7 +431,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -412,7 +431,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
def _mask_invalid_locations(input_tensor, window_overlap): def _mask_invalid_locations(input_tensor, window_overlap):
# create correct upper triangle bool mask # create correct upper triangle bool mask
mask_2d_upper = tf.reverse( mask_2d_upper = tf.reverse(
tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0), axis=[0] tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0),
axis=[0],
) )
# pad to full matrix # pad to full matrix
padding = tf.constant( padding = tf.constant(
...@@ -443,7 +463,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -443,7 +463,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
batch_size, seq_len, num_heads, head_dim = shape_list(value) batch_size, seq_len, num_heads, head_dim = shape_list(value)
tf.debugging.assert_equal( tf.debugging.assert_equal(
seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap" seq_len % (window_overlap * 2),
0,
message="Seq_len has to be multiple of 2 * window_overlap",
) )
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_probs)[:3], shape_list(attn_probs)[:3],
...@@ -461,11 +483,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -461,11 +483,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
chunked_attn_probs = tf.reshape( chunked_attn_probs = tf.reshape(
tf.transpose(attn_probs, (0, 2, 1, 3)), tf.transpose(attn_probs, (0, 2, 1, 3)),
(batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1), (
batch_size * num_heads,
seq_len // window_overlap,
window_overlap,
2 * window_overlap + 1,
),
) )
# group batch_size and num_heads dimensions into one # group batch_size and num_heads dimensions into one
value = tf.reshape(tf.transpose(value, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim)) value = tf.reshape(
tf.transpose(value, (0, 2, 1, 3)),
(batch_size * num_heads, seq_len, head_dim),
)
# pad seq_len with w at the beginning of the sequence and another window overlap at the end # pad seq_len with w at the beginning of the sequence and another window overlap at the end
...@@ -478,10 +508,13 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -478,10 +508,13 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count
chunked_value = tf.signal.frame( chunked_value = tf.signal.frame(
tf.reshape(padded_value, (batch_size * num_heads, -1)), frame_size, frame_hop_size tf.reshape(padded_value, (batch_size * num_heads, -1)),
frame_size,
frame_hop_size,
) )
chunked_value = tf.reshape( chunked_value = tf.reshape(
chunked_value, (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim) chunked_value,
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
) )
tf.debugging.assert_equal( tf.debugging.assert_equal(
...@@ -493,7 +526,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -493,7 +526,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs) chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value) context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
context = tf.transpose(tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)), (0, 2, 1, 3)) context = tf.transpose(
tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)),
(0, 2, 1, 3),
)
return context return context
@staticmethod @staticmethod
...@@ -502,12 +538,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -502,12 +538,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
hidden_states_padded = tf.pad( hidden_states_padded = tf.pad(
hidden_states_padded, paddings hidden_states_padded, paddings
) # padding value is not important because it will be overwritten ) # padding value is not important because it will be overwritten
if len(shape_list(hidden_states_padded)) > 3:
batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded) batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded)
hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length)) hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length))
else:
batch_size, seq_length, hidden_dim = shape_list(hidden_states_padded)
hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, hidden_dim, seq_length))
return hidden_states_padded return hidden_states_padded
@staticmethod @staticmethod
...@@ -539,7 +573,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -539,7 +573,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
:, :, :-window_overlap :, :, :-window_overlap
] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap ] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap
chunked_hidden_states = tf.reshape( chunked_hidden_states = tf.reshape(
chunked_hidden_states, (total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim) chunked_hidden_states,
(total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim),
) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap ) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap
chunked_hidden_states = chunked_hidden_states[:, :, :, :-1] chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
return chunked_hidden_states return chunked_hidden_states
...@@ -566,7 +601,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -566,7 +601,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) )
chunked_hidden_states = tf.reshape( chunked_hidden_states = tf.reshape(
chunked_hidden_states, (batch_size, num_output_chunks, 2 * window_overlap, hidden_dim) chunked_hidden_states,
(batch_size, num_output_chunks, 2 * window_overlap, hidden_dim),
) )
return chunked_hidden_states return chunked_hidden_states
...@@ -619,7 +655,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -619,7 +655,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
key_vectors_only_global = tf.scatter_nd( key_vectors_only_global = tf.scatter_nd(
is_local_index_global_attn_nonzero, is_local_index_global_attn_nonzero,
global_key_vectors, global_key_vectors,
shape=(batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim), shape=(
batch_size,
max_num_global_attn_indices,
self.num_heads,
self.head_dim,
),
) )
# (batch_size, seq_len, num_heads, max_num_global_attn_indices) # (batch_size, seq_len, num_heads, max_num_global_attn_indices)
...@@ -633,7 +674,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -633,7 +674,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# scatter mask # scatter mask
attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update( attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update(
attn_probs_from_global_key_trans, is_local_index_no_global_attn_nonzero, mask attn_probs_from_global_key_trans,
is_local_index_no_global_attn_nonzero,
mask,
) )
# (batch_size, seq_len, num_heads, max_num_global_attn_indices) # (batch_size, seq_len, num_heads, max_num_global_attn_indices)
...@@ -664,7 +707,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -664,7 +707,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
value_vectors_only_global = tf.scatter_nd( value_vectors_only_global = tf.scatter_nd(
is_local_index_global_attn_nonzero, is_local_index_global_attn_nonzero,
global_value_vectors, global_value_vectors,
shape=(batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim), shape=(
batch_size,
max_num_global_attn_indices,
self.num_heads,
self.head_dim,
),
) )
# compute attn output only global # compute attn output only global
...@@ -690,14 +738,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -690,14 +738,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
is_index_masked, is_index_masked,
training, training,
): ):
seq_len, batch_size = shape_list(hidden_states)[:2] batch_size, seq_len = shape_list(hidden_states)[:2]
# prepare global hidden states # prepare global hidden states
global_attn_hidden_states = tf.gather_nd(hidden_states, tf.reverse(is_index_global_attn_nonzero, axis=[1])) global_attn_hidden_states = tf.gather_nd(hidden_states, is_index_global_attn_nonzero)
global_attn_hidden_states = tf.scatter_nd( global_attn_hidden_states = tf.scatter_nd(
tf.reverse(is_local_index_global_attn_nonzero, axis=[1]), is_local_index_global_attn_nonzero,
global_attn_hidden_states, global_attn_hidden_states,
shape=(max_num_global_attn_indices, batch_size, self.embed_dim), shape=(batch_size, max_num_global_attn_indices, self.embed_dim),
) )
# global key, query, value # global key, query, value
...@@ -708,27 +756,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -708,27 +756,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# normalize # normalize
global_query_vectors_only_global /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32)) global_query_vectors_only_global /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32))
# (batch_size * self.num_heads, max_num_global_attn_indices, head_dim) global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size)
global_query_vectors_only_global = tf.transpose( global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size)
tf.reshape( global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size)
global_query_vectors_only_global,
(max_num_global_attn_indices, batch_size * self.num_heads, self.head_dim),
),
(1, 0, 2),
)
# (..., batch_size * self.num_heads, seq_len, head_dim)
global_key_vectors = tf.transpose(
tf.reshape(global_key_vectors, (-1, batch_size * self.num_heads, self.head_dim)), (1, 0, 2)
)
# (..., batch_size * self.num_heads, seq_len, head_dim)
global_value_vectors = tf.transpose(
tf.reshape(global_value_vectors, (-1, batch_size * self.num_heads, self.head_dim)), (1, 0, 2)
)
# compute attn scores # compute attn scores
global_attn_scores = tf.matmul(global_query_vectors_only_global, tf.transpose(global_key_vectors, (0, 2, 1))) global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(global_attn_scores), shape_list(global_attn_scores),
...@@ -737,7 +770,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -737,7 +770,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) )
global_attn_scores = tf.reshape( global_attn_scores = tf.reshape(
global_attn_scores, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len) global_attn_scores,
(batch_size, self.num_heads, max_num_global_attn_indices, seq_len),
) )
global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3)) global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3))
...@@ -748,7 +782,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -748,7 +782,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# scatter mask # scatter mask
global_attn_scores_trans = tf.tensor_scatter_nd_update( global_attn_scores_trans = tf.tensor_scatter_nd_update(
global_attn_scores_trans, is_local_index_no_global_attn_nonzero, global_attn_mask global_attn_scores_trans,
is_local_index_no_global_attn_nonzero,
global_attn_mask,
) )
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3)) global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
...@@ -757,7 +793,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -757,7 +793,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores) global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
global_attn_scores = tf.reshape( global_attn_scores = tf.reshape(
global_attn_scores, (batch_size * self.num_heads, max_num_global_attn_indices, seq_len) global_attn_scores,
(batch_size * self.num_heads, max_num_global_attn_indices, seq_len),
) )
# compute global attn probs # compute global attn probs
...@@ -776,12 +813,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -776,12 +813,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) )
global_attn_output = tf.reshape( global_attn_output = tf.reshape(
global_attn_output, (batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim) global_attn_output,
(batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim),
) )
# get only non zero global attn output # get only non zero global attn output
nonzero_global_attn_output = tf.gather_nd( nonzero_global_attn_output = tf.gather_nd(
tf.transpose(global_attn_output, (0, 2, 1, 3)), is_local_index_global_attn_nonzero tf.transpose(global_attn_output, (0, 2, 1, 3)),
is_local_index_global_attn_nonzero,
) )
nonzero_global_attn_output = tf.reshape( nonzero_global_attn_output = tf.reshape(
nonzero_global_attn_output, nonzero_global_attn_output,
...@@ -789,12 +828,21 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -789,12 +828,21 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) )
# overwrite values with global attention # overwrite values with global attention
attn_output = tf.tensor_scatter_nd_update( attn_output = tf.tensor_scatter_nd_update(
attn_output, tf.reverse(is_index_global_attn_nonzero, axis=[1]), nonzero_global_attn_output attn_output, is_index_global_attn_nonzero, nonzero_global_attn_output
) )
return attn_output return attn_output
def reshape_and_transpose(self, vector, batch_size):
return tf.reshape(
tf.transpose(
tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)),
(0, 2, 1, 3),
),
(batch_size * self.num_heads, -1, self.head_dim),
)
class TFLongformerAttention(tf.keras.layers.Layer): class TFLongformerAttention(tf.keras.layers.Layer):
def __init__(self, config, layer_id=0, **kwargs): def __init__(self, config, layer_id=0, **kwargs):
...@@ -806,10 +854,20 @@ class TFLongformerAttention(tf.keras.layers.Layer): ...@@ -806,10 +854,20 @@ class TFLongformerAttention(tf.keras.layers.Layer):
raise NotImplementedError raise NotImplementedError
def call(self, inputs, training=False): def call(self, inputs, training=False):
input_tensor, attention_mask, output_attentions = inputs (
hidden_states,
attention_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
output_attentions,
) = inputs
self_outputs = self.self_attention([input_tensor, attention_mask, output_attentions], training=training) self_outputs = self.self_attention(
attention_output = self.dense_output(self_outputs[0], input_tensor, training=training) [hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, output_attentions],
training=training,
)
attention_output = self.dense_output(self_outputs[0], hidden_states, training=training)
outputs = (attention_output,) + self_outputs[1:] outputs = (attention_output,) + self_outputs[1:]
return outputs return outputs
...@@ -823,9 +881,19 @@ class TFLongformerLayer(tf.keras.layers.Layer): ...@@ -823,9 +881,19 @@ class TFLongformerLayer(tf.keras.layers.Layer):
self.longformer_output = TFBertOutput(config, name="output") self.longformer_output = TFBertOutput(config, name="output")
def call(self, inputs, training=False): def call(self, inputs, training=False):
hidden_states, attention_mask, output_attentions = inputs (
hidden_states,
attention_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
output_attentions,
) = inputs
attention_outputs = self.attention([hidden_states, attention_mask, output_attentions], training=training) attention_outputs = self.attention(
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, output_attentions],
training=training,
)
attention_output = attention_outputs[0] attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output) intermediate_output = self.intermediate(attention_output)
layer_output = self.longformer_output(intermediate_output, attention_output, training=training) layer_output = self.longformer_output(intermediate_output, attention_output, training=training)
...@@ -848,12 +916,14 @@ class TFLongformerEncoder(tf.keras.layers.Layer): ...@@ -848,12 +916,14 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
padding_len=0, padding_len=0,
is_index_masked=None,
is_index_global_attn=None,
is_global_attn=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
): ):
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
...@@ -861,11 +931,21 @@ class TFLongformerEncoder(tf.keras.layers.Layer): ...@@ -861,11 +931,21 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
all_hidden_states = all_hidden_states + (hidden_states_to_add,) all_hidden_states = all_hidden_states + (hidden_states_to_add,)
layer_outputs = layer_module([hidden_states, attention_mask, output_attentions], training=training) layer_outputs = layer_module(
[
hidden_states,
attention_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
output_attentions,
],
training=training,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if output_attentions: if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],) all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),)
# Add last layer # Add last layer
if output_hidden_states: if output_hidden_states:
...@@ -875,7 +955,9 @@ class TFLongformerEncoder(tf.keras.layers.Layer): ...@@ -875,7 +955,9 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutput( return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
) )
...@@ -982,7 +1064,14 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -982,7 +1064,14 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
if global_attention_mask is not None: if global_attention_mask is not None:
attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask) attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask)
padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size( (
padding_len,
input_ids,
attention_mask,
token_type_ids,
position_ids,
inputs_embeds,
) = self._pad_to_window_size(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -991,27 +1080,32 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -991,27 +1080,32 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
) )
# is index masked or global attention
is_index_masked = tf.math.less(attention_mask, 1)
is_index_global_attn = tf.math.greater(attention_mask, 1)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, to_seq_length, 1, 1]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] extended_attention_mask = attention_mask[:, :, tf.newaxis, tf.newaxis]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to locall attend locally and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked and global attn positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions. # positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
extended_attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training) embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
padding_len=padding_len, padding_len=padding_len,
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -1081,7 +1175,14 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1081,7 +1175,14 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
) # no attention on the padding tokens ) # no attention on the padding tokens
token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0 token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0
return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds return (
padding_len,
input_ids,
attention_mask,
token_type_ids,
position_ids,
inputs_embeds,
)
@staticmethod @staticmethod
def _merge_to_attention_mask(attention_mask: tf.Tensor, global_attention_mask: tf.Tensor): def _merge_to_attention_mask(attention_mask: tf.Tensor, global_attention_mask: tf.Tensor):
...@@ -1231,7 +1332,10 @@ class TFLongformerModel(TFLongformerPreTrainedModel): ...@@ -1231,7 +1332,10 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
return outputs return outputs
@add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING) @add_start_docstrings(
"""Longformer Model with a `language modeling` head on top. """,
LONGFORMER_START_DOCSTRING,
)
class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss): class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
...@@ -1320,7 +1424,9 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn ...@@ -1320,7 +1424,9 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
self.longformer = TFLongformerMainLayer(config, name="longformer") self.longformer = TFLongformerMainLayer(config, name="longformer")
self.qa_outputs = tf.keras.layers.Dense( self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" config.num_labels,
kernel_initializer=get_initializer(config.initializer_range),
name="qa_outputs",
) )
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
......
...@@ -110,15 +110,6 @@ class TFModelTesterMixin: ...@@ -110,15 +110,6 @@ class TFModelTesterMixin:
def test_initialization(self): def test_initialization(self):
pass pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# configs_no_init = _config_zero_init(config)
# for model_class in self.all_model_classes:
# model = model_class(config=configs_no_init)
# for name, param in model.named_parameters():
# if param.requires_grad:
# self.assertIn(param.data.mean().item(), [0.0, 1.0],
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
def test_save_load(self): def test_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -134,6 +125,19 @@ class TFModelTesterMixin: ...@@ -134,6 +125,19 @@ class TFModelTesterMixin:
self.assert_outputs_same(after_outputs, outputs) self.assert_outputs_same(after_outputs, outputs)
def test_graph_mode(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@tf.function
def run_in_graph_mode():
return model(inputs)
outputs = run_in_graph_mode()
self.assertIsNotNone(outputs)
@slow @slow
def test_saved_model_with_hidden_states_output(self): def test_saved_model_with_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -385,15 +385,16 @@ class TFLongformerModelIntegrationTest(unittest.TestCase): ...@@ -385,15 +385,16 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
self.assertTrue(shape_list(hidden_states), [1, 8, 4]) self.assertTrue(shape_list(hidden_states), [1, 8, 4])
# pad along seq length dim # pad along seq length dim
paddings = tf.constant([[0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32) paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
hidden_states = TFLongformerSelfAttention._chunk(hidden_states, window_overlap=2)
padded_hidden_states = TFLongformerSelfAttention._pad_and_transpose_last_two_dims(hidden_states, paddings) padded_hidden_states = TFLongformerSelfAttention._pad_and_transpose_last_two_dims(hidden_states, paddings)
self.assertTrue(shape_list(padded_hidden_states) == [1, 8, 5]) self.assertTrue(shape_list(padded_hidden_states) == [1, 1, 8, 5])
expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32) expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32)
tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, -1, :], rtol=1e-6) tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, 0, -1, :], rtol=1e-6)
tf.debugging.assert_near( tf.debugging.assert_near(
hidden_states[0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6 hidden_states[0, 0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6
) )
def test_mask_invalid_locations(self): def test_mask_invalid_locations(self):
...@@ -437,10 +438,16 @@ class TFLongformerModelIntegrationTest(unittest.TestCase): ...@@ -437,10 +438,16 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
batch_size, seq_length, hidden_size = hidden_states.shape batch_size, seq_length, hidden_size = hidden_states.shape
attention_mask = tf.zeros((batch_size, 1, 1, seq_length), dtype=tf.dtypes.float32) attention_mask = tf.zeros((batch_size, seq_length), dtype=tf.dtypes.float32)
attention_mask = tf.where(tf.range(4)[None, None, None, :] > 1, -10000.0, attention_mask) is_index_global_attn = tf.math.greater(attention_mask, 1)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0, attention_mask[:, :, None, None])
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
output_hidden_states = layer([hidden_states, attention_mask, None])[0] output_hidden_states = layer(
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, None]
)[0]
expected_slice = tf.convert_to_tensor( expected_slice = tf.convert_to_tensor(
[0.00188, 0.012196, -0.017051, -0.025571, -0.02996, 0.017297, -0.011521, 0.004848], dtype=tf.dtypes.float32 [0.00188, 0.012196, -0.017051, -0.025571, -0.02996, 0.017297, -0.011521, 0.004848], dtype=tf.dtypes.float32
...@@ -461,12 +468,18 @@ class TFLongformerModelIntegrationTest(unittest.TestCase): ...@@ -461,12 +468,18 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32) attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32) attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_1 = tf.where(tf.range(4)[None, None, None, :] > 1, 10000.0, attention_mask_1) attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 1, 10000.0, attention_mask_1)
attention_mask_1 = tf.where(tf.range(4)[None, None, None, :] > 2, -10000.0, attention_mask_1) attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 2, -10000.0, attention_mask_1)
attention_mask_2 = tf.where(tf.range(4)[None, None, None, :] > 0, 10000.0, attention_mask_2) attention_mask_2 = tf.where(tf.range(4)[None, :, None, None] > 0, 10000.0, attention_mask_2)
attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0) attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0)
output_hidden_states = layer([hidden_states, attention_mask, None])[0] is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
output_hidden_states = layer(
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn, None]
)[0]
self.assertTrue(output_hidden_states.shape, (2, 4, 8)) self.assertTrue(output_hidden_states.shape, (2, 4, 8))
expected_slice_0 = tf.convert_to_tensor( expected_slice_0 = tf.convert_to_tensor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment