Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
19e737b9
Unverified
Commit
19e737b9
authored
Feb 22, 2021
by
Julien Plu
Committed by
GitHub
Feb 22, 2021
Browse files
Making TF Longformer-like models compliant with AMP (#10233)
* AMP * Add LED * Apply style * Fix longformer
parent
cd8c4c3f
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
267 additions
and
217 deletions
+267
-217
src/transformers/models/led/modeling_tf_led.py
src/transformers/models/led/modeling_tf_led.py
+156
-122
src/transformers/models/longformer/modeling_tf_longformer.py
src/transformers/models/longformer/modeling_tf_longformer.py
+111
-87
tests/test_modeling_tf_led.py
tests/test_modeling_tf_led.py
+0
-4
tests/test_modeling_tf_longformer.py
tests/test_modeling_tf_longformer.py
+0
-4
No files found.
src/transformers/models/led/modeling_tf_led.py
View file @
19e737b9
...
...
@@ -55,8 +55,7 @@ LARGE_NEGATIVE = -1e8
def
shift_tokens_right
(
input_ids
:
tf
.
Tensor
,
pad_token_id
:
int
,
decoder_start_token_id
:
int
):
shifted_input_ids
=
tf
.
cast
(
input_ids
,
tf
.
int32
)
shifted_input_ids
=
tf
.
roll
(
shifted_input_ids
,
1
,
axis
=-
1
)
shifted_input_ids
=
tf
.
roll
(
input_ids
,
1
,
axis
=-
1
)
start_tokens
=
tf
.
fill
((
shape_list
(
shifted_input_ids
)[
0
],
1
),
decoder_start_token_id
)
shifted_input_ids
=
tf
.
concat
([
start_tokens
,
shifted_input_ids
[:,
1
:]],
-
1
)
# replace possible -100 values in labels by `pad_token_id`
...
...
@@ -65,7 +64,8 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
)
# "Verify that `labels` has only positive values and -100"
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
cast
(
0
,
tf
.
int32
))
if
tf
.
executing_eagerly
():
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
...
...
@@ -79,14 +79,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
Make causal mask used for bi-directional self-attention.
"""
bsz
,
tgt_len
=
input_ids_shape
mask
=
tf
.
ones
((
tgt_len
,
tgt_len
)
,
dtype
=
tf
.
float32
)
*
LARGE_NEGATIVE
mask
=
tf
.
ones
((
tgt_len
,
tgt_len
))
*
LARGE_NEGATIVE
mask_cond
=
tf
.
range
(
shape_list
(
mask
)[
-
1
])
mask
=
tf
.
where
(
mask_cond
<
tf
.
reshape
(
mask_cond
+
1
,
(
shape_list
(
mask
)[
-
1
],
1
)),
0.0
,
mask
)
mask
=
tf
.
cast
(
mask
,
tf
.
float32
)
if
past_key_values_length
>
0
:
mask
=
tf
.
concat
([
tf
.
zeros
((
tgt_len
,
past_key_values_length
)
,
dtype
=
tf
.
float32
),
mask
],
axis
=-
1
)
mask
=
tf
.
concat
([
tf
.
zeros
((
tgt_len
,
past_key_values_length
)),
mask
],
axis
=-
1
)
return
tf
.
tile
(
mask
[
None
,
None
,
:,
:],
(
bsz
,
1
,
1
,
1
))
...
...
@@ -97,9 +96,11 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
"""
src_len
=
shape_list
(
mask
)[
1
]
tgt_len
=
tgt_len
if
tgt_len
is
not
None
else
src_len
expanded_mask
=
tf
.
cast
(
tf
.
tile
(
mask
[:,
None
,
None
,
:],
(
1
,
1
,
tgt_len
,
1
)),
tf
.
float32
)
one_cst
=
tf
.
constant
(
1.0
)
mask
=
tf
.
cast
(
mask
,
dtype
=
one_cst
.
dtype
)
expanded_mask
=
tf
.
tile
(
mask
[:,
None
,
None
,
:],
(
1
,
1
,
tgt_len
,
1
))
return
(
1.0
-
expanded_mask
)
*
LARGE_NEGATIVE
return
(
one_cst
-
expanded_mask
)
*
LARGE_NEGATIVE
class
TFLEDLearnedPositionalEmbedding
(
TFSharedEmbeddings
):
...
...
@@ -115,9 +116,7 @@ class TFLEDLearnedPositionalEmbedding(TFSharedEmbeddings):
"""Input is expected to be of size [bsz x seqlen]."""
bsz
,
seq_len
=
input_shape
[:
2
]
positions
=
tf
.
range
(
past_key_values_length
,
seq_len
+
past_key_values_length
,
delta
=
1
,
dtype
=
tf
.
int32
,
name
=
"range"
)
positions
=
tf
.
range
(
past_key_values_length
,
seq_len
+
past_key_values_length
,
delta
=
1
,
name
=
"range"
)
return
super
().
call
(
positions
)
...
...
@@ -212,6 +211,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
value_vectors
=
self
.
value
(
hidden_states
)
batch_size
,
seq_len
,
embed_dim
=
shape_list
(
hidden_states
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
embed_dim
,
self
.
embed_dim
,
...
...
@@ -219,7 +219,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
)
# normalize query
query_vectors
/=
tf
.
math
.
sqrt
(
tf
.
c
onvert_to_tensor
(
self
.
head_dim
,
dtype
=
tf
.
dtypes
.
float32
))
query_vectors
/=
tf
.
math
.
sqrt
(
tf
.
c
ast
(
self
.
head_dim
,
dtype
=
query_vectors
.
dtype
))
query_vectors
=
tf
.
reshape
(
query_vectors
,
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
))
key_vectors
=
tf
.
reshape
(
key_vectors
,
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
))
...
...
@@ -230,7 +230,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask
=
self
.
_sliding_chunks_query_key_matmul
(
tf
.
ones
(
shape_list
(
attention_mask
)
,
dtype
=
tf
.
float32
),
tf
.
ones
(
shape_list
(
attention_mask
)),
attention_mask
,
self
.
one_sided_attn_window_size
,
)
...
...
@@ -238,6 +238,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# pad local attention probs
attn_scores
+=
diagonal_mask
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_scores
),
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
one_sided_attn_window_size
*
2
+
1
],
...
...
@@ -285,16 +286,18 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
)
attn_probs
=
tf
.
where
(
masked_index
,
tf
.
zeros
(
shape_list
(
masked_index
),
dtype
=
tf
.
dtypes
.
float32
),
tf
.
zeros
(
shape_list
(
masked_index
),
dtype
=
attn_probs
.
dtype
),
attn_probs
,
)
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
layer_head_mask
),
[
self
.
num_heads
],
message
=
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is
{
shape_list
(
layer_head_mask
)
}
"
,
)
attn_probs
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
1
,
-
1
,
1
))
*
attn_probs
# apply dropout
...
...
@@ -316,6 +319,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
),
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_output
),
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
],
...
...
@@ -359,7 +363,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
)
attn_probs
=
tf
.
where
(
masked_global_attn_index
,
tf
.
zeros
(
shape_list
(
masked_global_attn_index
),
dtype
=
tf
.
dtypes
.
float32
),
tf
.
zeros
(
shape_list
(
masked_global_attn_index
),
dtype
=
attn_probs
.
dtype
),
attn_probs
,
)
...
...
@@ -375,6 +379,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
"""
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
query
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
seq_len
%
(
window_overlap
*
2
),
0
,
...
...
@@ -401,10 +406,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap
chunked_query
=
tf
.
cast
(
chunked_query
,
dtype
=
chunked_key
.
dtype
)
chunked_attention_scores
=
tf
.
einsum
(
"bcxd,bcyd->bcxy"
,
chunked_query
,
chunked_key
)
# multiply
# convert diagonals into columns
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
0
,
0
],
[
0
,
1
],
[
0
,
0
]]
,
dtype
=
tf
.
dtypes
.
int32
)
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
0
,
0
],
[
0
,
1
],
[
0
,
0
]])
diagonal_chunked_attention_scores
=
self
.
_pad_and_transpose_last_two_dims
(
chunked_attention_scores
,
paddings
)
# allocate space for the overall attention matrix where the chunks are combined. The last dimension
...
...
@@ -426,7 +432,10 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# - copying the lower triangle
diagonal_attn_scores_low_triang
=
tf
.
concat
(
[
tf
.
zeros
((
batch_size
*
num_heads
,
1
,
window_overlap
,
window_overlap
)),
tf
.
zeros
(
(
batch_size
*
num_heads
,
1
,
window_overlap
,
window_overlap
),
dtype
=
diagonal_chunked_attention_scores
.
dtype
,
),
diagonal_chunked_attention_scores
[:,
:,
-
(
window_overlap
+
1
)
:
-
1
,
window_overlap
+
1
:],
],
axis
=
1
,
...
...
@@ -438,7 +447,10 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
shift
=
[
1
,
window_overlap
],
axis
=
[
2
,
3
],
)[:,
:,
:
window_overlap
,
:
window_overlap
],
tf
.
zeros
((
batch_size
*
num_heads
,
1
,
window_overlap
,
window_overlap
)),
tf
.
zeros
(
(
batch_size
*
num_heads
,
1
,
window_overlap
,
window_overlap
),
dtype
=
diagonal_chunked_attention_scores
.
dtype
,
),
],
axis
=
1
,
)
...
...
@@ -496,7 +508,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
mask_4d
=
tf
.
tile
(
mask_2d
[
None
,
:,
None
,
:],
(
shape_list
(
input_tensor
)[
0
],
1
,
1
,
1
))
# inf tensor used for masking
inf_tensor
=
-
float
(
"inf"
)
*
tf
.
ones_like
(
input_tensor
,
dtype
=
tf
.
dtypes
.
float32
)
inf_tensor
=
-
float
(
"inf"
)
*
tf
.
ones_like
(
input_tensor
)
# mask
input_tensor
=
tf
.
where
(
tf
.
math
.
greater
(
mask_4d
,
0
),
inf_tensor
,
input_tensor
)
...
...
@@ -511,6 +523,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
value
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
seq_len
%
(
window_overlap
*
2
),
0
,
...
...
@@ -547,7 +560,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
)
# pad seq_len with w at the beginning of the sequence and another window overlap at the end
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
window_overlap
,
window_overlap
],
[
0
,
0
]]
,
dtype
=
tf
.
dtypes
.
int32
)
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
window_overlap
,
window_overlap
],
[
0
,
0
]])
padded_value
=
tf
.
pad
(
value
,
paddings
,
constant_values
=-
1
)
# chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
...
...
@@ -563,6 +576,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
(
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
),
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
chunked_value
),
[
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
],
...
...
@@ -640,6 +654,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# chunk with overlap
chunked_hidden_states
=
tf
.
signal
.
frame
(
hidden_states
,
frame_size
,
frame_hop_size
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
chunked_hidden_states
),
[
batch_size
,
num_output_chunks
,
frame_size
],
...
...
@@ -657,7 +672,8 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
def
_get_global_attn_indices
(
is_index_global_attn
):
""" compute global attn indices required throughout forward pass """
# helper variable
num_global_attn_indices
=
tf
.
reduce_sum
(
tf
.
cast
(
is_index_global_attn
,
dtype
=
tf
.
dtypes
.
int32
),
axis
=
1
)
num_global_attn_indices
=
tf
.
math
.
count_nonzero
(
is_index_global_attn
,
axis
=
1
)
num_global_attn_indices
=
tf
.
cast
(
num_global_attn_indices
,
dtype
=
tf
.
constant
(
1
).
dtype
)
# max number of global attn indices in batch
max_num_global_attn_indices
=
tf
.
reduce_max
(
num_global_attn_indices
)
...
...
@@ -719,6 +735,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
shape_list
(
attn_probs_from_global_key_trans
)[
-
2
:]
)
mask
=
tf
.
ones
(
mask_shape
)
*
-
10000.0
mask
=
tf
.
cast
(
mask
,
dtype
=
attn_probs_from_global_key_trans
.
dtype
)
# scatter mask
attn_probs_from_global_key_trans
=
tf
.
tensor_scatter_nd_update
(
...
...
@@ -805,7 +822,9 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
global_value_vectors
=
self
.
value_global
(
hidden_states
)
# normalize
global_query_vectors_only_global
/=
tf
.
math
.
sqrt
(
tf
.
convert_to_tensor
(
self
.
head_dim
,
dtype
=
tf
.
dtypes
.
float32
))
global_query_vectors_only_global
/=
tf
.
math
.
sqrt
(
tf
.
cast
(
self
.
head_dim
,
dtype
=
global_query_vectors_only_global
.
dtype
)
)
global_query_vectors_only_global
=
self
.
reshape_and_transpose
(
global_query_vectors_only_global
,
batch_size
)
global_key_vectors
=
self
.
reshape_and_transpose
(
global_key_vectors
,
batch_size
)
global_value_vectors
=
self
.
reshape_and_transpose
(
global_value_vectors
,
batch_size
)
...
...
@@ -813,6 +832,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# compute attn scores
global_attn_scores
=
tf
.
matmul
(
global_query_vectors_only_global
,
global_key_vectors
,
transpose_b
=
True
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
global_attn_scores
),
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
],
...
...
@@ -828,6 +848,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
shape_list
(
global_attn_scores_trans
)[
-
2
:]
)
global_attn_mask
=
tf
.
ones
(
mask_shape
)
*
-
10000.0
global_attn_mask
=
tf
.
cast
(
global_attn_mask
,
dtype
=
global_attn_scores_trans
.
dtype
)
# scatter mask
global_attn_scores_trans
=
tf
.
tensor_scatter_nd_update
(
...
...
@@ -850,6 +871,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# apply layer head maskin
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
layer_head_mask
),
[
self
.
num_heads
],
...
...
@@ -868,6 +890,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# global attn output
global_attn_output
=
tf
.
matmul
(
global_attn_probs
,
global_value_vectors
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
global_attn_output
),
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
self
.
head_dim
],
...
...
@@ -1023,6 +1046,7 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_weights
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
...
...
@@ -1030,22 +1054,28 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
)
if
attention_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is
{
shape_list
(
attention_mask
)
}
"
,
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
nn
.
softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
layer_head_mask
),
[
self
.
num_heads
],
message
=
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is
{
shape_list
(
layer_head_mask
)
}
"
,
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
)
...
...
@@ -1055,6 +1085,7 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_output
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
...
...
@@ -1111,6 +1142,7 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer):
hidden_states
=
layer_outputs
[
0
]
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
hidden_states
),
shape_list
(
residual
),
...
...
@@ -1707,12 +1739,13 @@ class TFLEDEncoder(tf.keras.layers.Layer):
all_attentions
=
all_global_attentions
=
()
if
inputs
[
"output_attentions"
]
else
None
# check if head_mask has a correct number of layers specified if desired
if
inputs
[
"head_mask"
]
is
not
None
:
if
inputs
[
"head_mask"
]
is
not
None
and
tf
.
executing_eagerly
()
:
tf
.
debugging
.
assert_equal
(
shape_list
(
inputs
[
"head_mask"
])[
0
],
len
(
self
.
layers
),
message
=
f
"The head_mask should be specified for
{
len
(
self
.
layers
)
}
layers, but it is for
{
shape_list
(
inputs
[
'head_mask'
])[
0
]
}
."
,
)
# encoder layers
for
idx
,
encoder_layer
in
enumerate
(
self
.
layers
):
...
...
@@ -1981,12 +2014,13 @@ class TFLEDDecoder(tf.keras.layers.Layer):
present_key_values
=
()
# check if head_mask has a correct number of layers specified if desired
if
inputs
[
"head_mask"
]
is
not
None
:
if
inputs
[
"head_mask"
]
is
not
None
and
tf
.
executing_eagerly
()
:
tf
.
debugging
.
assert_equal
(
shape_list
(
inputs
[
"head_mask"
])[
0
],
len
(
self
.
layers
),
message
=
f
"The head_mask should be specified for
{
len
(
self
.
layers
)
}
layers, but it is for
{
shape_list
(
inputs
[
'head_mask'
])[
0
]
}
."
,
)
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if
inputs
[
"output_hidden_states"
]:
...
...
src/transformers/models/longformer/modeling_tf_longformer.py
View file @
19e737b9
...
...
@@ -392,23 +392,22 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se
"""
assert
shape_list
(
sep_token_indices
)[
1
]
==
2
,
"`input_ids` should have two dimensions"
question_end_index
=
tf
.
reshape
(
sep_token_indices
,
(
input_ids_shape
[
0
],
3
,
2
))[:,
0
,
1
]
question_end_index
=
tf
.
cast
(
question_end_index
[:,
None
],
tf
.
dtypes
.
int32
)
# size: batch_size x 1
question_end_index
=
tf
.
reshape
(
sep_token_indices
,
(
input_ids_shape
[
0
],
3
,
2
))[:,
0
,
1
][:,
None
]
# bool attention mask with True in locations of global attention
attention_mask
=
tf
.
range
(
input_ids_shape
[
1
])
[
tf
.
newaxis
,
:]
attention_mask
=
tf
.
expand_dims
(
tf
.
range
(
input_ids_shape
[
1
])
,
axis
=
0
)
attention_mask
=
tf
.
tile
(
attention_mask
,
(
input_ids_shape
[
0
],
1
))
if
before_sep_token
is
True
:
question_end_index
=
tf
.
tile
(
question_end_index
,
(
1
,
input_ids_shape
[
1
]))
attention_mask
=
tf
.
cast
(
attention_mask
<
question_end_index
,
tf
.
int32
)
attention_mask
=
tf
.
cast
(
attention_mask
<
question_end_index
,
dtype
=
question_end_index
.
dtype
)
else
:
# last token is separation token and should not be counted and in the middle are two separation tokens
question_end_index
=
tf
.
tile
(
question_end_index
+
1
,
(
1
,
input_ids_shape
[
1
]))
attention_mask
=
(
tf
.
cast
(
attention_mask
>
question_end_index
,
tf
.
dtype
s
.
int32
,
dtype
=
question_end_index
.
dtype
,
)
*
tf
.
cast
(
attention_mask
<
input_ids_shape
[
-
1
],
tf
.
dtype
s
.
int32
)
*
tf
.
cast
(
attention_mask
<
input_ids_shape
[
-
1
],
dtype
=
question_end_index
.
dtype
)
)
return
attention_mask
...
...
@@ -730,6 +729,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
value_vectors
=
self
.
value
(
hidden_states
)
batch_size
,
seq_len
,
embed_dim
=
shape_list
(
hidden_states
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
embed_dim
,
self
.
embed_dim
,
...
...
@@ -737,7 +737,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
# normalize query
query_vectors
/=
tf
.
math
.
sqrt
(
tf
.
c
onvert_to_tensor
(
self
.
head_dim
,
dtype
=
tf
.
dtypes
.
float32
))
query_vectors
/=
tf
.
math
.
sqrt
(
tf
.
c
ast
(
self
.
head_dim
,
dtype
=
query_vectors
.
dtype
))
query_vectors
=
tf
.
reshape
(
query_vectors
,
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
))
key_vectors
=
tf
.
reshape
(
key_vectors
,
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
))
...
...
@@ -748,7 +748,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask
=
self
.
_sliding_chunks_query_key_matmul
(
tf
.
ones
(
shape_list
(
attention_mask
)
,
dtype
=
tf
.
float32
),
tf
.
ones
(
shape_list
(
attention_mask
)),
attention_mask
,
self
.
one_sided_attn_window_size
,
)
...
...
@@ -756,6 +756,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# pad local attention probs
attn_scores
+=
diagonal_mask
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_scores
),
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
one_sided_attn_window_size
*
2
+
1
],
...
...
@@ -803,16 +804,18 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
attn_probs
=
tf
.
where
(
masked_index
,
tf
.
zeros
(
shape_list
(
masked_index
),
dtype
=
tf
.
dtypes
.
float32
),
tf
.
zeros
(
shape_list
(
masked_index
),
dtype
=
attn_probs
.
dtype
),
attn_probs
,
)
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
layer_head_mask
),
[
self
.
num_heads
],
message
=
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is
{
shape_list
(
layer_head_mask
)
}
"
,
)
attn_probs
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
1
,
-
1
,
1
))
*
attn_probs
# apply dropout
...
...
@@ -834,6 +837,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
),
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_output
),
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
],
...
...
@@ -877,7 +881,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
attn_probs
=
tf
.
where
(
masked_global_attn_index
,
tf
.
zeros
(
shape_list
(
masked_global_attn_index
),
dtype
=
tf
.
dtypes
.
float32
),
tf
.
zeros
(
shape_list
(
masked_global_attn_index
),
dtype
=
attn_probs
.
dtype
),
attn_probs
,
)
...
...
@@ -893,6 +897,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
"""
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
query
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
seq_len
%
(
window_overlap
*
2
),
0
,
...
...
@@ -919,10 +924,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap
chunked_query
=
tf
.
cast
(
chunked_query
,
dtype
=
chunked_key
.
dtype
)
chunked_attention_scores
=
tf
.
einsum
(
"bcxd,bcyd->bcxy"
,
chunked_query
,
chunked_key
)
# multiply
# convert diagonals into columns
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
0
,
0
],
[
0
,
1
],
[
0
,
0
]]
,
dtype
=
tf
.
dtypes
.
int32
)
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
0
,
0
],
[
0
,
1
],
[
0
,
0
]])
diagonal_chunked_attention_scores
=
self
.
_pad_and_transpose_last_two_dims
(
chunked_attention_scores
,
paddings
)
# allocate space for the overall attention matrix where the chunks are combined. The last dimension
...
...
@@ -944,7 +950,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# - copying the lower triangle
diagonal_attn_scores_low_triang
=
tf
.
concat
(
[
tf
.
zeros
((
batch_size
*
num_heads
,
1
,
window_overlap
,
window_overlap
)),
tf
.
zeros
(
(
batch_size
*
num_heads
,
1
,
window_overlap
,
window_overlap
),
dtype
=
diagonal_chunked_attention_scores
.
dtype
,
),
diagonal_chunked_attention_scores
[:,
:,
-
(
window_overlap
+
1
)
:
-
1
,
window_overlap
+
1
:],
],
axis
=
1
,
...
...
@@ -956,7 +965,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
shift
=
[
1
,
window_overlap
],
axis
=
[
2
,
3
],
)[:,
:,
:
window_overlap
,
:
window_overlap
],
tf
.
zeros
((
batch_size
*
num_heads
,
1
,
window_overlap
,
window_overlap
)),
tf
.
zeros
(
(
batch_size
*
num_heads
,
1
,
window_overlap
,
window_overlap
),
dtype
=
diagonal_chunked_attention_scores
.
dtype
,
),
],
axis
=
1
,
)
...
...
@@ -1014,7 +1026,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
mask_4d
=
tf
.
tile
(
mask_2d
[
None
,
:,
None
,
:],
(
shape_list
(
input_tensor
)[
0
],
1
,
1
,
1
))
# inf tensor used for masking
inf_tensor
=
-
float
(
"inf"
)
*
tf
.
ones_like
(
input_tensor
,
dtype
=
tf
.
dtypes
.
float32
)
inf_tensor
=
-
float
(
"inf"
)
*
tf
.
ones_like
(
input_tensor
)
# mask
input_tensor
=
tf
.
where
(
tf
.
math
.
greater
(
mask_4d
,
0
),
inf_tensor
,
input_tensor
)
...
...
@@ -1029,6 +1041,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
value
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
seq_len
%
(
window_overlap
*
2
),
0
,
...
...
@@ -1065,7 +1078,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
# pad seq_len with w at the beginning of the sequence and another window overlap at the end
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
window_overlap
,
window_overlap
],
[
0
,
0
]]
,
dtype
=
tf
.
dtypes
.
int32
)
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
window_overlap
,
window_overlap
],
[
0
,
0
]])
padded_value
=
tf
.
pad
(
value
,
paddings
,
constant_values
=-
1
)
# chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
...
...
@@ -1081,6 +1094,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
(
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
),
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
chunked_value
),
[
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
],
...
...
@@ -1158,6 +1172,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# chunk with overlap
chunked_hidden_states
=
tf
.
signal
.
frame
(
hidden_states
,
frame_size
,
frame_hop_size
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
chunked_hidden_states
),
[
batch_size
,
num_output_chunks
,
frame_size
],
...
...
@@ -1175,7 +1190,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
def
_get_global_attn_indices
(
is_index_global_attn
):
""" compute global attn indices required throughout forward pass """
# helper variable
num_global_attn_indices
=
tf
.
reduce_sum
(
tf
.
cast
(
is_index_global_attn
,
dtype
=
tf
.
dtypes
.
int32
),
axis
=
1
)
num_global_attn_indices
=
tf
.
math
.
count_nonzero
(
is_index_global_attn
,
axis
=
1
)
num_global_attn_indices
=
tf
.
cast
(
num_global_attn_indices
,
dtype
=
tf
.
constant
(
1
).
dtype
)
# max number of global attn indices in batch
max_num_global_attn_indices
=
tf
.
reduce_max
(
num_global_attn_indices
)
...
...
@@ -1237,6 +1253,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
shape_list
(
attn_probs_from_global_key_trans
)[
-
2
:]
)
mask
=
tf
.
ones
(
mask_shape
)
*
-
10000.0
mask
=
tf
.
cast
(
mask
,
dtype
=
attn_probs_from_global_key_trans
.
dtype
)
# scatter mask
attn_probs_from_global_key_trans
=
tf
.
tensor_scatter_nd_update
(
...
...
@@ -1323,7 +1340,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
global_value_vectors
=
self
.
value_global
(
hidden_states
)
# normalize
global_query_vectors_only_global
/=
tf
.
math
.
sqrt
(
tf
.
convert_to_tensor
(
self
.
head_dim
,
dtype
=
tf
.
dtypes
.
float32
))
global_query_vectors_only_global
/=
tf
.
math
.
sqrt
(
tf
.
cast
(
self
.
head_dim
,
dtype
=
global_query_vectors_only_global
.
dtype
)
)
global_query_vectors_only_global
=
self
.
reshape_and_transpose
(
global_query_vectors_only_global
,
batch_size
)
global_key_vectors
=
self
.
reshape_and_transpose
(
global_key_vectors
,
batch_size
)
global_value_vectors
=
self
.
reshape_and_transpose
(
global_value_vectors
,
batch_size
)
...
...
@@ -1331,6 +1350,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# compute attn scores
global_attn_scores
=
tf
.
matmul
(
global_query_vectors_only_global
,
global_key_vectors
,
transpose_b
=
True
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
global_attn_scores
),
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
],
...
...
@@ -1346,6 +1366,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
shape_list
(
global_attn_scores_trans
)[
-
2
:]
)
global_attn_mask
=
tf
.
ones
(
mask_shape
)
*
-
10000.0
global_attn_mask
=
tf
.
cast
(
global_attn_mask
,
dtype
=
global_attn_scores_trans
.
dtype
)
# scatter mask
global_attn_scores_trans
=
tf
.
tensor_scatter_nd_update
(
...
...
@@ -1368,6 +1389,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# apply layer head maskin
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
layer_head_mask
),
[
self
.
num_heads
],
...
...
@@ -1386,6 +1408,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# global attn output
global_attn_output
=
tf
.
matmul
(
global_attn_probs
,
global_value_vectors
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
shape_list
(
global_attn_output
),
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
self
.
head_dim
],
...
...
@@ -2230,6 +2253,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
logger
.
info
(
"Initializing global attention on question tokens..."
)
# put global attention on all tokens until `config.sep_token_id` is reached
sep_token_indices
=
tf
.
where
(
inputs
[
"input_ids"
]
==
self
.
config
.
sep_token_id
)
sep_token_indices
=
tf
.
cast
(
sep_token_indices
,
dtype
=
inputs
[
"input_ids"
].
dtype
)
inputs
[
"global_attention_mask"
]
=
_compute_global_attention_mask
(
shape_list
(
inputs
[
"input_ids"
]),
sep_token_indices
)
...
...
tests/test_modeling_tf_led.py
View file @
19e737b9
...
...
@@ -362,10 +362,6 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
self
.
assertEqual
(
model
.
config
.
output_hidden_states
,
True
)
check_encoder_attentions_output
(
outputs
)
def
test_mixed_precision
(
self
):
# TODO JP: Make LED float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make LED XLA compliant
pass
...
...
tests/test_modeling_tf_longformer.py
View file @
19e737b9
...
...
@@ -343,10 +343,6 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make Longformer float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make Longformer XLA compliant
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment