Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
19e737b9
Unverified
Commit
19e737b9
authored
Feb 22, 2021
by
Julien Plu
Committed by
GitHub
Feb 22, 2021
Browse files
Making TF Longformer-like models compliant with AMP (#10233)
* AMP * Add LED * Apply style * Fix longformer
parent
cd8c4c3f
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
267 additions
and
217 deletions
+267
-217
src/transformers/models/led/modeling_tf_led.py
src/transformers/models/led/modeling_tf_led.py
+156
-122
src/transformers/models/longformer/modeling_tf_longformer.py
src/transformers/models/longformer/modeling_tf_longformer.py
+111
-87
tests/test_modeling_tf_led.py
tests/test_modeling_tf_led.py
+0
-4
tests/test_modeling_tf_longformer.py
tests/test_modeling_tf_longformer.py
+0
-4
No files found.
src/transformers/models/led/modeling_tf_led.py
View file @
19e737b9
...
@@ -55,8 +55,7 @@ LARGE_NEGATIVE = -1e8
...
@@ -55,8 +55,7 @@ LARGE_NEGATIVE = -1e8
def
shift_tokens_right
(
input_ids
:
tf
.
Tensor
,
pad_token_id
:
int
,
decoder_start_token_id
:
int
):
def
shift_tokens_right
(
input_ids
:
tf
.
Tensor
,
pad_token_id
:
int
,
decoder_start_token_id
:
int
):
shifted_input_ids
=
tf
.
cast
(
input_ids
,
tf
.
int32
)
shifted_input_ids
=
tf
.
roll
(
input_ids
,
1
,
axis
=-
1
)
shifted_input_ids
=
tf
.
roll
(
shifted_input_ids
,
1
,
axis
=-
1
)
start_tokens
=
tf
.
fill
((
shape_list
(
shifted_input_ids
)[
0
],
1
),
decoder_start_token_id
)
start_tokens
=
tf
.
fill
((
shape_list
(
shifted_input_ids
)[
0
],
1
),
decoder_start_token_id
)
shifted_input_ids
=
tf
.
concat
([
start_tokens
,
shifted_input_ids
[:,
1
:]],
-
1
)
shifted_input_ids
=
tf
.
concat
([
start_tokens
,
shifted_input_ids
[:,
1
:]],
-
1
)
# replace possible -100 values in labels by `pad_token_id`
# replace possible -100 values in labels by `pad_token_id`
...
@@ -65,7 +64,8 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
...
@@ -65,7 +64,8 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
)
)
# "Verify that `labels` has only positive values and -100"
# "Verify that `labels` has only positive values and -100"
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
cast
(
0
,
tf
.
int32
))
if
tf
.
executing_eagerly
():
assert_gte0
=
tf
.
debugging
.
assert_greater_equal
(
shifted_input_ids
,
tf
.
constant
(
0
))
# Make sure the assertion op is called by wrapping the result in an identity no-op
# Make sure the assertion op is called by wrapping the result in an identity no-op
with
tf
.
control_dependencies
([
assert_gte0
]):
with
tf
.
control_dependencies
([
assert_gte0
]):
...
@@ -79,14 +79,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
...
@@ -79,14 +79,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
Make causal mask used for bi-directional self-attention.
Make causal mask used for bi-directional self-attention.
"""
"""
bsz
,
tgt_len
=
input_ids_shape
bsz
,
tgt_len
=
input_ids_shape
mask
=
tf
.
ones
((
tgt_len
,
tgt_len
)
,
dtype
=
tf
.
float32
)
*
LARGE_NEGATIVE
mask
=
tf
.
ones
((
tgt_len
,
tgt_len
))
*
LARGE_NEGATIVE
mask_cond
=
tf
.
range
(
shape_list
(
mask
)[
-
1
])
mask_cond
=
tf
.
range
(
shape_list
(
mask
)[
-
1
])
mask
=
tf
.
where
(
mask_cond
<
tf
.
reshape
(
mask_cond
+
1
,
(
shape_list
(
mask
)[
-
1
],
1
)),
0.0
,
mask
)
mask
=
tf
.
where
(
mask_cond
<
tf
.
reshape
(
mask_cond
+
1
,
(
shape_list
(
mask
)[
-
1
],
1
)),
0.0
,
mask
)
mask
=
tf
.
cast
(
mask
,
tf
.
float32
)
if
past_key_values_length
>
0
:
if
past_key_values_length
>
0
:
mask
=
tf
.
concat
([
tf
.
zeros
((
tgt_len
,
past_key_values_length
)
,
dtype
=
tf
.
float32
),
mask
],
axis
=-
1
)
mask
=
tf
.
concat
([
tf
.
zeros
((
tgt_len
,
past_key_values_length
)),
mask
],
axis
=-
1
)
return
tf
.
tile
(
mask
[
None
,
None
,
:,
:],
(
bsz
,
1
,
1
,
1
))
return
tf
.
tile
(
mask
[
None
,
None
,
:,
:],
(
bsz
,
1
,
1
,
1
))
...
@@ -97,9 +96,11 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
...
@@ -97,9 +96,11 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
"""
"""
src_len
=
shape_list
(
mask
)[
1
]
src_len
=
shape_list
(
mask
)[
1
]
tgt_len
=
tgt_len
if
tgt_len
is
not
None
else
src_len
tgt_len
=
tgt_len
if
tgt_len
is
not
None
else
src_len
expanded_mask
=
tf
.
cast
(
tf
.
tile
(
mask
[:,
None
,
None
,
:],
(
1
,
1
,
tgt_len
,
1
)),
tf
.
float32
)
one_cst
=
tf
.
constant
(
1.0
)
mask
=
tf
.
cast
(
mask
,
dtype
=
one_cst
.
dtype
)
expanded_mask
=
tf
.
tile
(
mask
[:,
None
,
None
,
:],
(
1
,
1
,
tgt_len
,
1
))
return
(
1.0
-
expanded_mask
)
*
LARGE_NEGATIVE
return
(
one_cst
-
expanded_mask
)
*
LARGE_NEGATIVE
class
TFLEDLearnedPositionalEmbedding
(
TFSharedEmbeddings
):
class
TFLEDLearnedPositionalEmbedding
(
TFSharedEmbeddings
):
...
@@ -115,9 +116,7 @@ class TFLEDLearnedPositionalEmbedding(TFSharedEmbeddings):
...
@@ -115,9 +116,7 @@ class TFLEDLearnedPositionalEmbedding(TFSharedEmbeddings):
"""Input is expected to be of size [bsz x seqlen]."""
"""Input is expected to be of size [bsz x seqlen]."""
bsz
,
seq_len
=
input_shape
[:
2
]
bsz
,
seq_len
=
input_shape
[:
2
]
positions
=
tf
.
range
(
positions
=
tf
.
range
(
past_key_values_length
,
seq_len
+
past_key_values_length
,
delta
=
1
,
name
=
"range"
)
past_key_values_length
,
seq_len
+
past_key_values_length
,
delta
=
1
,
dtype
=
tf
.
int32
,
name
=
"range"
)
return
super
().
call
(
positions
)
return
super
().
call
(
positions
)
...
@@ -212,6 +211,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -212,6 +211,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
value_vectors
=
self
.
value
(
hidden_states
)
value_vectors
=
self
.
value
(
hidden_states
)
batch_size
,
seq_len
,
embed_dim
=
shape_list
(
hidden_states
)
batch_size
,
seq_len
,
embed_dim
=
shape_list
(
hidden_states
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
embed_dim
,
embed_dim
,
self
.
embed_dim
,
self
.
embed_dim
,
...
@@ -219,7 +219,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -219,7 +219,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
)
)
# normalize query
# normalize query
query_vectors
/=
tf
.
math
.
sqrt
(
tf
.
c
onvert_to_tensor
(
self
.
head_dim
,
dtype
=
tf
.
dtypes
.
float32
))
query_vectors
/=
tf
.
math
.
sqrt
(
tf
.
c
ast
(
self
.
head_dim
,
dtype
=
query_vectors
.
dtype
))
query_vectors
=
tf
.
reshape
(
query_vectors
,
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
))
query_vectors
=
tf
.
reshape
(
query_vectors
,
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
))
key_vectors
=
tf
.
reshape
(
key_vectors
,
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
))
key_vectors
=
tf
.
reshape
(
key_vectors
,
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
))
...
@@ -230,7 +230,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -230,7 +230,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# diagonal mask with zeros everywhere and -inf inplace of padding
# diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask
=
self
.
_sliding_chunks_query_key_matmul
(
diagonal_mask
=
self
.
_sliding_chunks_query_key_matmul
(
tf
.
ones
(
shape_list
(
attention_mask
)
,
dtype
=
tf
.
float32
),
tf
.
ones
(
shape_list
(
attention_mask
)),
attention_mask
,
attention_mask
,
self
.
one_sided_attn_window_size
,
self
.
one_sided_attn_window_size
,
)
)
...
@@ -238,6 +238,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -238,6 +238,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# pad local attention probs
# pad local attention probs
attn_scores
+=
diagonal_mask
attn_scores
+=
diagonal_mask
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_scores
),
shape_list
(
attn_scores
),
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
one_sided_attn_window_size
*
2
+
1
],
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
one_sided_attn_window_size
*
2
+
1
],
...
@@ -285,16 +286,18 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -285,16 +286,18 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
)
)
attn_probs
=
tf
.
where
(
attn_probs
=
tf
.
where
(
masked_index
,
masked_index
,
tf
.
zeros
(
shape_list
(
masked_index
),
dtype
=
tf
.
dtypes
.
float32
),
tf
.
zeros
(
shape_list
(
masked_index
),
dtype
=
attn_probs
.
dtype
),
attn_probs
,
attn_probs
,
)
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
layer_head_mask
),
shape_list
(
layer_head_mask
),
[
self
.
num_heads
],
[
self
.
num_heads
],
message
=
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is
{
shape_list
(
layer_head_mask
)
}
"
,
message
=
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is
{
shape_list
(
layer_head_mask
)
}
"
,
)
)
attn_probs
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
1
,
-
1
,
1
))
*
attn_probs
attn_probs
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
1
,
-
1
,
1
))
*
attn_probs
# apply dropout
# apply dropout
...
@@ -316,6 +319,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -316,6 +319,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
),
),
)
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_output
),
shape_list
(
attn_output
),
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
],
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
],
...
@@ -359,7 +363,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -359,7 +363,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
)
)
attn_probs
=
tf
.
where
(
attn_probs
=
tf
.
where
(
masked_global_attn_index
,
masked_global_attn_index
,
tf
.
zeros
(
shape_list
(
masked_global_attn_index
),
dtype
=
tf
.
dtypes
.
float32
),
tf
.
zeros
(
shape_list
(
masked_global_attn_index
),
dtype
=
attn_probs
.
dtype
),
attn_probs
,
attn_probs
,
)
)
...
@@ -375,6 +379,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -375,6 +379,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
"""
"""
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
query
)
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
query
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
seq_len
%
(
window_overlap
*
2
),
seq_len
%
(
window_overlap
*
2
),
0
,
0
,
...
@@ -401,10 +406,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -401,10 +406,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap
# bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap
chunked_query
=
tf
.
cast
(
chunked_query
,
dtype
=
chunked_key
.
dtype
)
chunked_attention_scores
=
tf
.
einsum
(
"bcxd,bcyd->bcxy"
,
chunked_query
,
chunked_key
)
# multiply
chunked_attention_scores
=
tf
.
einsum
(
"bcxd,bcyd->bcxy"
,
chunked_query
,
chunked_key
)
# multiply
# convert diagonals into columns
# convert diagonals into columns
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
0
,
0
],
[
0
,
1
],
[
0
,
0
]]
,
dtype
=
tf
.
dtypes
.
int32
)
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
0
,
0
],
[
0
,
1
],
[
0
,
0
]])
diagonal_chunked_attention_scores
=
self
.
_pad_and_transpose_last_two_dims
(
chunked_attention_scores
,
paddings
)
diagonal_chunked_attention_scores
=
self
.
_pad_and_transpose_last_two_dims
(
chunked_attention_scores
,
paddings
)
# allocate space for the overall attention matrix where the chunks are combined. The last dimension
# allocate space for the overall attention matrix where the chunks are combined. The last dimension
...
@@ -426,7 +432,10 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -426,7 +432,10 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# - copying the lower triangle
# - copying the lower triangle
diagonal_attn_scores_low_triang
=
tf
.
concat
(
diagonal_attn_scores_low_triang
=
tf
.
concat
(
[
[
tf
.
zeros
((
batch_size
*
num_heads
,
1
,
window_overlap
,
window_overlap
)),
tf
.
zeros
(
(
batch_size
*
num_heads
,
1
,
window_overlap
,
window_overlap
),
dtype
=
diagonal_chunked_attention_scores
.
dtype
,
),
diagonal_chunked_attention_scores
[:,
:,
-
(
window_overlap
+
1
)
:
-
1
,
window_overlap
+
1
:],
diagonal_chunked_attention_scores
[:,
:,
-
(
window_overlap
+
1
)
:
-
1
,
window_overlap
+
1
:],
],
],
axis
=
1
,
axis
=
1
,
...
@@ -438,7 +447,10 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -438,7 +447,10 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
shift
=
[
1
,
window_overlap
],
shift
=
[
1
,
window_overlap
],
axis
=
[
2
,
3
],
axis
=
[
2
,
3
],
)[:,
:,
:
window_overlap
,
:
window_overlap
],
)[:,
:,
:
window_overlap
,
:
window_overlap
],
tf
.
zeros
((
batch_size
*
num_heads
,
1
,
window_overlap
,
window_overlap
)),
tf
.
zeros
(
(
batch_size
*
num_heads
,
1
,
window_overlap
,
window_overlap
),
dtype
=
diagonal_chunked_attention_scores
.
dtype
,
),
],
],
axis
=
1
,
axis
=
1
,
)
)
...
@@ -496,7 +508,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -496,7 +508,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
mask_4d
=
tf
.
tile
(
mask_2d
[
None
,
:,
None
,
:],
(
shape_list
(
input_tensor
)[
0
],
1
,
1
,
1
))
mask_4d
=
tf
.
tile
(
mask_2d
[
None
,
:,
None
,
:],
(
shape_list
(
input_tensor
)[
0
],
1
,
1
,
1
))
# inf tensor used for masking
# inf tensor used for masking
inf_tensor
=
-
float
(
"inf"
)
*
tf
.
ones_like
(
input_tensor
,
dtype
=
tf
.
dtypes
.
float32
)
inf_tensor
=
-
float
(
"inf"
)
*
tf
.
ones_like
(
input_tensor
)
# mask
# mask
input_tensor
=
tf
.
where
(
tf
.
math
.
greater
(
mask_4d
,
0
),
inf_tensor
,
input_tensor
)
input_tensor
=
tf
.
where
(
tf
.
math
.
greater
(
mask_4d
,
0
),
inf_tensor
,
input_tensor
)
...
@@ -511,6 +523,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -511,6 +523,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
value
)
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
value
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
seq_len
%
(
window_overlap
*
2
),
seq_len
%
(
window_overlap
*
2
),
0
,
0
,
...
@@ -547,7 +560,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -547,7 +560,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
)
)
# pad seq_len with w at the beginning of the sequence and another window overlap at the end
# pad seq_len with w at the beginning of the sequence and another window overlap at the end
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
window_overlap
,
window_overlap
],
[
0
,
0
]]
,
dtype
=
tf
.
dtypes
.
int32
)
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
window_overlap
,
window_overlap
],
[
0
,
0
]])
padded_value
=
tf
.
pad
(
value
,
paddings
,
constant_values
=-
1
)
padded_value
=
tf
.
pad
(
value
,
paddings
,
constant_values
=-
1
)
# chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
# chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
...
@@ -563,6 +576,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -563,6 +576,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
(
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
),
(
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
),
)
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
chunked_value
),
shape_list
(
chunked_value
),
[
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
],
[
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
],
...
@@ -640,6 +654,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -640,6 +654,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# chunk with overlap
# chunk with overlap
chunked_hidden_states
=
tf
.
signal
.
frame
(
hidden_states
,
frame_size
,
frame_hop_size
)
chunked_hidden_states
=
tf
.
signal
.
frame
(
hidden_states
,
frame_size
,
frame_hop_size
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
chunked_hidden_states
),
shape_list
(
chunked_hidden_states
),
[
batch_size
,
num_output_chunks
,
frame_size
],
[
batch_size
,
num_output_chunks
,
frame_size
],
...
@@ -657,7 +672,8 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -657,7 +672,8 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
def
_get_global_attn_indices
(
is_index_global_attn
):
def
_get_global_attn_indices
(
is_index_global_attn
):
""" compute global attn indices required throughout forward pass """
""" compute global attn indices required throughout forward pass """
# helper variable
# helper variable
num_global_attn_indices
=
tf
.
reduce_sum
(
tf
.
cast
(
is_index_global_attn
,
dtype
=
tf
.
dtypes
.
int32
),
axis
=
1
)
num_global_attn_indices
=
tf
.
math
.
count_nonzero
(
is_index_global_attn
,
axis
=
1
)
num_global_attn_indices
=
tf
.
cast
(
num_global_attn_indices
,
dtype
=
tf
.
constant
(
1
).
dtype
)
# max number of global attn indices in batch
# max number of global attn indices in batch
max_num_global_attn_indices
=
tf
.
reduce_max
(
num_global_attn_indices
)
max_num_global_attn_indices
=
tf
.
reduce_max
(
num_global_attn_indices
)
...
@@ -719,6 +735,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -719,6 +735,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
shape_list
(
attn_probs_from_global_key_trans
)[
-
2
:]
shape_list
(
attn_probs_from_global_key_trans
)[
-
2
:]
)
)
mask
=
tf
.
ones
(
mask_shape
)
*
-
10000.0
mask
=
tf
.
ones
(
mask_shape
)
*
-
10000.0
mask
=
tf
.
cast
(
mask
,
dtype
=
attn_probs_from_global_key_trans
.
dtype
)
# scatter mask
# scatter mask
attn_probs_from_global_key_trans
=
tf
.
tensor_scatter_nd_update
(
attn_probs_from_global_key_trans
=
tf
.
tensor_scatter_nd_update
(
...
@@ -805,7 +822,9 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -805,7 +822,9 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
global_value_vectors
=
self
.
value_global
(
hidden_states
)
global_value_vectors
=
self
.
value_global
(
hidden_states
)
# normalize
# normalize
global_query_vectors_only_global
/=
tf
.
math
.
sqrt
(
tf
.
convert_to_tensor
(
self
.
head_dim
,
dtype
=
tf
.
dtypes
.
float32
))
global_query_vectors_only_global
/=
tf
.
math
.
sqrt
(
tf
.
cast
(
self
.
head_dim
,
dtype
=
global_query_vectors_only_global
.
dtype
)
)
global_query_vectors_only_global
=
self
.
reshape_and_transpose
(
global_query_vectors_only_global
,
batch_size
)
global_query_vectors_only_global
=
self
.
reshape_and_transpose
(
global_query_vectors_only_global
,
batch_size
)
global_key_vectors
=
self
.
reshape_and_transpose
(
global_key_vectors
,
batch_size
)
global_key_vectors
=
self
.
reshape_and_transpose
(
global_key_vectors
,
batch_size
)
global_value_vectors
=
self
.
reshape_and_transpose
(
global_value_vectors
,
batch_size
)
global_value_vectors
=
self
.
reshape_and_transpose
(
global_value_vectors
,
batch_size
)
...
@@ -813,6 +832,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -813,6 +832,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# compute attn scores
# compute attn scores
global_attn_scores
=
tf
.
matmul
(
global_query_vectors_only_global
,
global_key_vectors
,
transpose_b
=
True
)
global_attn_scores
=
tf
.
matmul
(
global_query_vectors_only_global
,
global_key_vectors
,
transpose_b
=
True
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
global_attn_scores
),
shape_list
(
global_attn_scores
),
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
],
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
],
...
@@ -828,6 +848,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -828,6 +848,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
shape_list
(
global_attn_scores_trans
)[
-
2
:]
shape_list
(
global_attn_scores_trans
)[
-
2
:]
)
)
global_attn_mask
=
tf
.
ones
(
mask_shape
)
*
-
10000.0
global_attn_mask
=
tf
.
ones
(
mask_shape
)
*
-
10000.0
global_attn_mask
=
tf
.
cast
(
global_attn_mask
,
dtype
=
global_attn_scores_trans
.
dtype
)
# scatter mask
# scatter mask
global_attn_scores_trans
=
tf
.
tensor_scatter_nd_update
(
global_attn_scores_trans
=
tf
.
tensor_scatter_nd_update
(
...
@@ -850,6 +871,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -850,6 +871,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# apply layer head maskin
# apply layer head maskin
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
layer_head_mask
),
shape_list
(
layer_head_mask
),
[
self
.
num_heads
],
[
self
.
num_heads
],
...
@@ -868,6 +890,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
...
@@ -868,6 +890,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# global attn output
# global attn output
global_attn_output
=
tf
.
matmul
(
global_attn_probs
,
global_value_vectors
)
global_attn_output
=
tf
.
matmul
(
global_attn_probs
,
global_value_vectors
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
global_attn_output
),
shape_list
(
global_attn_output
),
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
self
.
head_dim
],
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
self
.
head_dim
],
...
@@ -1023,6 +1046,7 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
...
@@ -1023,6 +1046,7 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
src_len
=
shape_list
(
key_states
)[
1
]
src_len
=
shape_list
(
key_states
)[
1
]
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
attn_weights
=
tf
.
matmul
(
query_states
,
key_states
,
transpose_b
=
True
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_weights
),
shape_list
(
attn_weights
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
],
...
@@ -1030,22 +1054,28 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
...
@@ -1030,22 +1054,28 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
)
)
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attention_mask
),
shape_list
(
attention_mask
),
[
bsz
,
1
,
tgt_len
,
src_len
],
[
bsz
,
1
,
tgt_len
,
src_len
],
message
=
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is
{
shape_list
(
attention_mask
)
}
"
,
message
=
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is
{
shape_list
(
attention_mask
)
}
"
,
)
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
attention_mask
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
+
tf
.
cast
(
attention_mask
,
dtype
=
attn_weights
.
dtype
)
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
reshape
(
attn_weights
,
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
tf
.
nn
.
softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
tf
.
nn
.
softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
layer_head_mask
),
shape_list
(
layer_head_mask
),
[
self
.
num_heads
],
[
self
.
num_heads
],
message
=
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is
{
shape_list
(
layer_head_mask
)
}
"
,
message
=
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is
{
shape_list
(
layer_head_mask
)
}
"
,
)
)
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
,
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
)
)
...
@@ -1055,6 +1085,7 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
...
@@ -1055,6 +1085,7 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
attn_output
=
tf
.
matmul
(
attn_probs
,
value_states
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_output
),
shape_list
(
attn_output
),
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
],
...
@@ -1111,6 +1142,7 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer):
...
@@ -1111,6 +1142,7 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer):
hidden_states
=
layer_outputs
[
0
]
hidden_states
=
layer_outputs
[
0
]
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
hidden_states
),
shape_list
(
hidden_states
),
shape_list
(
residual
),
shape_list
(
residual
),
...
@@ -1707,12 +1739,13 @@ class TFLEDEncoder(tf.keras.layers.Layer):
...
@@ -1707,12 +1739,13 @@ class TFLEDEncoder(tf.keras.layers.Layer):
all_attentions
=
all_global_attentions
=
()
if
inputs
[
"output_attentions"
]
else
None
all_attentions
=
all_global_attentions
=
()
if
inputs
[
"output_attentions"
]
else
None
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
if
inputs
[
"head_mask"
]
is
not
None
:
if
inputs
[
"head_mask"
]
is
not
None
and
tf
.
executing_eagerly
()
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
inputs
[
"head_mask"
])[
0
],
shape_list
(
inputs
[
"head_mask"
])[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
message
=
f
"The head_mask should be specified for
{
len
(
self
.
layers
)
}
layers, but it is for
{
shape_list
(
inputs
[
'head_mask'
])[
0
]
}
."
,
message
=
f
"The head_mask should be specified for
{
len
(
self
.
layers
)
}
layers, but it is for
{
shape_list
(
inputs
[
'head_mask'
])[
0
]
}
."
,
)
)
# encoder layers
# encoder layers
for
idx
,
encoder_layer
in
enumerate
(
self
.
layers
):
for
idx
,
encoder_layer
in
enumerate
(
self
.
layers
):
...
@@ -1981,12 +2014,13 @@ class TFLEDDecoder(tf.keras.layers.Layer):
...
@@ -1981,12 +2014,13 @@ class TFLEDDecoder(tf.keras.layers.Layer):
present_key_values
=
()
present_key_values
=
()
# check if head_mask has a correct number of layers specified if desired
# check if head_mask has a correct number of layers specified if desired
if
inputs
[
"head_mask"
]
is
not
None
:
if
inputs
[
"head_mask"
]
is
not
None
and
tf
.
executing_eagerly
()
:
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
inputs
[
"head_mask"
])[
0
],
shape_list
(
inputs
[
"head_mask"
])[
0
],
len
(
self
.
layers
),
len
(
self
.
layers
),
message
=
f
"The head_mask should be specified for
{
len
(
self
.
layers
)
}
layers, but it is for
{
shape_list
(
inputs
[
'head_mask'
])[
0
]
}
."
,
message
=
f
"The head_mask should be specified for
{
len
(
self
.
layers
)
}
layers, but it is for
{
shape_list
(
inputs
[
'head_mask'
])[
0
]
}
."
,
)
)
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if
inputs
[
"output_hidden_states"
]:
if
inputs
[
"output_hidden_states"
]:
...
...
src/transformers/models/longformer/modeling_tf_longformer.py
View file @
19e737b9
...
@@ -392,23 +392,22 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se
...
@@ -392,23 +392,22 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se
"""
"""
assert
shape_list
(
sep_token_indices
)[
1
]
==
2
,
"`input_ids` should have two dimensions"
assert
shape_list
(
sep_token_indices
)[
1
]
==
2
,
"`input_ids` should have two dimensions"
question_end_index
=
tf
.
reshape
(
sep_token_indices
,
(
input_ids_shape
[
0
],
3
,
2
))[:,
0
,
1
]
question_end_index
=
tf
.
reshape
(
sep_token_indices
,
(
input_ids_shape
[
0
],
3
,
2
))[:,
0
,
1
][:,
None
]
question_end_index
=
tf
.
cast
(
question_end_index
[:,
None
],
tf
.
dtypes
.
int32
)
# size: batch_size x 1
# bool attention mask with True in locations of global attention
# bool attention mask with True in locations of global attention
attention_mask
=
tf
.
range
(
input_ids_shape
[
1
])
[
tf
.
newaxis
,
:]
attention_mask
=
tf
.
expand_dims
(
tf
.
range
(
input_ids_shape
[
1
])
,
axis
=
0
)
attention_mask
=
tf
.
tile
(
attention_mask
,
(
input_ids_shape
[
0
],
1
))
attention_mask
=
tf
.
tile
(
attention_mask
,
(
input_ids_shape
[
0
],
1
))
if
before_sep_token
is
True
:
if
before_sep_token
is
True
:
question_end_index
=
tf
.
tile
(
question_end_index
,
(
1
,
input_ids_shape
[
1
]))
question_end_index
=
tf
.
tile
(
question_end_index
,
(
1
,
input_ids_shape
[
1
]))
attention_mask
=
tf
.
cast
(
attention_mask
<
question_end_index
,
tf
.
int32
)
attention_mask
=
tf
.
cast
(
attention_mask
<
question_end_index
,
dtype
=
question_end_index
.
dtype
)
else
:
else
:
# last token is separation token and should not be counted and in the middle are two separation tokens
# last token is separation token and should not be counted and in the middle are two separation tokens
question_end_index
=
tf
.
tile
(
question_end_index
+
1
,
(
1
,
input_ids_shape
[
1
]))
question_end_index
=
tf
.
tile
(
question_end_index
+
1
,
(
1
,
input_ids_shape
[
1
]))
attention_mask
=
(
attention_mask
=
(
tf
.
cast
(
tf
.
cast
(
attention_mask
>
question_end_index
,
attention_mask
>
question_end_index
,
tf
.
dtype
s
.
int32
,
dtype
=
question_end_index
.
dtype
,
)
)
*
tf
.
cast
(
attention_mask
<
input_ids_shape
[
-
1
],
tf
.
dtype
s
.
int32
)
*
tf
.
cast
(
attention_mask
<
input_ids_shape
[
-
1
],
dtype
=
question_end_index
.
dtype
)
)
)
return
attention_mask
return
attention_mask
...
@@ -730,6 +729,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -730,6 +729,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
value_vectors
=
self
.
value
(
hidden_states
)
value_vectors
=
self
.
value
(
hidden_states
)
batch_size
,
seq_len
,
embed_dim
=
shape_list
(
hidden_states
)
batch_size
,
seq_len
,
embed_dim
=
shape_list
(
hidden_states
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
embed_dim
,
embed_dim
,
self
.
embed_dim
,
self
.
embed_dim
,
...
@@ -737,7 +737,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -737,7 +737,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
)
# normalize query
# normalize query
query_vectors
/=
tf
.
math
.
sqrt
(
tf
.
c
onvert_to_tensor
(
self
.
head_dim
,
dtype
=
tf
.
dtypes
.
float32
))
query_vectors
/=
tf
.
math
.
sqrt
(
tf
.
c
ast
(
self
.
head_dim
,
dtype
=
query_vectors
.
dtype
))
query_vectors
=
tf
.
reshape
(
query_vectors
,
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
))
query_vectors
=
tf
.
reshape
(
query_vectors
,
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
))
key_vectors
=
tf
.
reshape
(
key_vectors
,
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
))
key_vectors
=
tf
.
reshape
(
key_vectors
,
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
))
...
@@ -748,7 +748,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -748,7 +748,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# diagonal mask with zeros everywhere and -inf inplace of padding
# diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask
=
self
.
_sliding_chunks_query_key_matmul
(
diagonal_mask
=
self
.
_sliding_chunks_query_key_matmul
(
tf
.
ones
(
shape_list
(
attention_mask
)
,
dtype
=
tf
.
float32
),
tf
.
ones
(
shape_list
(
attention_mask
)),
attention_mask
,
attention_mask
,
self
.
one_sided_attn_window_size
,
self
.
one_sided_attn_window_size
,
)
)
...
@@ -756,6 +756,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -756,6 +756,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# pad local attention probs
# pad local attention probs
attn_scores
+=
diagonal_mask
attn_scores
+=
diagonal_mask
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_scores
),
shape_list
(
attn_scores
),
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
one_sided_attn_window_size
*
2
+
1
],
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
one_sided_attn_window_size
*
2
+
1
],
...
@@ -803,16 +804,18 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -803,16 +804,18 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
)
attn_probs
=
tf
.
where
(
attn_probs
=
tf
.
where
(
masked_index
,
masked_index
,
tf
.
zeros
(
shape_list
(
masked_index
),
dtype
=
tf
.
dtypes
.
float32
),
tf
.
zeros
(
shape_list
(
masked_index
),
dtype
=
attn_probs
.
dtype
),
attn_probs
,
attn_probs
,
)
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
layer_head_mask
),
shape_list
(
layer_head_mask
),
[
self
.
num_heads
],
[
self
.
num_heads
],
message
=
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is
{
shape_list
(
layer_head_mask
)
}
"
,
message
=
f
"Head mask for a single layer should be of size
{
(
self
.
num_heads
)
}
, but is
{
shape_list
(
layer_head_mask
)
}
"
,
)
)
attn_probs
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
1
,
-
1
,
1
))
*
attn_probs
attn_probs
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
1
,
-
1
,
1
))
*
attn_probs
# apply dropout
# apply dropout
...
@@ -834,6 +837,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -834,6 +837,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
),
),
)
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
attn_output
),
shape_list
(
attn_output
),
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
],
[
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
],
...
@@ -877,7 +881,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -877,7 +881,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
)
attn_probs
=
tf
.
where
(
attn_probs
=
tf
.
where
(
masked_global_attn_index
,
masked_global_attn_index
,
tf
.
zeros
(
shape_list
(
masked_global_attn_index
),
dtype
=
tf
.
dtypes
.
float32
),
tf
.
zeros
(
shape_list
(
masked_global_attn_index
),
dtype
=
attn_probs
.
dtype
),
attn_probs
,
attn_probs
,
)
)
...
@@ -893,6 +897,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -893,6 +897,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
"""
"""
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
query
)
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
query
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
seq_len
%
(
window_overlap
*
2
),
seq_len
%
(
window_overlap
*
2
),
0
,
0
,
...
@@ -919,10 +924,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -919,10 +924,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap
# bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap
chunked_query
=
tf
.
cast
(
chunked_query
,
dtype
=
chunked_key
.
dtype
)
chunked_attention_scores
=
tf
.
einsum
(
"bcxd,bcyd->bcxy"
,
chunked_query
,
chunked_key
)
# multiply
chunked_attention_scores
=
tf
.
einsum
(
"bcxd,bcyd->bcxy"
,
chunked_query
,
chunked_key
)
# multiply
# convert diagonals into columns
# convert diagonals into columns
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
0
,
0
],
[
0
,
1
],
[
0
,
0
]]
,
dtype
=
tf
.
dtypes
.
int32
)
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
0
,
0
],
[
0
,
1
],
[
0
,
0
]])
diagonal_chunked_attention_scores
=
self
.
_pad_and_transpose_last_two_dims
(
chunked_attention_scores
,
paddings
)
diagonal_chunked_attention_scores
=
self
.
_pad_and_transpose_last_two_dims
(
chunked_attention_scores
,
paddings
)
# allocate space for the overall attention matrix where the chunks are combined. The last dimension
# allocate space for the overall attention matrix where the chunks are combined. The last dimension
...
@@ -944,7 +950,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -944,7 +950,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# - copying the lower triangle
# - copying the lower triangle
diagonal_attn_scores_low_triang
=
tf
.
concat
(
diagonal_attn_scores_low_triang
=
tf
.
concat
(
[
[
tf
.
zeros
((
batch_size
*
num_heads
,
1
,
window_overlap
,
window_overlap
)),
tf
.
zeros
(
(
batch_size
*
num_heads
,
1
,
window_overlap
,
window_overlap
),
dtype
=
diagonal_chunked_attention_scores
.
dtype
,
),
diagonal_chunked_attention_scores
[:,
:,
-
(
window_overlap
+
1
)
:
-
1
,
window_overlap
+
1
:],
diagonal_chunked_attention_scores
[:,
:,
-
(
window_overlap
+
1
)
:
-
1
,
window_overlap
+
1
:],
],
],
axis
=
1
,
axis
=
1
,
...
@@ -956,7 +965,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -956,7 +965,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
shift
=
[
1
,
window_overlap
],
shift
=
[
1
,
window_overlap
],
axis
=
[
2
,
3
],
axis
=
[
2
,
3
],
)[:,
:,
:
window_overlap
,
:
window_overlap
],
)[:,
:,
:
window_overlap
,
:
window_overlap
],
tf
.
zeros
((
batch_size
*
num_heads
,
1
,
window_overlap
,
window_overlap
)),
tf
.
zeros
(
(
batch_size
*
num_heads
,
1
,
window_overlap
,
window_overlap
),
dtype
=
diagonal_chunked_attention_scores
.
dtype
,
),
],
],
axis
=
1
,
axis
=
1
,
)
)
...
@@ -1014,7 +1026,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1014,7 +1026,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
mask_4d
=
tf
.
tile
(
mask_2d
[
None
,
:,
None
,
:],
(
shape_list
(
input_tensor
)[
0
],
1
,
1
,
1
))
mask_4d
=
tf
.
tile
(
mask_2d
[
None
,
:,
None
,
:],
(
shape_list
(
input_tensor
)[
0
],
1
,
1
,
1
))
# inf tensor used for masking
# inf tensor used for masking
inf_tensor
=
-
float
(
"inf"
)
*
tf
.
ones_like
(
input_tensor
,
dtype
=
tf
.
dtypes
.
float32
)
inf_tensor
=
-
float
(
"inf"
)
*
tf
.
ones_like
(
input_tensor
)
# mask
# mask
input_tensor
=
tf
.
where
(
tf
.
math
.
greater
(
mask_4d
,
0
),
inf_tensor
,
input_tensor
)
input_tensor
=
tf
.
where
(
tf
.
math
.
greater
(
mask_4d
,
0
),
inf_tensor
,
input_tensor
)
...
@@ -1029,6 +1041,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1029,6 +1041,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
value
)
batch_size
,
seq_len
,
num_heads
,
head_dim
=
shape_list
(
value
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
seq_len
%
(
window_overlap
*
2
),
seq_len
%
(
window_overlap
*
2
),
0
,
0
,
...
@@ -1065,7 +1078,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1065,7 +1078,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
)
# pad seq_len with w at the beginning of the sequence and another window overlap at the end
# pad seq_len with w at the beginning of the sequence and another window overlap at the end
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
window_overlap
,
window_overlap
],
[
0
,
0
]]
,
dtype
=
tf
.
dtypes
.
int32
)
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
window_overlap
,
window_overlap
],
[
0
,
0
]])
padded_value
=
tf
.
pad
(
value
,
paddings
,
constant_values
=-
1
)
padded_value
=
tf
.
pad
(
value
,
paddings
,
constant_values
=-
1
)
# chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
# chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
...
@@ -1081,6 +1094,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1081,6 +1094,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
(
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
),
(
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
),
)
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
chunked_value
),
shape_list
(
chunked_value
),
[
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
],
[
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
],
...
@@ -1158,6 +1172,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1158,6 +1172,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# chunk with overlap
# chunk with overlap
chunked_hidden_states
=
tf
.
signal
.
frame
(
hidden_states
,
frame_size
,
frame_hop_size
)
chunked_hidden_states
=
tf
.
signal
.
frame
(
hidden_states
,
frame_size
,
frame_hop_size
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
chunked_hidden_states
),
shape_list
(
chunked_hidden_states
),
[
batch_size
,
num_output_chunks
,
frame_size
],
[
batch_size
,
num_output_chunks
,
frame_size
],
...
@@ -1175,7 +1190,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1175,7 +1190,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
def
_get_global_attn_indices
(
is_index_global_attn
):
def
_get_global_attn_indices
(
is_index_global_attn
):
""" compute global attn indices required throughout forward pass """
""" compute global attn indices required throughout forward pass """
# helper variable
# helper variable
num_global_attn_indices
=
tf
.
reduce_sum
(
tf
.
cast
(
is_index_global_attn
,
dtype
=
tf
.
dtypes
.
int32
),
axis
=
1
)
num_global_attn_indices
=
tf
.
math
.
count_nonzero
(
is_index_global_attn
,
axis
=
1
)
num_global_attn_indices
=
tf
.
cast
(
num_global_attn_indices
,
dtype
=
tf
.
constant
(
1
).
dtype
)
# max number of global attn indices in batch
# max number of global attn indices in batch
max_num_global_attn_indices
=
tf
.
reduce_max
(
num_global_attn_indices
)
max_num_global_attn_indices
=
tf
.
reduce_max
(
num_global_attn_indices
)
...
@@ -1237,6 +1253,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1237,6 +1253,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
shape_list
(
attn_probs_from_global_key_trans
)[
-
2
:]
shape_list
(
attn_probs_from_global_key_trans
)[
-
2
:]
)
)
mask
=
tf
.
ones
(
mask_shape
)
*
-
10000.0
mask
=
tf
.
ones
(
mask_shape
)
*
-
10000.0
mask
=
tf
.
cast
(
mask
,
dtype
=
attn_probs_from_global_key_trans
.
dtype
)
# scatter mask
# scatter mask
attn_probs_from_global_key_trans
=
tf
.
tensor_scatter_nd_update
(
attn_probs_from_global_key_trans
=
tf
.
tensor_scatter_nd_update
(
...
@@ -1323,7 +1340,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1323,7 +1340,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
global_value_vectors
=
self
.
value_global
(
hidden_states
)
global_value_vectors
=
self
.
value_global
(
hidden_states
)
# normalize
# normalize
global_query_vectors_only_global
/=
tf
.
math
.
sqrt
(
tf
.
convert_to_tensor
(
self
.
head_dim
,
dtype
=
tf
.
dtypes
.
float32
))
global_query_vectors_only_global
/=
tf
.
math
.
sqrt
(
tf
.
cast
(
self
.
head_dim
,
dtype
=
global_query_vectors_only_global
.
dtype
)
)
global_query_vectors_only_global
=
self
.
reshape_and_transpose
(
global_query_vectors_only_global
,
batch_size
)
global_query_vectors_only_global
=
self
.
reshape_and_transpose
(
global_query_vectors_only_global
,
batch_size
)
global_key_vectors
=
self
.
reshape_and_transpose
(
global_key_vectors
,
batch_size
)
global_key_vectors
=
self
.
reshape_and_transpose
(
global_key_vectors
,
batch_size
)
global_value_vectors
=
self
.
reshape_and_transpose
(
global_value_vectors
,
batch_size
)
global_value_vectors
=
self
.
reshape_and_transpose
(
global_value_vectors
,
batch_size
)
...
@@ -1331,6 +1350,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1331,6 +1350,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# compute attn scores
# compute attn scores
global_attn_scores
=
tf
.
matmul
(
global_query_vectors_only_global
,
global_key_vectors
,
transpose_b
=
True
)
global_attn_scores
=
tf
.
matmul
(
global_query_vectors_only_global
,
global_key_vectors
,
transpose_b
=
True
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
global_attn_scores
),
shape_list
(
global_attn_scores
),
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
],
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
seq_len
],
...
@@ -1346,6 +1366,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1346,6 +1366,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
shape_list
(
global_attn_scores_trans
)[
-
2
:]
shape_list
(
global_attn_scores_trans
)[
-
2
:]
)
)
global_attn_mask
=
tf
.
ones
(
mask_shape
)
*
-
10000.0
global_attn_mask
=
tf
.
ones
(
mask_shape
)
*
-
10000.0
global_attn_mask
=
tf
.
cast
(
global_attn_mask
,
dtype
=
global_attn_scores_trans
.
dtype
)
# scatter mask
# scatter mask
global_attn_scores_trans
=
tf
.
tensor_scatter_nd_update
(
global_attn_scores_trans
=
tf
.
tensor_scatter_nd_update
(
...
@@ -1368,6 +1389,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1368,6 +1389,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# apply layer head maskin
# apply layer head maskin
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
layer_head_mask
),
shape_list
(
layer_head_mask
),
[
self
.
num_heads
],
[
self
.
num_heads
],
...
@@ -1386,6 +1408,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
...
@@ -1386,6 +1408,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# global attn output
# global attn output
global_attn_output
=
tf
.
matmul
(
global_attn_probs
,
global_value_vectors
)
global_attn_output
=
tf
.
matmul
(
global_attn_probs
,
global_value_vectors
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
tf
.
debugging
.
assert_equal
(
shape_list
(
global_attn_output
),
shape_list
(
global_attn_output
),
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
self
.
head_dim
],
[
batch_size
*
self
.
num_heads
,
max_num_global_attn_indices
,
self
.
head_dim
],
...
@@ -2230,6 +2253,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
...
@@ -2230,6 +2253,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
logger
.
info
(
"Initializing global attention on question tokens..."
)
logger
.
info
(
"Initializing global attention on question tokens..."
)
# put global attention on all tokens until `config.sep_token_id` is reached
# put global attention on all tokens until `config.sep_token_id` is reached
sep_token_indices
=
tf
.
where
(
inputs
[
"input_ids"
]
==
self
.
config
.
sep_token_id
)
sep_token_indices
=
tf
.
where
(
inputs
[
"input_ids"
]
==
self
.
config
.
sep_token_id
)
sep_token_indices
=
tf
.
cast
(
sep_token_indices
,
dtype
=
inputs
[
"input_ids"
].
dtype
)
inputs
[
"global_attention_mask"
]
=
_compute_global_attention_mask
(
inputs
[
"global_attention_mask"
]
=
_compute_global_attention_mask
(
shape_list
(
inputs
[
"input_ids"
]),
sep_token_indices
shape_list
(
inputs
[
"input_ids"
]),
sep_token_indices
)
)
...
...
tests/test_modeling_tf_led.py
View file @
19e737b9
...
@@ -362,10 +362,6 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -362,10 +362,6 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
self
.
assertEqual
(
model
.
config
.
output_hidden_states
,
True
)
self
.
assertEqual
(
model
.
config
.
output_hidden_states
,
True
)
check_encoder_attentions_output
(
outputs
)
check_encoder_attentions_output
(
outputs
)
def
test_mixed_precision
(
self
):
# TODO JP: Make LED float16 compliant
pass
def
test_xla_mode
(
self
):
def
test_xla_mode
(
self
):
# TODO JP: Make LED XLA compliant
# TODO JP: Make LED XLA compliant
pass
pass
...
...
tests/test_modeling_tf_longformer.py
View file @
19e737b9
...
@@ -343,10 +343,6 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -343,10 +343,6 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
# This test is too long (>30sec) and makes fail the CI
pass
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make Longformer float16 compliant
pass
def
test_xla_mode
(
self
):
def
test_xla_mode
(
self
):
# TODO JP: Make Longformer XLA compliant
# TODO JP: Make Longformer XLA compliant
pass
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment