Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
09e6e71c
Commit
09e6e71c
authored
Mar 02, 2022
by
Zihan Wang
Browse files
lint
parent
32867f40
Changes
7
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
755 additions
and
631 deletions
+755
-631
official/projects/longformer/longformer_attention.py
official/projects/longformer/longformer_attention.py
+260
-171
official/projects/longformer/longformer_attention_test.py
official/projects/longformer/longformer_attention_test.py
+98
-60
official/projects/longformer/longformer_encoder.py
official/projects/longformer/longformer_encoder.py
+141
-151
official/projects/longformer/longformer_encoder_block.py
official/projects/longformer/longformer_encoder_block.py
+122
-136
official/projects/longformer/longformer_encoder_test.py
official/projects/longformer/longformer_encoder_test.py
+51
-36
official/projects/longformer/longformer_experiments.py
official/projects/longformer/longformer_experiments.py
+71
-65
official/projects/longformer/train.py
official/projects/longformer/train.py
+12
-12
No files found.
official/projects/longformer/longformer_attention.py
View file @
09e6e71c
This diff is collapsed.
Click to expand it.
official/projects/longformer/longformer_attention_test.py
View file @
09e6e71c
...
...
@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for
the
attention
layer
."""
"""Tests for
official.nlp.projects.longformer.longformer_
attention."""
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
tensorflow.python.keras
import
\
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.projects.longformer
import
longformer_attention
from
official.modeling.tf_utils
import
get_shape_list
...
...
@@ -56,7 +57,7 @@ def _create_mock_attention_data(
if
include_mask
:
mask_shape
=
(
batch_size
,
num_heads
,
q_seq_length
,
total_seq_length
)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
mask_shape
).
astype
(
"
float32
"
)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
mask_shape
).
astype
(
'
float32
'
)
mask_data
=
dict
(
attention_mask
=
mask_data
)
data
.
update
(
mask_data
)
...
...
@@ -65,6 +66,12 @@ def _create_mock_attention_data(
@
keras_parameterized
.
run_all_keras_modes
class
LongformerAttentionTest
(
keras_parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
LongformerAttentionTest
,
self
).
setUp
()
np
.
random
.
seed
(
0
)
tf
.
random
.
set_seed
(
0
)
def
_get_hidden_states
(
self
):
return
tf
.
convert_to_tensor
(
[
...
...
@@ -116,25 +123,33 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
def
test_diagonalize
(
self
):
hidden_states
=
self
.
_get_hidden_states
()
hidden_states
=
tf
.
reshape
(
hidden_states
,
(
1
,
8
,
4
))
# set seq length = 8, hidden dim = 4
chunked_hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
hidden_states
,
window_overlap
=
2
)
hidden_states
=
tf
.
reshape
(
hidden_states
,
(
1
,
8
,
4
))
# set seq length = 8, hidden dim = 4
chunked_hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
hidden_states
,
window_overlap
=
2
)
window_overlap_size
=
get_shape_list
(
chunked_hidden_states
)[
2
]
self
.
assertTrue
(
window_overlap_size
==
4
)
padded_hidden_states
=
longformer_attention
.
LongformerAttention
.
_pad_and_diagonalize
(
chunked_hidden_states
)
padded_hidden_states
=
longformer_attention
.
LongformerAttention
.
_pad_and_diagonalize
(
chunked_hidden_states
)
self
.
assertTrue
(
get_shape_list
(
padded_hidden_states
)[
-
1
]
==
get_shape_list
(
chunked_hidden_states
)[
-
1
]
+
window_overlap_size
-
1
get_shape_list
(
padded_hidden_states
)[
-
1
]
==
get_shape_list
(
chunked_hidden_states
)[
-
1
]
+
window_overlap_size
-
1
)
# first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000]
tf
.
debugging
.
assert_near
(
padded_hidden_states
[
0
,
0
,
0
,
:
4
],
chunked_hidden_states
[
0
,
0
,
0
],
rtol
=
1e-3
)
tf
.
debugging
.
assert_near
(
padded_hidden_states
[
0
,
0
,
0
,
4
:],
tf
.
zeros
((
3
,),
dtype
=
tf
.
dtypes
.
float32
),
rtol
=
1e-3
)
tf
.
debugging
.
assert_near
(
padded_hidden_states
[
0
,
0
,
0
,
:
4
],
chunked_hidden_states
[
0
,
0
,
0
],
rtol
=
1e-3
)
tf
.
debugging
.
assert_near
(
padded_hidden_states
[
0
,
0
,
0
,
4
:],
tf
.
zeros
((
3
,),
dtype
=
tf
.
dtypes
.
float32
),
rtol
=
1e-3
)
# last row => [0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629]
tf
.
debugging
.
assert_near
(
padded_hidden_states
[
0
,
0
,
-
1
,
3
:],
chunked_hidden_states
[
0
,
0
,
-
1
],
rtol
=
1e-3
)
tf
.
debugging
.
assert_near
(
padded_hidden_states
[
0
,
0
,
-
1
,
3
:],
chunked_hidden_states
[
0
,
0
,
-
1
],
rtol
=
1e-3
)
tf
.
debugging
.
assert_near
(
padded_hidden_states
[
0
,
0
,
-
1
,
:
3
],
tf
.
zeros
((
3
,),
dtype
=
tf
.
dtypes
.
float32
),
rtol
=
1e-3
padded_hidden_states
[
0
,
0
,
-
1
,
:
3
],
tf
.
zeros
((
3
,),
dtype
=
tf
.
dtypes
.
float32
),
rtol
=
1e-3
)
def
test_pad_and_transpose_last_two_dims
(
self
):
...
...
@@ -142,16 +157,21 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
self
.
assertTrue
(
get_shape_list
(
hidden_states
),
[
1
,
8
,
4
])
# pad along seq length dim
paddings
=
tf
.
constant
([[
0
,
0
],
[
0
,
0
],
[
0
,
1
],
[
0
,
0
]],
dtype
=
tf
.
dtypes
.
int32
)
paddings
=
tf
.
constant
([[
0
,
0
],
[
0
,
0
],
[
0
,
1
],
[
0
,
0
]],
dtype
=
tf
.
dtypes
.
int32
)
hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
hidden_states
,
window_overlap
=
2
)
padded_hidden_states
=
longformer_attention
.
LongformerAttention
.
_pad_and_transpose_last_two_dims
(
hidden_states
,
paddings
)
hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
hidden_states
,
window_overlap
=
2
)
padded_hidden_states
=
longformer_attention
.
LongformerAttention
.
_pad_and_transpose_last_two_dims
(
hidden_states
,
paddings
)
self
.
assertTrue
(
get_shape_list
(
padded_hidden_states
)
==
[
1
,
1
,
8
,
5
])
expected_added_dim
=
tf
.
zeros
((
5
,),
dtype
=
tf
.
dtypes
.
float32
)
tf
.
debugging
.
assert_near
(
expected_added_dim
,
padded_hidden_states
[
0
,
0
,
-
1
,
:],
rtol
=
1e-6
)
tf
.
debugging
.
assert_near
(
expected_added_dim
,
padded_hidden_states
[
0
,
0
,
-
1
,
:],
rtol
=
1e-6
)
tf
.
debugging
.
assert_near
(
hidden_states
[
0
,
0
,
-
1
,
:],
tf
.
reshape
(
padded_hidden_states
,
(
1
,
-
1
))[
0
,
24
:
32
],
rtol
=
1e-6
hidden_states
[
0
,
0
,
-
1
,
:],
tf
.
reshape
(
padded_hidden_states
,
(
1
,
-
1
))[
0
,
24
:
32
],
rtol
=
1e-6
)
def
test_mask_invalid_locations
(
self
):
...
...
@@ -159,39 +179,55 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
batch_size
=
1
seq_length
=
8
hidden_size
=
4
hidden_states
=
tf
.
reshape
(
hidden_states
,
(
batch_size
,
seq_length
,
hidden_size
))
hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
hidden_states
,
window_overlap
=
2
)
hid_states_1
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
hidden_states
,
1
)
hid_states_2
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
hidden_states
,
2
)
hid_states_3
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
hidden_states
[:,
:,
:,
:
3
],
2
)
hid_states_4
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
hidden_states
[:,
:,
2
:,
:],
2
)
self
.
assertTrue
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_1
),
tf
.
dtypes
.
int32
))
==
8
)
self
.
assertTrue
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_2
),
tf
.
dtypes
.
int32
))
==
24
)
self
.
assertTrue
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_3
),
tf
.
dtypes
.
int32
))
==
24
)
self
.
assertTrue
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_4
),
tf
.
dtypes
.
int32
))
==
12
)
hidden_states
=
tf
.
reshape
(
hidden_states
,
(
batch_size
,
seq_length
,
hidden_size
))
hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
hidden_states
,
window_overlap
=
2
)
hid_states_1
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
hidden_states
,
1
)
hid_states_2
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
hidden_states
,
2
)
hid_states_3
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
hidden_states
[:,
:,
:,
:
3
],
2
)
hid_states_4
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
hidden_states
[:,
:,
2
:,
:],
2
)
self
.
assertTrue
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_1
),
tf
.
dtypes
.
int32
))
==
8
)
self
.
assertTrue
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_2
),
tf
.
dtypes
.
int32
))
==
24
)
self
.
assertTrue
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_3
),
tf
.
dtypes
.
int32
))
==
24
)
self
.
assertTrue
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_4
),
tf
.
dtypes
.
int32
))
==
12
)
def
test_chunk
(
self
):
hidden_states
=
self
.
_get_hidden_states
()
batch_size
=
1
seq_length
=
8
hidden_size
=
4
hidden_states
=
tf
.
reshape
(
hidden_states
,
(
batch_size
,
seq_length
,
hidden_size
))
hidden_states
=
tf
.
reshape
(
hidden_states
,
(
batch_size
,
seq_length
,
hidden_size
))
chunked_hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
hidden_states
,
window_overlap
=
2
)
chunked_hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
hidden_states
,
window_overlap
=
2
)
# expected slices across chunk and seq length dim
expected_slice_along_seq_length
=
tf
.
convert_to_tensor
([
0.4983
,
-
0.7584
,
-
1.6944
],
dtype
=
tf
.
dtypes
.
float32
)
expected_slice_along_chunk
=
tf
.
convert_to_tensor
([
0.4983
,
-
1.8348
,
-
0.7584
,
2.0514
],
dtype
=
tf
.
dtypes
.
float32
)
expected_slice_along_seq_length
=
tf
.
convert_to_tensor
(
[
0.4983
,
-
0.7584
,
-
1.6944
],
dtype
=
tf
.
dtypes
.
float32
)
expected_slice_along_chunk
=
tf
.
convert_to_tensor
(
[
0.4983
,
-
1.8348
,
-
0.7584
,
2.0514
],
dtype
=
tf
.
dtypes
.
float32
)
self
.
assertTrue
(
get_shape_list
(
chunked_hidden_states
)
==
[
1
,
3
,
4
,
4
])
tf
.
debugging
.
assert_near
(
chunked_hidden_states
[
0
,
:,
0
,
0
],
expected_slice_along_seq_length
,
rtol
=
1e-3
)
tf
.
debugging
.
assert_near
(
chunked_hidden_states
[
0
,
0
,
:,
0
],
expected_slice_along_chunk
,
rtol
=
1e-3
)
tf
.
debugging
.
assert_near
(
chunked_hidden_states
[
0
,
:,
0
,
0
],
expected_slice_along_seq_length
,
rtol
=
1e-3
)
tf
.
debugging
.
assert_near
(
chunked_hidden_states
[
0
,
0
,
:,
0
],
expected_slice_along_chunk
,
rtol
=
1e-3
)
def
test_layer_local_attn
(
self
):
hidden_states
=
self
.
_get_hidden_states
()
batch_size
,
seq_length
,
hidden_size
=
hidden_states
.
shape
batch_size
,
seq_length
,
_
=
hidden_states
.
shape
layer
=
longformer_attention
.
LongformerAttention
(
num_heads
=
2
,
key_dim
=
4
,
...
...
@@ -203,14 +239,15 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
attention_mask
=
tf
.
zeros
((
batch_size
,
seq_length
),
dtype
=
tf
.
dtypes
.
float32
)
is_index_global_attn
=
tf
.
math
.
greater
(
attention_mask
,
1
)
is_global_attn
=
tf
.
math
.
reduce_any
(
is_index_global_attn
)
attention_mask
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
>
1
,
-
10000.0
,
attention_mask
[:,
:,
None
,
None
])
attention_mask
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
>
1
,
-
10000.0
,
attention_mask
[:,
:,
None
,
None
])
is_index_masked
=
tf
.
math
.
less
(
attention_mask
[:,
:,
0
,
0
],
0
)
output_hidden_states
=
layer
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
is_index_masked
=
is_index_masked
,
is_index_global_attn
=
is_index_global_attn
,
is_global_attn
=
is_global_attn
,
is_index_masked
=
is_index_masked
,
is_index_global_attn
=
is_index_global_attn
,
)[
0
]
self
.
assertTrue
(
output_hidden_states
.
shape
,
(
1
,
4
,
8
))
...
...
@@ -226,32 +263,33 @@ class LongformerAttentionTest(keras_parameterized.TestCase):
)
hidden_states
=
self
.
_get_hidden_states
()
hidden_states
=
tf
.
concat
([
self
.
_get_hidden_states
(),
self
.
_get_hidden_states
()
-
0.5
],
axis
=
0
)
hidden_states
=
tf
.
concat
(
[
self
.
_get_hidden_states
(),
self
.
_get_hidden_states
()
-
0.5
],
axis
=
0
)
batch_size
,
seq_length
,
hidden_size
=
hidden_states
.
shape
# create attn mask
attention_mask_1
=
tf
.
zeros
((
1
,
1
,
1
,
seq_length
),
dtype
=
tf
.
dtypes
.
float32
)
attention_mask_2
=
tf
.
zeros
((
1
,
1
,
1
,
seq_length
),
dtype
=
tf
.
dtypes
.
float32
)
attention_mask_1
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
==
0
,
10000.0
,
attention_mask_1
)
attention_mask_1
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
>
2
,
-
10000.0
,
attention_mask_1
)
attention_mask_2
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
==
0
,
10000.0
,
attention_mask_2
)
attention_mask_1
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
==
0
,
10000.0
,
attention_mask_1
)
attention_mask_1
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
>
2
,
-
10000.0
,
attention_mask_1
)
attention_mask_2
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
==
0
,
10000.0
,
attention_mask_2
)
attention_mask
=
tf
.
concat
([
attention_mask_1
,
attention_mask_2
],
axis
=
0
)
is_index_masked
=
tf
.
math
.
less
(
attention_mask
[:,
:,
0
,
0
],
0
)
is_index_global_attn
=
tf
.
math
.
greater
(
attention_mask
[:,
:,
0
,
0
],
0
)
is_global_attn
=
tf
.
math
.
reduce_any
(
is_index_global_attn
)
output_hidden_states
=
layer
(
hidden_states
=
hidden_states
,
attention_mask
=-
tf
.
math
.
abs
(
attention_mask
),
is_index_masked
=
is_index_masked
,
is_index_global_attn
=
is_index_global_attn
,
is_global_attn
=
is_global_attn
,
is_index_masked
=
is_index_masked
,
is_index_global_attn
=
is_index_global_attn
,
)[
0
]
self
.
assertTrue
(
output_hidden_states
.
shape
,
(
2
,
4
,
8
))
if
__name__
==
"__main__"
:
np
.
random
.
seed
(
0
)
tf
.
random
.
set_seed
(
0
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/longformer/longformer_encoder.py
View file @
09e6e71c
...
...
@@ -23,29 +23,16 @@ from absl import logging
import
tensorflow
as
tf
from
official.nlp.modeling
import
layers
from
official.projects.longformer.longformer_encoder_block
import
LongformerEncoderBlock
from
official.projects.longformer.longformer_encoder_block
import
\
LongformerEncoderBlock
from
official.modeling.tf_utils
import
get_shape_list
_Initializer
=
Union
[
str
,
tf
.
keras
.
initializers
.
Initializer
]
_approx_gelu
=
lambda
x
:
tf
.
keras
.
activations
.
gelu
(
x
,
approximate
=
True
)
# Transferred from huggingface.longformer.TFLongformerMainLayer & TFLongformerEncoder
class
LongformerEncoder
(
tf
.
keras
.
layers
.
Layer
):
"""Bi-directional Transformer-based encoder network.
This network implements a bi-directional Transformer-based encoder as
described in "BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding" (https://arxiv.org/abs/1810.04805). It includes the
embedding lookups and transformer layers, but not the masked language model
or classification task networks.
The default values for this object are taken from the BERT-Base implementation
in "BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding".
"""LongformerEncoder
Args:
vocab_size: The size of the token vocabulary.
attention_window: list of ints representing the window size for each layer.
...
...
@@ -165,15 +152,14 @@ class LongformerEncoder(tf.keras.layers.Layer):
num_attention_heads
=
num_attention_heads
,
inner_dim
=
inner_dim
,
inner_activation
=
inner_activation
,
# Longformer, instead of passing a list of attention_window, pass a value to sub-block
attention_window
=
attention_window
if
isinstance
(
attention_window
,
int
)
else
attention_window
[
i
],
attention_window
=
attention_window
[
i
],
layer_id
=
i
,
output_dropout
=
output_dropout
,
attention_dropout
=
attention_dropout
,
norm_first
=
norm_first
,
output_range
=
output_range
if
i
==
num_layers
-
1
else
None
,
kernel_initializer
=
initializer
,
name
=
'transformer/layer_
%d'
%
i
)
name
=
f
'transformer/layer_
{
i
}
'
)
self
.
_transformer_layers
.
append
(
layer
)
self
.
_pooler_layer
=
tf
.
keras
.
layers
.
Dense
(
...
...
@@ -198,7 +184,6 @@ class LongformerEncoder(tf.keras.layers.Layer):
'embedding_width'
:
embedding_width
,
'embedding_layer'
:
embedding_layer
,
'norm_first'
:
norm_first
,
# Longformer
'attention_window'
:
attention_window
,
'global_attention_size'
:
global_attention_size
,
'pad_token_id'
:
pad_token_id
,
...
...
@@ -214,9 +199,10 @@ class LongformerEncoder(tf.keras.layers.Layer):
word_ids
=
inputs
.
get
(
'input_word_ids'
)
# input_ids
mask
=
inputs
.
get
(
'input_mask'
)
# attention_mask
type_ids
=
inputs
.
get
(
'input_type_ids'
)
# token_type_ids
word_embeddings
=
inputs
.
get
(
'input_word_embeddings'
,
None
)
# input_embeds
word_embeddings
=
inputs
.
get
(
'input_word_embeddings'
,
None
)
# input_embeds
else
:
raise
ValueError
(
'Unexpected inputs type to
%s.'
%
self
.
__class__
)
raise
ValueError
(
f
'Unexpected inputs type to
{
self
.
__class__
}
.'
)
(
padding_len
,
...
...
@@ -247,34 +233,35 @@ class LongformerEncoder(tf.keras.layers.Layer):
batch_size
,
seq_len
=
get_shape_list
(
mask
)
# create masks with fixed len global_attention_size
mask
=
tf
.
transpose
(
tf
.
concat
(
values
=
[
tf
.
ones
((
self
.
_global_attention_size
,
batch_size
),
tf
.
int32
)
*
2
,
mask
=
tf
.
transpose
(
tf
.
concat
(
values
=
[
tf
.
ones
((
self
.
_global_attention_size
,
batch_size
),
tf
.
int32
)
*
2
,
tf
.
transpose
(
mask
)[
self
.
_global_attention_size
:]],
axis
=
0
))
is_index_masked
=
tf
.
math
.
less
(
mask
,
1
)
is_index_global_attn
=
tf
.
transpose
(
tf
.
concat
(
values
=
[
tf
.
ones
((
self
.
_global_attention_size
,
batch_size
),
tf
.
bool
),
tf
.
zeros
((
seq_len
-
self
.
_global_attention_size
,
tf
.
ones
((
self
.
_global_attention_size
,
batch_size
),
tf
.
bool
),
tf
.
zeros
((
seq_len
-
self
.
_global_attention_size
,
batch_size
),
tf
.
bool
)
],
axis
=
0
))
is_global_attn
=
self
.
_global_attention_size
>
0
# Longformer
attention_mask
=
mask
extended_attention_mask
=
tf
.
reshape
(
attention_mask
,
(
tf
.
shape
(
mask
)[
0
],
tf
.
shape
(
mask
)[
1
],
1
,
1
)
)
attention_mask
=
tf
.
cast
(
tf
.
math
.
abs
(
1
-
extended_attention_mask
),
tf
.
dtypes
.
float32
)
*
-
10000.0
attention_mask
=
tf
.
cast
(
tf
.
math
.
abs
(
1
-
extended_attention_mask
),
tf
.
dtypes
.
float32
)
*
-
10000.0
encoder_outputs
=
[]
x
=
embeddings
# TFLongformerEncoder
for
i
,
layer
in
enumerate
(
self
.
_transformer_layers
)
:
for
layer
in
self
.
_transformer_layers
:
x
=
layer
([
x
,
attention_mask
,
is_index_masked
,
is_index_global_attn
,
is_global_attn
])
is_index_global_attn
])
encoder_outputs
.
append
(
x
)
last_encoder_output
=
encoder_outputs
[
-
1
]
...
...
@@ -328,19 +315,19 @@ class LongformerEncoder(tf.keras.layers.Layer):
word_embeddings
,
pad_token_id
,
):
"""A helper function to pad tokens and mask to work with implementation of Longformer selfattention."""
# padding
attention_window
=
(
self
.
_attention_window
if
isinstance
(
self
.
_attention_window
,
int
)
else
max
(
self
.
_attention_window
)
)
attention_window
=
max
(
self
.
_attention_window
)
assert
attention_window
%
2
==
0
,
f
"`attention_window` should be an even value. Given
{
attention_window
}
"
assert
attention_window
%
2
==
0
,
\
f
'`attention_window` should be an even value. Given
{
attention_window
}
'
input_shape
=
get_shape_list
(
word_ids
)
if
word_ids
is
not
None
else
get_shape_list
(
word_embeddings
)
input_shape
=
get_shape_list
(
word_ids
)
if
word_ids
is
not
None
else
get_shape_list
(
word_embeddings
)
batch_size
,
seq_len
=
input_shape
[:
2
]
if
seq_len
is
not
None
:
padding_len
=
(
attention_window
-
seq_len
%
attention_window
)
%
attention_window
padding_len
=
(
attention_window
-
seq_len
%
attention_window
)
%
attention_window
else
:
padding_len
=
0
...
...
@@ -355,10 +342,13 @@ class LongformerEncoder(tf.keras.layers.Layer):
word_embeddings_padding
=
self
.
_embedding_layer
(
word_ids_padding
)
return
tf
.
concat
([
word_embeddings
,
word_embeddings_padding
],
axis
=-
2
)
word_embeddings
=
tf
.
cond
(
tf
.
math
.
greater
(
padding_len
,
0
),
pad_embeddings
,
lambda
:
word_embeddings
)
word_embeddings
=
tf
.
cond
(
tf
.
math
.
greater
(
padding_len
,
0
),
pad_embeddings
,
lambda
:
word_embeddings
)
mask
=
tf
.
pad
(
mask
,
paddings
,
constant_values
=
False
)
# no attention on the padding tokens
token_type_ids
=
tf
.
pad
(
type_ids
,
paddings
,
constant_values
=
0
)
# pad with token_type_id = 0
mask
=
tf
.
pad
(
mask
,
paddings
,
constant_values
=
False
)
# no attention on the padding tokens
token_type_ids
=
tf
.
pad
(
type_ids
,
paddings
,
constant_values
=
0
)
# pad with token_type_id = 0
return
(
padding_len
,
...
...
official/projects/longformer/longformer_encoder_block.py
View file @
09e6e71c
...
...
@@ -17,49 +17,13 @@ Longformer attention layer. Modified From huggingface/transformers
"""
import
tensorflow
as
tf
from
official.projects.longformer.longformer_attention
import
LongformerAttention
from
official.projects.longformer.longformer_attention
import
\
LongformerAttention
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
LongformerEncoderBlock
(
tf
.
keras
.
layers
.
Layer
):
"""TransformerEncoderBlock layer.
This layer implements the Transformer Encoder from
"Attention Is All You Need". (https://arxiv.org/abs/1706.03762),
which combines a `tf.keras.layers.MultiHeadAttention` layer with a
two-layer feedforward network.
References:
[Attention Is All You Need](https://arxiv.org/abs/1706.03762)
[BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding](https://arxiv.org/abs/1810.04805)
"""
def
__init__
(
self
,
global_attention_size
,
num_attention_heads
,
inner_dim
,
inner_activation
,
# Longformer
attention_window
,
layer_id
=
0
,
output_range
=
None
,
kernel_initializer
=
"glorot_uniform"
,
bias_initializer
=
"zeros"
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
activity_regularizer
=
None
,
kernel_constraint
=
None
,
bias_constraint
=
None
,
use_bias
=
True
,
norm_first
=
False
,
norm_epsilon
=
1e-12
,
output_dropout
=
0.0
,
attention_dropout
=
0.0
,
inner_dropout
=
0.0
,
attention_initializer
=
None
,
attention_axes
=
None
,
**
kwargs
):
"""Initializes `TransformerEncoderBlock`.
"""LongformerEncoderBlock.
Args:
num_attention_heads: Number of attention heads.
...
...
@@ -94,6 +58,32 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
attention over all axes, but batch, heads, and features.
**kwargs: keyword arguments/
"""
def
__init__
(
self
,
global_attention_size
,
num_attention_heads
,
inner_dim
,
inner_activation
,
# Longformer
attention_window
,
layer_id
=
0
,
output_range
=
None
,
kernel_initializer
=
"glorot_uniform"
,
bias_initializer
=
"zeros"
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
activity_regularizer
=
None
,
kernel_constraint
=
None
,
bias_constraint
=
None
,
use_bias
=
True
,
norm_first
=
False
,
norm_epsilon
=
1e-12
,
output_dropout
=
0.0
,
attention_dropout
=
0.0
,
inner_dropout
=
0.0
,
attention_initializer
=
None
,
attention_axes
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
global_attention_size
=
global_attention_size
...
...
@@ -133,16 +123,16 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
input_tensor_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
else
:
raise
ValueError
(
"The type of input shape argument is not supported, got:
%s"
%
type
(
input_shape
))
f
"The type of input shape argument is not supported, got:
"
f
"
{
type
(
input_shape
)
}
"
)
einsum_equation
=
"abc,cd->abd"
if
len
(
input_tensor_shape
.
as_list
())
>
3
:
einsum_equation
=
"...bc,cd->...bd"
hidden_size
=
input_tensor_shape
[
-
1
]
if
hidden_size
%
self
.
_num_heads
!=
0
:
raise
ValueError
(
"The input size (
%d
) is not a multiple of the number of attention "
"heads (
%d)"
%
(
hidden_size
,
self
.
_num_heads
)
)
f
"The input size (
{
hidden_size
}
) is not a multiple of the number of attention "
f
"heads (
{
self
.
_num_heads
}
)"
)
self
.
_attention_head_size
=
int
(
hidden_size
//
self
.
_num_heads
)
common_kwargs
=
dict
(
bias_initializer
=
self
.
_bias_initializer
,
...
...
@@ -216,7 +206,7 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
epsilon
=
self
.
_norm_epsilon
,
dtype
=
tf
.
float32
)
super
(
LongformerEncoderBlock
,
self
).
build
(
input_shape
)
super
().
build
(
input_shape
)
def
get_config
(
self
):
config
=
{
...
...
@@ -258,7 +248,7 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
tf
.
keras
.
initializers
.
serialize
(
self
.
_attention_initializer
),
"attention_axes"
:
self
.
_attention_axes
,
}
base_config
=
super
(
LongformerEncoderBlock
,
self
).
get_config
()
base_config
=
super
().
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
):
...
...
@@ -277,26 +267,23 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
An output tensor with the same dimensions as input/query tensor.
"""
if
isinstance
(
inputs
,
(
list
,
tuple
)):
if
len
(
inputs
)
==
5
:
if
len
(
inputs
)
==
4
:
(
input_tensor
,
attention_mask
,
is_index_masked
,
is_index_global_attn
,
is_global_attn
)
=
inputs
key_value
=
None
elif
len
(
inputs
)
==
6
:
elif
len
(
inputs
)
==
5
:
assert
False
# No key_value
else
:
raise
ValueError
(
"Unexpected inputs to %s with length at %d"
%
(
self
.
__class__
,
len
(
inputs
)))
raise
ValueError
(
f
"Unexpected inputs to
{
self
.
__class__
}
with length at
{
len
(
inputs
)
}
"
)
else
:
input_tensor
=
inputs
attention_mask
=
None
is_index_masked
=
None
is_index_global_attn
=
None
is_global_attn
=
None
key_value
=
None
if
self
.
_output_range
:
...
...
@@ -329,7 +316,6 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
attention_mask
=
attention_mask
,
is_index_masked
=
is_index_masked
,
is_index_global_attn
=
is_index_global_attn
,
is_global_attn
=
is_global_attn
)
# TFLongformerAttention.TFLongformerSelfOutput.* - {.dense}
attention_output
=
self
.
_attention_dropout
(
attention_output
)
...
...
official/projects/longformer/longformer_encoder_test.py
View file @
09e6e71c
...
...
@@ -12,44 +12,55 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.projects.
bigbird.
encoder."""
"""Tests for official.nlp.projects.
longformer.longformer_
encoder."""
import
numpy
as
np
import
tensorflow
as
tf
from
absl.testing
import
parameterized
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
tensorflow.python.keras
import
\
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
tensorflow.python.distribute
import
combinations
from
official.projects.longformer.longformer_encoder
import
LongformerEncoder
@
keras_parameterized
.
run_all_keras_modes
class
LongformerEncoderTest
(
keras_parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
LongformerEncoderTest
,
self
).
setUp
()
np
.
random
.
seed
(
0
)
tf
.
random
.
set_seed
(
0
)
@
combinations
.
generate
(
combinations
.
combine
(
attention_window
=
[
32
,
128
],
global_attention_size
=
[
0
,
1
,
2
]))
def
test_encoder
(
self
,
attention_window
,
global_attention_size
):
sequence_length
=
128
batch_size
=
2
vocab_size
=
1024
hidden_size
=
256
hidden_size
=
256
network
=
LongformerEncoder
(
global_attention_size
=
global_attention_size
,
vocab_size
=
vocab_size
,
attention_window
=
attention_window
,
attention_window
=
[
attention_window
]
,
hidden_size
=
hidden_size
,
num_layers
=
1
,
num_attention_heads
=
4
,
max_sequence_length
=
512
)
word_id_data
=
np
.
random
.
randint
(
vocab_size
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
type_id_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
word_id_data
=
np
.
random
.
randint
(
vocab_size
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
type_id_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
inputs
=
{
'input_word_ids'
:
word_id_data
,
'input_mask'
:
mask_data
,
'input_type_ids'
:
type_id_data
,
}
outputs
=
network
(
inputs
)
self
.
assertEqual
(
outputs
[
"
sequence_output
"
].
shape
,
self
.
assertEqual
(
outputs
[
'
sequence_output
'
].
shape
,
(
batch_size
,
sequence_length
,
hidden_size
))
@
combinations
.
generate
(
combinations
.
combine
(
...
...
@@ -62,24 +73,28 @@ class LongformerEncoderTest(keras_parameterized.TestCase):
network
=
LongformerEncoder
(
global_attention_size
=
global_attention_size
,
vocab_size
=
vocab_size
,
attention_window
=
32
,
attention_window
=
[
32
]
,
hidden_size
=
hidden_size
,
num_layers
=
1
,
num_attention_heads
=
4
,
max_sequence_length
=
512
,
norm_first
=
norm_first
)
word_id_data
=
np
.
random
.
randint
(
vocab_size
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
type_id_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
word_id_data
=
np
.
random
.
randint
(
vocab_size
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
type_id_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
inputs
=
{
'input_word_ids'
:
word_id_data
,
'input_mask'
:
mask_data
,
'input_type_ids'
:
type_id_data
,
}
outputs
=
network
(
inputs
)
self
.
assertEqual
(
outputs
[
"
sequence_output
"
].
shape
,
self
.
assertEqual
(
outputs
[
'
sequence_output
'
].
shape
,
(
batch_size
,
sequence_length
,
hidden_size
))
if
__name__
==
"
__main__
"
:
if
__name__
==
'
__main__
'
:
tf
.
test
.
main
()
official/projects/longformer/longformer_experiments.py
View file @
09e6e71c
...
...
@@ -34,22 +34,24 @@ AdamWeightDecay = optimization.AdamWeightDecayConfig
PolynomialLr
=
optimization
.
PolynomialLrConfig
PolynomialWarmupConfig
=
optimization
.
PolynomialWarmupConfig
@
dataclasses
.
dataclass
class
LongformerOptimizationConfig
(
optimization
.
OptimizationConfig
):
optimizer
:
optimization
.
OptimizerConfig
=
optimization
.
OptimizerConfig
(
type
=
"
adamw
"
,
type
=
'
adamw
'
,
adamw
=
AdamWeightDecay
(
weight_decay_rate
=
0.01
,
exclude_from_weight_decay
=
[
"
LayerNorm
"
,
"
layer_norm
"
,
"
bias
"
],
exclude_from_weight_decay
=
[
'
LayerNorm
'
,
'
layer_norm
'
,
'
bias
'
],
epsilon
=
1e-6
))
learning_rate
:
optimization
.
LrConfig
=
optimization
.
LrConfig
(
type
=
"
polynomial
"
,
type
=
'
polynomial
'
,
polynomial
=
PolynomialLr
(
initial_learning_rate
=
1e-4
,
decay_steps
=
1000000
,
end_learning_rate
=
0.0
))
warmup
:
optimization
.
WarmupConfig
=
optimization
.
WarmupConfig
(
type
=
"polynomial"
,
polynomial
=
PolynomialWarmupConfig
(
warmup_steps
=
10000
))
type
=
'polynomial'
,
polynomial
=
PolynomialWarmupConfig
(
warmup_steps
=
10000
))
@
exp_factory
.
register_config_factory
(
'longformer/pretraining'
)
def
longformer_pretraining
()
->
cfg
.
ExperimentConfig
:
...
...
@@ -62,11 +64,14 @@ def longformer_pretraining() -> cfg.ExperimentConfig:
type
=
"any"
,
any
=
LongformerEncoderConfig
()),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
768
,
num_classes
=
2
,
dropout_rate
=
0.1
,
name
=
'next_sentence'
)
inner_dim
=
768
,
num_classes
=
2
,
dropout_rate
=
0.1
,
name
=
'next_sentence'
)
]
),
train_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
use_v2_feature_names
=
True
),
validation_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
use_v2_feature_names
=
True
,
train_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
use_v2_feature_names
=
True
),
validation_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
use_v2_feature_names
=
True
,
is_training
=
False
)),
trainer
=
cfg
.
TrainerConfig
(
optimizer_config
=
LongformerOptimizationConfig
(),
train_steps
=
1000000
),
...
...
@@ -76,6 +81,7 @@ def longformer_pretraining() -> cfg.ExperimentConfig:
])
return
config
@
exp_factory
.
register_config_factory
(
'longformer/glue'
)
def
longformer_glue
()
->
cfg
.
ExperimentConfig
:
config
=
cfg
.
ExperimentConfig
(
...
...
official/projects/longformer/train.py
View file @
09e6e71c
...
...
@@ -24,7 +24,6 @@ from official.core import task_factory
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.projects.longformer
import
longformer_experiments
FLAGS
=
flags
.
FLAGS
...
...
@@ -43,7 +42,8 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
)
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
)
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment